Coverage for src/flag_gems/ops/stack.py: 60%

86 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2from typing import List, Tuple, Union 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@triton.jit 

12def stack_copy_func_kernel_4( 

13 out_ptr, 

14 in_ptr_a, 

15 in_ptr_b, 

16 in_ptr_c, 

17 in_ptr_d, 

18 dim_size_out, 

19 dim_prod_post, 

20 dim_offset_a, 

21 dim_offset_b, 

22 dim_offset_c, 

23 dim_offset_d, 

24 total_elements_a, 

25 total_elements_b, 

26 total_elements_c, 

27 total_elements_d, 

28 BLOCK_X: tl.constexpr, 

29): 

30 pid_x = tl.program_id(0) 

31 pid_y = tl.program_id(1) 

32 

33 if pid_y == 0: 

34 in_ptr = in_ptr_a 

35 dim_offset = dim_offset_a 

36 total_elements = total_elements_a 

37 elif pid_y == 1: 

38 in_ptr = in_ptr_b 

39 dim_offset = dim_offset_b 

40 total_elements = total_elements_b 

41 elif pid_y == 2: 

42 in_ptr = in_ptr_c 

43 dim_offset = dim_offset_c 

44 total_elements = total_elements_c 

45 else: 

46 in_ptr = in_ptr_d 

47 dim_offset = dim_offset_d 

48 total_elements = total_elements_d 

49 

50 block_start = pid_x.to(tl.int64) * BLOCK_X 

51 offsets = tl.arange(0, BLOCK_X).to(tl.int64) 

52 idx = block_start + offsets 

53 scalar_zero = offsets * 0 

54 

55 dim_size_out = dim_size_out + scalar_zero 

56 dim_prod_post = dim_prod_post + scalar_zero 

57 dim_offset = dim_offset + scalar_zero 

58 total_elements = total_elements + scalar_zero 

59 

60 mask = idx < total_elements 

61 

62 pre_idx = idx // dim_prod_post 

63 post_idx = idx % dim_prod_post 

64 

65 out_idx = ( 

66 pre_idx * dim_size_out * dim_prod_post + dim_offset * dim_prod_post + post_idx 

67 ) 

68 

69 data = tl.load(in_ptr + idx, mask=mask) 

70 tl.store(out_ptr + out_idx, data, mask=mask) 

71 

72 

73def stack( 

74 tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0 

75) -> torch.Tensor: 

76 logger.debug("GEMS STACK") 

77 

78 if len(tensors) == 0: 

79 raise RuntimeError("stack expected a non-empty TensorList") 

80 

81 inp_shapes = [list(_.shape) for _ in tensors] 

82 inp0_shape = inp_shapes[0] 

83 for i, s in enumerate(inp_shapes[1:]): 

84 if (dim < -tensors[i + 1].dim() - 1) or (dim > tensors[i + 1].dim()): 

85 raise IndexError( 

86 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( 

87 -tensors[i + 1].dim() - 1, tensors[i + 1].dim(), dim 

88 ) 

89 ) 

90 if s != inp0_shape: 

91 raise RuntimeError( 

92 f"stack expects each tensor to be equal size, but got {inp0_shape} at entry 0 and {s} at entry {i + 1}" 

93 ) 

94 

95 if dim < 0: 

96 dim = dim + len(inp0_shape) + 1 

97 

98 # Type promotion: find the common dtype for all tensors 

99 dtypes = [t.dtype for t in tensors] 

100 dtype = dtypes[0] 

101 for dt in dtypes[1:]: 

102 dtype = torch.promote_types(dtype, dt) 

103 # Convert all tensors to the result dtype if needed 

104 tensors = [t.to(dtype) if t.dtype != dtype else t for t in tensors] 

105 device = tensors[0].device 

106 out_shape = inp0_shape[:dim] + [len(tensors)] + inp0_shape[dim:] 

107 out = torch.empty(out_shape, dtype=dtype, device=device) 

108 

109 dim_prod_post = 1 

110 for s in inp0_shape[dim:]: 

111 dim_prod_post *= s 

112 

113 BLOCK = 1024 

114 i = 0 

115 while i < len(tensors): 

116 tensors_in_batch = tensors[i : i + 4] 

117 num_tensors_in_batch = len(tensors_in_batch) 

118 

119 args = [] 

120 total_elements_list = [] 

121 

122 for j in range(4): 

123 if j < num_tensors_in_batch: 

124 tensor = tensors_in_batch[j].contiguous() 

125 total_elements = tensor.numel() 

126 args.extend([tensor, i + j, total_elements]) 

127 total_elements_list.append(total_elements) 

128 else: 

129 args.extend([tensors_in_batch[0], 0, 0]) 

130 total_elements_list.append(0) 

131 

132 dim_size_out = len(tensors) 

133 

134 grid_y = num_tensors_in_batch 

135 max_elements_in_batch = tensors[0].numel() if total_elements_list else 0 

136 grid = (triton.cdiv(max_elements_in_batch, BLOCK), grid_y) 

137 

138 ( 

139 tensor_a, 

140 dim_offset_a, 

141 total_elements_a, 

142 tensor_b, 

143 dim_offset_b, 

144 total_elements_b, 

145 tensor_c, 

146 dim_offset_c, 

147 total_elements_c, 

148 tensor_d, 

149 dim_offset_d, 

150 total_elements_d, 

151 ) = args 

152 

153 stack_copy_func_kernel_4[grid]( 

154 out, 

155 tensor_a, 

156 tensor_b, 

157 tensor_c, 

158 tensor_d, 

159 dim_size_out, 

160 dim_prod_post, 

161 dim_offset_a, 

162 dim_offset_b, 

163 dim_offset_c, 

164 dim_offset_d, 

165 total_elements_a, 

166 total_elements_b, 

167 total_elements_c, 

168 total_elements_d, 

169 BLOCK_X=BLOCK, 

170 ) 

171 i += num_tensors_in_batch 

172 

173 return out