Coverage for src/flag_gems/runtime/backend/_ascend/ops/gather_collapsed.py: 0%

92 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1from typing import List, Tuple 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7WARP_LIST = [8, 16, 32, 64] 

8REORDER_LIST = [True, False] 

9MEM_LIST = [120 * 1024, 216 * 1024] 

10BLOCK_SIZE_LIST = [32, 64, 128, 256, 512, 1024, 2048] 

11 

12 

13def normalize_dim(dim: int, ndim: int) -> int: 

14 if dim < 0: 

15 dim += ndim 

16 if dim < 0 or dim >= ndim: 

17 raise ValueError(f"dim={dim} out of range for ndim={ndim}") 

18 return dim 

19 

20 

21def apply_prefix_narrows( 

22 inp: torch.Tensor, narrows: List[Tuple[int, int]] 

23) -> torch.Tensor: 

24 for axis, new_size in narrows: 

25 if new_size == inp.shape[axis]: 

26 continue 

27 inp = inp.narrow(axis, 0, new_size) 

28 return inp 

29 

30 

31def can_collapse_axes( 

32 inp: torch.Tensor, index: torch.Tensor, dim: int 

33) -> Tuple[bool, List[Tuple[int, int]]]: 

34 """ 

35 Determine whether we can use the collapsed (3D) gather kernel. 

36 Gather definition (dim = d): 

37 Y[t0..tN-1] = 

38 inp[t0..t_{d-1}, index[t0..tN-1], t_{d+1}..t_{N-1}] 

39 

40 Shape constraints: 

41 - For i != d: index.shape[i] <= inp.shape[i] 

42 - Output only accesses inp at coordinates 0 <= t_i < index.shape[i] 

43 

44 Collapsed kernel assumption: 

45 We fold tensor into (Outer, Dim, Inner): 

46 Outer = ∏_{i<d} shape[i] 

47 Inner = ∏_{i>d} shape[i] 

48 The same (off_outer, off_inner) must map consistently 

49 in inp and index/out (linear isomorphism). 

50 

51 Policy: 

52 - For i < dim (outer side): 

53 allow index.shape[i] <= inp.shape[i]. 

54 If strictly smaller, we can prefix-narrow inp so that 

55 outer dimensions match and linear mapping remains valid. 

56 - For i > dim (inner side): 

57 require exact equality to preserve inner linear mapping. 

58 """ 

59 if inp.ndim != index.ndim: 

60 return False, [] 

61 

62 dim = normalize_dim(dim, inp.ndim) 

63 narrows: List[Tuple[int, int]] = [] 

64 

65 for i in range(inp.ndim): 

66 if i == dim: 

67 continue 

68 

69 inp_i = int(inp.shape[i]) 

70 idx_i = int(index.shape[i]) 

71 

72 if i < dim: 

73 if idx_i == inp_i: 

74 continue 

75 if idx_i < inp_i: 

76 narrows.append((i, idx_i)) 

77 continue 

78 return False, [] 

79 else: 

80 if idx_i != inp_i: 

81 return False, [] 

82 

83 return True, narrows 

84 

85 

86@triton.autotune( 

87 configs=[ 

88 triton.Config( 

89 kwargs={ 

90 "BLOCK_SIZE": size, 

91 "shared_mem_dynamic_size": localmem, 

92 "enable_simt_reorder_instruction": is_reorder, 

93 }, 

94 num_warps=warp, 

95 ) 

96 for warp in WARP_LIST 

97 for localmem in MEM_LIST 

98 for size in BLOCK_SIZE_LIST 

99 for is_reorder in REORDER_LIST 

100 ], 

101 key=["num_elements"], 

102 warmup=25, 

103 rep=100, 

104) 

105@triton.jit 

106def gather_kernel_collapsed( 

107 inp_ptr, 

108 index_ptr, 

109 out_ptr, 

110 SIZE_OUTER, 

111 SIZE_DIM, 

112 SIZE_INNER, 

113 stride_inp_outer, 

114 stride_inp_dim, 

115 stride_inp_inner, 

116 stride_idx_outer, 

117 stride_idx_dim, 

118 stride_idx_inner, 

119 stride_out_outer, 

120 stride_out_dim, 

121 stride_out_inner, 

122 num_elements, 

123 with_negative_index: tl.constexpr, 

124 BLOCK_SIZE: tl.constexpr, 

125): 

126 pid = tl.program_id(0) 

127 num_programs = tl.num_programs(0) 

128 elements_per_prog = tl.cdiv(num_elements, num_programs) 

129 prog_start = pid * elements_per_prog 

130 prog_end = tl.minimum(prog_start + elements_per_prog, num_elements) 

131 

132 for block_start in range(prog_start, prog_end, BLOCK_SIZE): 

133 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

134 mask = offsets < prog_end 

135 

136 idx_val = tl.load(index_ptr + offsets, mask=mask, other=0).to(tl.int64) 

137 

138 if with_negative_index: 

139 idx_val = tl.where(idx_val < 0, idx_val + SIZE_DIM, idx_val) 

140 # Coordinate Reconstruction: (outer, dim, inner) 

141 off_inner = offsets % SIZE_INNER 

142 tmp = offsets // SIZE_INNER 

143 off_outer = tmp // SIZE_DIM 

144 

145 # Input Offset Calculation 

146 inp_off = ( 

147 off_outer * stride_inp_outer 

148 + idx_val * stride_inp_dim 

149 + off_inner * stride_inp_inner 

150 ) 

151 val = tl.load(inp_ptr + inp_off, mask=mask, other=0.0) 

152 

153 # Output Store 

154 tl.store(out_ptr + offsets, val, mask=mask) 

155 

156 

157def _collapsed_3d_views( 

158 inp: torch.Tensor, dim: int, index: torch.Tensor, out: torch.Tensor 

159): 

160 dim = normalize_dim(dim, inp.ndim) 

161 

162 # Collapse Axes to 3D: (Outer, Dim, Inner) 

163 idx_outer = 1 

164 for i in range(dim): 

165 idx_outer *= index.shape[i] 

166 idx_inner = 1 

167 for i in range(dim + 1, index.ndim): 

168 idx_inner *= index.shape[i] 

169 

170 inp_outer = 1 

171 for i in range(dim): 

172 inp_outer *= inp.shape[i] 

173 inp_inner = 1 

174 for i in range(dim + 1, inp.ndim): 

175 inp_inner *= inp.shape[i] 

176 

177 inp_3d = inp.contiguous().view(inp_outer, inp.shape[dim], inp_inner) 

178 idx_3d = index.contiguous().view(idx_outer, index.shape[dim], idx_inner) 

179 out_3d = out.view(idx_outer, index.shape[dim], idx_inner) 

180 

181 SIZE_OUTER = idx_outer 

182 SIZE_DIM = idx_3d.shape[1] 

183 SIZE_INNER = idx_inner 

184 

185 return inp_3d, idx_3d, out_3d, SIZE_OUTER, SIZE_DIM, SIZE_INNER 

186 

187 

188def gather_collapsed( 

189 inp: torch.Tensor, 

190 dim: int, 

191 index: torch.Tensor, 

192 out: torch.Tensor, 

193 grid_fn, 

194 return_run_kernel: bool = True, 

195 with_negative_index=False, 

196): 

197 if out.shape != index.shape: 

198 raise ValueError(f"out.shape {out.shape} must equal index.shape {index.shape}") 

199 

200 dim = normalize_dim(dim, inp.ndim) 

201 

202 inp_3d, idx_3d, out_3d, SIZE_OUTER, SIZE_DIM, SIZE_INNER = _collapsed_3d_views( 

203 inp, dim, index, out 

204 ) 

205 num_elements = out_3d.numel() 

206 

207 def _run_kernel(): 

208 gather_kernel_collapsed[grid_fn]( 

209 inp_3d, 

210 idx_3d, 

211 out_3d, 

212 # Shapes 

213 SIZE_OUTER, 

214 SIZE_DIM, 

215 SIZE_INNER, 

216 # Strides 

217 inp_3d.stride(0), 

218 inp_3d.stride(1), 

219 inp_3d.stride(2), 

220 idx_3d.stride(0), 

221 idx_3d.stride(1), 

222 idx_3d.stride(2), 

223 out_3d.stride(0), 

224 out_3d.stride(1), 

225 out_3d.stride(2), 

226 # Meta 

227 num_elements, 

228 with_negative_index, 

229 force_simt_only=False, 

230 ) 

231 

232 if return_run_kernel: 

233 return _run_kernel 

234 

235 _run_kernel()