Coverage for src/flag_gems/runtime/backend/_ascend/ops/gather_collapsed_uintdiv.py: 0%

255 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import importlib.util 

2import os 

3import sys 

4from typing import List, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_utils import IndentedBuffer 

9 

10WARP_LIST = [8, 16, 32, 64] 

11MEM_LIST = [120 * 1024, 216 * 1024] 

12BLOCK_SIZE_LIST = [32, 64, 128, 256, 512, 1024, 2048] 

13 

14 

15CACHE_DIR = os.path.join(os.getcwd(), "__triton_cache__") 

16if not os.path.exists(CACHE_DIR): 

17 os.makedirs(CACHE_DIR, exist_ok=True) 

18sys.path.append(CACHE_DIR) 

19 

20 

21def normalize_dim(dim: int, ndim: int) -> int: 

22 if dim < 0: 

23 dim += ndim 

24 if dim < 0 or dim >= ndim: 

25 raise ValueError(f"dim={dim} out of range for ndim={ndim}") 

26 return dim 

27 

28 

29def apply_prefix_narrows( 

30 inp: torch.Tensor, narrows: List[Tuple[int, int]] 

31) -> torch.Tensor: 

32 for axis, new_size in narrows: 

33 if new_size == inp.shape[axis]: 

34 continue 

35 inp = inp.narrow(axis, 0, new_size) 

36 return inp 

37 

38 

39def can_collapse_axes( 

40 inp: torch.Tensor, index: torch.Tensor, dim: int 

41) -> Tuple[bool, List[Tuple[int, int]]]: 

42 """ 

43 Determine whether we can use the collapsed (3D) gather kernel. 

44 Gather definition (dim = d): 

45 Y[t0..tN-1] = 

46 inp[t0..t_{d-1}, index[t0..tN-1], t_{d+1}..t_{N-1}] 

47 

48 Shape constraints: 

49 - For i != d: index.shape[i] <= inp.shape[i] 

50 - Output only accesses inp at coordinates 0 <= t_i < index.shape[i] 

51 

52 Collapsed kernel assumption: 

53 We fold tensor into (Outer, Dim, Inner): 

54 Outer = ∏_{i<d} shape[i] 

55 Inner = ∏_{i>d} shape[i] 

56 The same (off_outer, off_inner) must map consistently 

57 in inp and index/out (linear isomorphism). 

58 

59 Policy: 

60 - For i < dim (outer side): 

61 allow index.shape[i] <= inp.shape[i]. 

62 If strictly smaller, we can prefix-narrow inp so that 

63 outer dimensions match and linear mapping remains valid. 

64 - For i > dim (inner side): 

65 require exact equality to preserve inner linear mapping. 

66 """ 

67 if inp.ndim != index.ndim: 

68 return False, [] 

69 

70 dim = normalize_dim(dim, inp.ndim) 

71 narrows: List[Tuple[int, int]] = [] 

72 

73 for i in range(inp.ndim): 

74 if i == dim: 

75 continue 

76 

77 inp_i = int(inp.shape[i]) 

78 idx_i = int(index.shape[i]) 

79 

80 if i < dim: 

81 if idx_i == inp_i: 

82 continue 

83 if idx_i < inp_i: 

84 narrows.append((i, idx_i)) 

85 continue 

86 return False, [] 

87 else: 

88 if idx_i != inp_i: 

89 return False, [] 

90 

91 return True, narrows 

92 

93 

94LIBDIVIDE_ADD_MARKER = 0x40 

95 

96 

97def _clz32(x: int) -> int: 

98 return 32 - x.bit_length() if x else 32 

99 

100 

101def calc_magic_u32_libdivide(d: int): 

102 """return (magic:uint32, more:uint8)""" 

103 assert 1 <= d <= 0xFFFFFFFF 

104 if (d & (d - 1)) == 0: 

105 shift = d.bit_length() - 1 

106 return 0, shift & 0xFF 

107 floor_log_2_d = 31 - _clz32(d) 

108 two_to = 1 << (32 + floor_log_2_d) 

109 proposed_m = two_to // d 

110 rem = two_to - proposed_m * d 

111 e = d - rem 

112 two_power = 1 << floor_log_2_d 

113 if e < two_power: 

114 magic = (proposed_m + 1) & 0xFFFFFFFF 

115 more = floor_log_2_d & 0xFF 

116 return magic, more 

117 else: 

118 proposed_m2 = proposed_m * 2 

119 twice_rem = rem * 2 

120 if twice_rem >= d or twice_rem < rem: 

121 proposed_m2 += 1 

122 magic = (proposed_m2 + 1) & 0xFFFFFFFF 

123 more = (floor_log_2_d | LIBDIVIDE_ADD_MARKER) & 0xFF 

124 return magic, more 

125 

126 

127def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

128 code.writeline("import torch") 

129 code.writeline("import triton") 

130 code.writeline("import triton.language as tl") 

131 code.newline() 

132 return code 

133 

134 

135def generate_collapsed_device_functions(code: IndentedBuffer) -> IndentedBuffer: 

136 code.writeline("# Device Functions for Fast Division (collapsed path)") 

137 code.newline() 

138 

139 code.writeline("@triton.jit") 

140 code.writeline("def fast_divide_shift(n, shift):") 

141 with code.indent(): 

142 code.writeline("return n >> shift") 

143 code.newline() 

144 

145 code.writeline("@triton.jit") 

146 code.writeline("def fast_divide_mul_noadd(n, magic, shift):") 

147 with code.indent(): 

148 code.writeline("return (tl.umulhi(n, magic) >> shift)") 

149 code.newline() 

150 

151 code.writeline("@triton.jit") 

152 code.writeline("def fast_divide_mul_add(n, magic, shift):") 

153 with code.indent(): 

154 code.writeline("q0 = tl.umulhi(n, magic)") 

155 code.writeline("t = ((n - q0) >> 1) + q0") 

156 code.writeline("return (t >> shift)") 

157 code.newline() 

158 

159 return code 

160 

161 

162def _collapsed_3d_views( 

163 inp: torch.Tensor, dim: int, index: torch.Tensor, out: torch.Tensor 

164): 

165 dim = normalize_dim(dim, inp.ndim) 

166 

167 # Collapse Axes to 3D: (Outer, Dim, Inner) 

168 idx_outer = 1 

169 for i in range(dim): 

170 idx_outer *= index.shape[i] 

171 idx_inner = 1 

172 for i in range(dim + 1, index.ndim): 

173 idx_inner *= index.shape[i] 

174 

175 inp_outer = 1 

176 for i in range(dim): 

177 inp_outer *= inp.shape[i] 

178 inp_inner = 1 

179 for i in range(dim + 1, inp.ndim): 

180 inp_inner *= inp.shape[i] 

181 

182 inp_3d = inp.contiguous().view(inp_outer, inp.shape[dim], inp_inner) 

183 idx_3d = index.contiguous().view(idx_outer, index.shape[dim], idx_inner) 

184 out_3d = out.view(idx_outer, index.shape[dim], idx_inner) 

185 

186 SIZE_OUTER = idx_outer 

187 SIZE_DIM = idx_3d.shape[1] 

188 SIZE_INNER = idx_inner 

189 

190 return inp_3d, idx_3d, out_3d, SIZE_OUTER, SIZE_DIM, SIZE_INNER 

191 

192 

193def generate_collapsed_kernel( 

194 kernel_name: str, div_kinds: List[str], code: IndentedBuffer 

195) -> IndentedBuffer: 

196 code.newline() 

197 

198 code.writeline("# Autotune Configuration Lists") 

199 code.writeline("WARP_LIST = [8, 16, 32]") 

200 code.writeline("MEM_LIST = [120 * 1024, 216 * 1024]") 

201 code.writeline("BLOCK_SIZE_LIST = [32, 64, 128, 256, 512, 1024, 2048]") 

202 code.writeline("REORDER_LIST = [True, False]") 

203 code.newline() 

204 

205 code.writeline("@triton.autotune(configs=[") 

206 with code.indent(): 

207 code.writeline( 

208 "triton.Config(" 

209 "kwargs={'BLOCK_SIZE': size, 'shared_mem_dynamic_size': localmem, " 

210 "'enable_simt_reorder_instruction': is_reorder}, num_warps=warp)" 

211 ) 

212 code.writeline("for warp in WARP_LIST") 

213 code.writeline("for localmem in MEM_LIST") 

214 code.writeline("for size in BLOCK_SIZE_LIST") 

215 code.writeline("for is_reorder in REORDER_LIST") 

216 code.writeline("],") 

217 code.writeline("key=['num_elements'], ") 

218 code.writeline("warmup=25, ") 

219 code.writeline("rep=100)") 

220 

221 code.writeline("@triton.jit") 

222 code.writeline(f"def {kernel_name}(") 

223 with code.indent(): 

224 args = [ 

225 "inp_ptr, ", 

226 "index_ptr, ", 

227 "out_ptr, ", 

228 "SIZE_OUTER, ", 

229 "SIZE_DIM, ", 

230 "SIZE_INNER, ", 

231 "stride_inp_outer, ", 

232 "stride_inp_dim, ", 

233 "stride_inp_inner, ", 

234 "stride_idx_outer, ", 

235 "stride_idx_dim, ", 

236 "stride_idx_inner, ", 

237 "stride_out_outer, ", 

238 "stride_out_dim, ", 

239 "stride_out_inner, ", 

240 "inner_magic: tl.uint32, ", 

241 "inner_shift: tl.uint32, ", 

242 "dim_magic: tl.uint32, ", 

243 "dim_shift: tl.uint32, ", 

244 "num_elements, ", 

245 "with_negative_index: tl.constexpr, ", 

246 "BLOCK_SIZE: tl.constexpr, ", 

247 ] 

248 code.writelines(args) 

249 code.writeline("):") 

250 

251 with code.indent(): 

252 code.writeline("pid = tl.program_id(0)") 

253 code.writeline("num_programs = tl.num_programs(0)") 

254 code.writeline("elements_per_prog = tl.cdiv(num_elements, num_programs)") 

255 code.writeline("prog_start = pid * elements_per_prog") 

256 code.writeline( 

257 "prog_end = tl.minimum(prog_start + elements_per_prog, num_elements)" 

258 ) 

259 code.newline() 

260 

261 code.writeline("for block_start in range(prog_start, prog_end, BLOCK_SIZE):") 

262 with code.indent(): 

263 code.writeline("offsets = block_start + tl.arange(0, BLOCK_SIZE)") 

264 code.writeline("mask = offsets < prog_end") 

265 code.newline() 

266 code.writeline("idx_val = tl.load(index_ptr + offsets, mask=mask, other=0)") 

267 

268 code.newline() 

269 code.writeline("if with_negative_index:") 

270 with code.indent(): 

271 code.writeline( 

272 "idx_val = tl.where(idx_val < 0, idx_val + SIZE_DIM, idx_val)" 

273 ) 

274 code.newline() 

275 

276 # offsets -> (off_outer, off_dim, off_inner) 

277 # q1 = offsets // SIZE_INNER 

278 # code.writeline("q1 = offsets // SIZE_INNER") 

279 if div_kinds[0] == "S": 

280 code.writeline("q1 = fast_divide_shift(offsets, inner_shift)") 

281 elif div_kinds[0] == "A": 

282 code.writeline( 

283 "q1 = fast_divide_mul_add(offsets, inner_magic, inner_shift)" 

284 ) 

285 else: 

286 code.writeline( 

287 "q1 = fast_divide_mul_noadd(offsets, inner_magic, inner_shift)" 

288 ) 

289 

290 code.writeline("off_inner = offsets - q1 * SIZE_INNER") 

291 code.writeline("tmp = q1") 

292 code.newline() 

293 

294 # q2 = tmp // SIZE_DIM 

295 # code.writeline("q2 = tmp // SIZE_DIM") 

296 if div_kinds[1] == "S": 

297 code.writeline("q2 = fast_divide_shift(tmp, dim_shift)") 

298 elif div_kinds[1] == "A": 

299 code.writeline("q2 = fast_divide_mul_add(tmp, dim_magic, dim_shift)") 

300 else: 

301 code.writeline("q2 = fast_divide_mul_noadd(tmp, dim_magic, dim_shift)") 

302 

303 code.writeline("off_dim = tmp - q2 * SIZE_DIM") 

304 code.writeline("off_outer = q2") 

305 code.newline() 

306 

307 code.writeline("inp_off = (") 

308 with code.indent(): 

309 code.writeline("off_outer * stride_inp_outer") 

310 code.writeline("+ idx_val * stride_inp_dim") 

311 code.writeline("+ off_inner * stride_inp_inner") 

312 code.writeline(")") 

313 code.writeline("val = tl.load(inp_ptr + inp_off, mask=mask, other=0.0)") 

314 code.newline() 

315 

316 code.writeline("tl.store(out_ptr + offsets, val, mask=mask)") 

317 

318 code.newline() 

319 return code 

320 

321 

322def generate_collapsed_wrapper( 

323 wrapper_name: str, kernel_name: str, code: IndentedBuffer 

324) -> IndentedBuffer: 

325 code.writeline( 

326 f"def {wrapper_name}(" 

327 f"inp, index, out, grid, inner_magic, inner_shift, dim_magic, dim_shift, with_negative_index):" 

328 ) 

329 with code.indent(): 

330 code.writeline("inp_shape = inp.shape") 

331 code.writeline("inp_stride = inp.stride()") 

332 code.writeline("index_shape = index.shape") 

333 code.writeline("index_stride = index.stride()") 

334 code.writeline("out_stride = out.stride()") 

335 code.writeline("num_elements = out.numel()") 

336 code.newline() 

337 

338 code.writeline(f"{kernel_name}[grid](") 

339 with code.indent(): 

340 args = [ 

341 "inp, ", 

342 "index, ", 

343 "out, ", 

344 "index_shape[0], ", # SIZE_OUTER 

345 "index_shape[1], ", # SIZE_DIM 

346 "index_shape[2], ", # SIZE_INNER 

347 "inp_stride[0], ", 

348 "inp_stride[1], ", 

349 "inp_stride[2], ", 

350 "index_stride[0], ", 

351 "index_stride[1], ", 

352 "index_stride[2], ", 

353 "out_stride[0], ", 

354 "out_stride[1], ", 

355 "out_stride[2], ", 

356 "inner_magic, ", 

357 "inner_shift, ", 

358 "dim_magic, ", 

359 "dim_shift, ", 

360 "num_elements, ", 

361 "with_negative_index, ", 

362 "force_simt_only=False, ", 

363 ] 

364 code.writelines(args) 

365 code.writeline(")") 

366 code.writeline("return out") 

367 code.newline() 

368 return code 

369 

370 

371def generate_collapsed_code( 

372 wrapper_name: str, kernel_name: str, div_kinds: List[str] 

373) -> str: 

374 code = IndentedBuffer() 

375 code = generate_imports(code) 

376 code = generate_collapsed_device_functions(code) 

377 code = generate_collapsed_kernel(kernel_name, div_kinds, code) 

378 code = generate_collapsed_wrapper(wrapper_name, kernel_name, code) 

379 return code.getvalue() 

380 

381 

382class CollapsedGatherFunction: 

383 def __init__(self): 

384 self.overloads = {} 

385 

386 def __call__( 

387 self, inp, index, out, grid, magic_shift_map=None, with_negative_index=False 

388 ): 

389 assert inp.ndim == 3 

390 assert index.ndim == 3 

391 assert out.ndim == 3 

392 

393 if magic_shift_map is None: 

394 # two divisors only: SIZE_INNER, SIZE_DIM 

395 inner_magic, inner_more = calc_magic_u32_libdivide(int(index.shape[2])) 

396 dim_magic, dim_more = calc_magic_u32_libdivide(int(index.shape[1])) 

397 else: 

398 (inner_magic, inner_more), (dim_magic, dim_more) = magic_shift_map 

399 

400 inner_shift = int(inner_more) & 0x1F 

401 dim_shift = int(dim_more) & 0x1F 

402 

403 inner_kind = ( 

404 "S" if int(inner_magic) == 0 else ("A" if (int(inner_more) & 0x40) else "M") 

405 ) 

406 dim_kind = ( 

407 "S" if int(dim_magic) == 0 else ("A" if (int(dim_more) & 0x40) else "M") 

408 ) 

409 

410 pattern = inner_kind + dim_kind 

411 key = f"collapsed_pat_{pattern}" 

412 

413 if key not in self.overloads: 

414 kernel_name = f"_gather_collapsed_kernel_{pattern}" 

415 wrapper_name = f"_gather_collapsed_wrapper_{pattern}" 

416 

417 src_code = generate_collapsed_code( 

418 wrapper_name, kernel_name, [inner_kind, dim_kind] 

419 ) 

420 

421 file_name = f"{key}.py" 

422 file_path = os.path.join(CACHE_DIR, file_name) 

423 with open(file_path, "w", encoding="utf-8") as f: 

424 f.write(src_code) 

425 

426 spec = importlib.util.spec_from_file_location( 

427 f"dynamic_collapsed_mod_{key}", file_path 

428 ) 

429 mod = importlib.util.module_from_spec(spec) 

430 assert spec.loader is not None 

431 spec.loader.exec_module(mod) 

432 

433 self.overloads[key] = getattr(mod, wrapper_name) 

434 

435 return self.overloads[key]( 

436 inp, 

437 index, 

438 out, 

439 grid, 

440 inner_magic, 

441 inner_shift, 

442 dim_magic, 

443 dim_shift, 

444 with_negative_index, 

445 ) 

446 

447 

448collapsed_gather = CollapsedGatherFunction() 

449 

450 

451def gather_collapsed( 

452 inp: torch.Tensor, 

453 dim: int, 

454 index: torch.Tensor, 

455 out: torch.Tensor, 

456 grid_fn, 

457 return_run_kernel: bool = True, 

458 with_negative_index: bool = False, 

459): 

460 if out.shape != index.shape: 

461 raise ValueError(f"out.shape {out.shape} must equal index.shape {index.shape}") 

462 

463 dim = normalize_dim(dim, inp.ndim) 

464 inp_3d, idx_3d, out_3d, SIZE_OUTER, SIZE_DIM, SIZE_INNER = _collapsed_3d_views( 

465 inp, dim, index, out 

466 ) 

467 

468 def _run_kernel(): 

469 collapsed_gather( 

470 inp_3d, idx_3d, out_3d, grid_fn, with_negative_index=with_negative_index 

471 ) 

472 

473 if return_run_kernel: 

474 return _run_kernel 

475 

476 _run_kernel()