Coverage for src/flag_gems/runtime/backend/_ascend/ops/cat.py: 0%

76 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2from typing import List, Tuple, Union 

3 

4import torch 

5 

6logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

7 

8 

9def cat( 

10 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0 

11) -> torch.Tensor: 

12 logger.debug("GEMS_ASCEND CAT") 

13 

14 device = A[0].device 

15 dtype = A[0].dtype 

16 A = list(A) 

17 for i in range(len(A) - 1, -1, -1): 

18 if A[i].shape == torch.Size([0]): 

19 A.pop(i) 

20 if len(A) == 0: 

21 return torch.tensor([], device=device, dtype=dtype) 

22 if len(A) == 1: 

23 return A[0] 

24 

25 assert dim >= -A[0].ndim and dim < A[0].ndim, f"Invalid dim: {dim}" 

26 dim = dim % A[0].ndim 

27 

28 inp_shapes = [list(_.shape) for _ in A] 

29 inp0_shape = inp_shapes[0] 

30 for s in inp_shapes[1:]: 

31 if len(s) != len(inp0_shape): 

32 raise RuntimeError( 

33 f"Tensors must have same number of dimensions: got {len(inp0_shape)} and {len(s)}" 

34 ) 

35 for tensor_idx, inp_shape in enumerate(inp_shapes): 

36 for idx, (common_length, length) in enumerate(zip(inp0_shape, inp_shape)): 

37 if idx == dim: 

38 continue 

39 elif length != common_length: 

40 raise RuntimeError( 

41 f"Sizes of tensors must match except in dimension {dim}. " 

42 f"Expected size {common_length} but got size {length} for tensor number " 

43 f"{tensor_idx} in the list" 

44 ) 

45 

46 out_shape = list(inp0_shape) 

47 out_shape[dim] = sum(s[dim] for s in inp_shapes) 

48 out = torch.empty(out_shape, dtype=A[0].dtype, device=A[0].device) 

49 _cat_fill(out, A, dim) 

50 return out 

51 

52 

53def _cat_fill(out, A, dim): 

54 idx = [slice(None)] * out.ndim 

55 offset = 0 

56 for a in A: 

57 a = a.contiguous() 

58 idx[dim] = slice(offset, offset + a.shape[dim]) 

59 out[tuple(idx)] = a 

60 offset += a.shape[dim] 

61 

62 

63def cat_out( 

64 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], 

65 dim: int = 0, 

66 *, 

67 out: torch.Tensor, 

68) -> torch.Tensor: 

69 logger.debug("GEMS_ASCEND CAT_OUT") 

70 

71 if len(A) == 0: 

72 raise RuntimeError("torch.cat(): expected a non-empty list of Tensors") 

73 

74 A = list(A) 

75 for i in range(len(A) - 1, -1, -1): 

76 if A[i].shape == torch.Size([0]): 

77 A.pop(i) 

78 if len(A) == 0: 

79 out.resize_(0) 

80 return out 

81 if len(A) == 1: 

82 t = A[0] 

83 out.resize_(t.shape) 

84 out.copy_(t) 

85 return out 

86 

87 assert dim >= -A[0].ndim and dim < A[0].ndim, f"Invalid dim: {dim}" 

88 dim = dim % A[0].ndim 

89 

90 inp_shapes = [list(_.shape) for _ in A] 

91 inp0_shape = inp_shapes[0] 

92 for s in inp_shapes[1:]: 

93 if len(s) != len(inp0_shape): 

94 raise RuntimeError( 

95 f"Tensors must have same number of dimensions: got {len(inp0_shape)} and {len(s)}" 

96 ) 

97 for tensor_idx, inp_shape in enumerate(inp_shapes): 

98 for idx, (common_length, length) in enumerate(zip(inp0_shape, inp_shape)): 

99 if idx == dim: 

100 continue 

101 elif length != common_length: 

102 raise RuntimeError( 

103 f"Sizes of tensors must match except in dimension {dim}. " 

104 f"Expected size {common_length} but got size {length} for tensor number " 

105 f"{tensor_idx} in the list" 

106 ) 

107 

108 out_shape = list(inp0_shape) 

109 out_shape[dim] = sum(s[dim] for s in inp_shapes) 

110 out.resize_(out_shape) 

111 _cat_fill(out, A, dim) 

112 return out