Coverage for src/flag_gems/runtime/backend/_ascend/ops/gather_ascend.py: 0%

234 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import functools 

2import importlib.util 

3import logging 

4import os 

5import sys 

6from typing import List, Tuple 

7 

8import torch 

9 

10from flag_gems.utils.code_utils import IndentedBuffer 

11 

12logger = logging.getLogger(__name__) 

13LIBDIVIDE_32_SHIFT_MASK = 0x1F 

14LIBDIVIDE_ADD_MARKER = 0x40 

15 

16CACHE_DIR = os.path.join(os.getcwd(), "__triton_cache__") 

17if not os.path.exists(CACHE_DIR): 

18 os.makedirs(CACHE_DIR, exist_ok=True) 

19sys.path.append(CACHE_DIR) 

20 

21 

22def _clz32(x: int) -> int: 

23 return 32 - x.bit_length() if x else 32 

24 

25 

26def calc_magic_u32_libdivide(d: int) -> Tuple[int, int]: 

27 """ 

28 Compute the libdivide (u32) fast division parameters for a given divisor d. 

29 Returns: (magic:uint32, more:uint8) 

30 

31 - magic == 0 indicates the power-of-two path (shift only) 

32 - the lower 5 bits of `more` represent the shift value 

33 - bit6 (0x40) of `more` indicates the add_marker flag 

34 """ 

35 if not (1 <= d <= 0xFFFFFFFF): 

36 raise ValueError(f"d must be in [1, 2^32-1], got {d}") 

37 

38 # pow2 -> shift path 

39 if (d & (d - 1)) == 0: 

40 shift = d.bit_length() - 1 

41 return 0, shift & 0xFF 

42 

43 floor_log_2_d = 31 - _clz32(d) 

44 

45 # 2^(32+floor_log_2_d) 

46 two_to = 1 << (32 + floor_log_2_d) 

47 

48 proposed_m = two_to // d 

49 rem = two_to - proposed_m * d 

50 e = d - rem 

51 two_power = 1 << floor_log_2_d 

52 

53 if e < two_power: 

54 # no add marker 

55 magic = (proposed_m + 1) & 0xFFFFFFFF 

56 more = floor_log_2_d & 0xFF 

57 return magic, more 

58 else: 

59 # add marker 

60 proposed_m2 = proposed_m * 2 

61 twice_rem = rem * 2 

62 if twice_rem >= d or twice_rem < rem: 

63 proposed_m2 += 1 

64 magic = (proposed_m2 + 1) & 0xFFFFFFFF 

65 more = (floor_log_2_d | LIBDIVIDE_ADD_MARKER) & 0xFF 

66 return magic, more 

67 

68 

69@functools.lru_cache(maxsize=128) 

70def get_all_magics(shape_tuple: Tuple[int, ...]) -> Tuple[List[int], List[int]]: 

71 magic_list, more_list = [], [] 

72 for d in shape_tuple: 

73 magic, more = calc_magic_u32_libdivide(int(d)) 

74 magic_list.append(magic) 

75 more_list.append(more) 

76 return magic_list, more_list 

77 

78 

79def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

80 code.writeline("import torch") 

81 code.writeline("import triton") 

82 code.writeline("import triton.language as tl") 

83 code.newline() 

84 return code 

85 

86 

87def generate_device_functions(code: IndentedBuffer) -> IndentedBuffer: 

88 code.writeline( 

89 "# Device Functions for Fast Division (assume uint32 inputs, no casts/masks)" 

90 ) 

91 code.newline() 

92 

93 # shift-only (magic==0) 

94 code.writeline("@triton.jit") 

95 code.writeline("def fast_divide_shift(n, shift):") 

96 with code.indent(): 

97 code.writeline("return n >> shift") 

98 code.newline() 

99 

100 # mul-noadd (add_marker==0) 

101 code.writeline("@triton.jit") 

102 code.writeline("def fast_divide_mul_noadd(n, magic, shift):") 

103 with code.indent(): 

104 code.writeline("return (tl.umulhi(n, magic) >> shift)") 

105 code.newline() 

106 

107 # mul-add (add_marker==1) 

108 code.writeline("@triton.jit") 

109 code.writeline("def fast_divide_mul_add(n, magic, shift):") 

110 with code.indent(): 

111 code.writeline("q0 = tl.umulhi(n, magic)") 

112 code.writeline("t = ((n - q0) >> 1) + q0") 

113 code.writeline("return (t >> shift)") 

114 code.newline() 

115 

116 return code 

117 

118 

119def generate_gather_kernel( 

120 rank: int, kernel_name: str, div_kinds: List[str], code: IndentedBuffer 

121) -> IndentedBuffer: 

122 code.newline() 

123 

124 # Autotune lists 

125 code.writeline("# Autotune Configuration Lists") 

126 code.writeline("WARP_LIST = [8, 16, 32]") 

127 code.writeline("MEM_LIST = [120 * 1024, 216 * 1024]") 

128 code.writeline("BLOCK_SIZE_LIST = [32, 64, 128, 256, 512, 1024]") 

129 code.writeline("REORDER_LIST = [True, False]") 

130 code.newline() 

131 

132 code.writeline("@triton.autotune(configs=[") 

133 with code.indent(): 

134 code.writeline( 

135 "triton.Config(" 

136 "kwargs={'BLOCK_SIZE': size, 'shared_mem_dynamic_size': localmem, " 

137 "'enable_simt_reorder_instruction': is_reorder}, num_warps=warp)" 

138 ) 

139 code.writeline("for warp in WARP_LIST") 

140 code.writeline("for localmem in MEM_LIST") 

141 code.writeline("for size in BLOCK_SIZE_LIST") 

142 code.writeline("for is_reorder in REORDER_LIST") 

143 code.writeline("],") 

144 code.writeline("key=['num_elements'], ") 

145 code.writeline("warmup=25, ") 

146 code.writeline("rep=100) ") 

147 

148 code.writeline("@triton.jit") 

149 code.writeline(f"def {kernel_name}(") 

150 with code.indent(): 

151 args = [ 

152 "inp_ptr, ", 

153 "index_ptr, ", 

154 "out_ptr, ", 

155 ] 

156 # Unroll shapes and strides in signature to avoid metadata loads 

157 args += [f"inp_shape{i}, " for i in range(rank)] 

158 args += [f"index_shape{i}, " for i in range(rank)] 

159 

160 # libdivide params per dimension 

161 args += [f"index_magic{i}: tl.uint32, " for i in range(rank)] 

162 args += [f"index_more{i}: tl.uint32, " for i in range(rank)] 

163 

164 args += [f"inp_stride{i}, " for i in range(rank)] 

165 args += [f"index_stride{i}, " for i in range(rank)] 

166 args += [f"out_stride{i}, " for i in range(rank)] 

167 

168 args += [ 

169 "dim: tl.constexpr, ", 

170 "num_elements, ", 

171 "with_negative_index: tl.constexpr, ", 

172 "BLOCK_SIZE: tl.constexpr, ", 

173 ] 

174 code.writelines(args) 

175 code.writeline("):") 

176 

177 with code.indent(): 

178 code.writeline("pid = tl.program_id(0)") 

179 code.writeline("num_programs = tl.num_programs(0)") 

180 code.writeline("elements_per_prog = tl.cdiv(num_elements, num_programs)") 

181 code.writeline("prog_start = pid * elements_per_prog") 

182 code.writeline( 

183 "prog_end = tl.minimum(prog_start + elements_per_prog, num_elements)" 

184 ) 

185 code.newline() 

186 

187 code.writeline( 

188 "# Block-Stride Loop (Processing contiguous chunks for better cache hit rate)" 

189 ) 

190 code.writeline("for block_start in range(prog_start, prog_end, BLOCK_SIZE):") 

191 with code.indent(): 

192 code.writeline("offsets = block_start + tl.arange(0, BLOCK_SIZE)") 

193 code.writeline("mask = offsets < num_elements") 

194 code.newline() 

195 

196 code.writeline( 

197 "gather_index = tl.load(index_ptr + offsets, mask=mask, other=0).to(tl.int32)" 

198 ) 

199 

200 code.writeline("base_inp_offset = tl.zeros([BLOCK_SIZE], dtype=tl.int32)") 

201 code.writeline("cur_offset = offsets.to(tl.int32)") 

202 code.newline() 

203 

204 code.writeline("dim_stride = 0") 

205 code.writeline("dim_size = 0") 

206 code.newline() 

207 

208 for i in range(rank - 1, -1, -1): 

209 if i == 0: 

210 # After processing dims [rank-1 .. 1], cur_offset is already in [0, index_shape0). 

211 # So: next_offset = cur_offset // index_shape0 == 0, coord_0 == cur_offset. 

212 code.writeline("coord_0 = cur_offset") 

213 code.writeline("cur_offset = 0") 

214 else: 

215 code.writeline(f"shift = index_more{i}") 

216 if div_kinds[i] == "S": 

217 code.writeline( 

218 "next_offset = fast_divide_shift(cur_offset, shift)" 

219 ) 

220 elif div_kinds[i] == "A": 

221 code.writeline(f"magic = index_magic{i}") 

222 code.writeline( 

223 "next_offset = fast_divide_mul_add(cur_offset, magic, shift)" 

224 ) 

225 else: 

226 code.writeline(f"magic = index_magic{i}") 

227 code.writeline( 

228 "next_offset = fast_divide_mul_noadd(cur_offset, magic, shift)" 

229 ) 

230 

231 code.writeline( 

232 f"coord_{i} = cur_offset - next_offset * index_shape{i}" 

233 ) 

234 code.writeline("cur_offset = next_offset") 

235 

236 code.writeline(f"if dim == {i}:") 

237 with code.indent(): 

238 code.writeline(f"dim_stride = inp_stride{i}") 

239 code.writeline(f"dim_size = inp_shape{i}") 

240 code.writeline("else:") 

241 with code.indent(): 

242 code.writeline(f"base_inp_offset += coord_{i} * inp_stride{i}") 

243 code.newline() 

244 

245 code.writeline("# Handle negative indices") 

246 code.writeline("if with_negative_index:") 

247 with code.indent(): 

248 code.writeline( 

249 "gather_index = tl.where(gather_index < 0, gather_index + dim_size, gather_index).to(tl.int32)" 

250 ) 

251 

252 code.writeline( 

253 "final_inp_offset = base_inp_offset + gather_index * dim_stride" 

254 ) 

255 code.writeline( 

256 "val = tl.load(inp_ptr + final_inp_offset, mask=mask, other=0.0)" 

257 ) 

258 code.writeline("tl.store(out_ptr + offsets, val, mask=mask)") 

259 

260 code.newline() 

261 return code 

262 

263 

264def generate_gather_wrapper( 

265 rank: int, wrapper_name: str, kernel_name: str, code: IndentedBuffer 

266) -> IndentedBuffer: 

267 code.writeline( 

268 f"def {wrapper_name}(inp, dim, index, out, grid, magic, more, with_negative_index):" 

269 ) 

270 with code.indent(): 

271 # Extract shapes and strides 

272 code.writeline("inp_shape = inp.shape") 

273 code.writeline("inp_stride = inp.stride()") 

274 code.writeline("index_shape = index.shape") 

275 code.writeline("index_stride = index.stride()") 

276 code.writeline("out_stride = out.stride()") 

277 code.writeline("num_elements = index.numel()") 

278 code.newline() 

279 

280 code.writeline(f"{kernel_name}[grid](") 

281 with code.indent(): 

282 args = [ 

283 "inp, ", 

284 "index, ", 

285 "out, ", 

286 ] 

287 args += [f"inp_shape[{i}], " for i in range(rank)] 

288 args += [f"index_shape[{i}], " for i in range(rank)] 

289 

290 args += [f"magic[{i}], " for i in range(rank)] 

291 args += [f"more[{i}], " for i in range(rank)] 

292 

293 args += [f"inp_stride[{i}], " for i in range(rank)] 

294 args += [f"index_stride[{i}], " for i in range(rank)] 

295 args += [f"out_stride[{i}], " for i in range(rank)] 

296 

297 args += [ 

298 "dim, ", 

299 "num_elements, ", 

300 "with_negative_index, ", 

301 ] 

302 args += [ 

303 "force_simt_only=False, ", 

304 ] 

305 code.writelines(args) 

306 code.writeline(")") 

307 code.writeline("return out") 

308 code.newline() 

309 return code 

310 

311 

312def generate_code( 

313 inputs, wrapper_name: str, kernel_name: str, div_kinds: List[str] 

314) -> str: 

315 code = IndentedBuffer() 

316 rank = inputs[0].ndim 

317 code = generate_imports(code) 

318 code = generate_device_functions(code) 

319 code = generate_gather_kernel(rank, kernel_name, div_kinds, code) 

320 code = generate_gather_wrapper(rank, wrapper_name, kernel_name, code) 

321 return code.getvalue() 

322 

323 

324class GatherFunction: 

325 def __init__(self): 

326 self.overloads = {} 

327 self.kernels = {} 

328 

329 def __call__( 

330 self, inp, dim, index, out, grid, magic_map=None, with_negative_index=False 

331 ): 

332 rank = inp.ndim 

333 

334 if magic_map is None: 

335 magic, more = get_all_magics(tuple(index.shape)) 

336 else: 

337 magic, more = magic_map 

338 

339 # div_kinds: 'S' (shift-only), 'M' (mul-noadd), 'A' (mul-add) 

340 div_kinds = [] 

341 for m, mo in zip(magic, more): 

342 if int(m) == 0: 

343 div_kinds.append("S") 

344 elif (int(mo) & 0x40) != 0: 

345 div_kinds.append("A") 

346 else: 

347 div_kinds.append("M") 

348 

349 pattern = "".join(div_kinds) 

350 key = f"gather_rank_{rank}_pat_{pattern}" 

351 

352 if key not in self.overloads: 

353 kernel_name = f"_gather_kernel_{rank}" 

354 wrapper_name = f"_gather_wrapper_{rank}" 

355 

356 src_code = generate_code([inp], wrapper_name, kernel_name, div_kinds) 

357 

358 file_name = f"{key}.py" 

359 file_path = os.path.join(CACHE_DIR, file_name) 

360 with open(file_path, "w", encoding="utf-8") as f: 

361 f.write(src_code) 

362 

363 spec = importlib.util.spec_from_file_location( 

364 f"dynamic_mod_{key}", file_path 

365 ) 

366 mod = importlib.util.module_from_spec(spec) 

367 assert spec.loader is not None 

368 spec.loader.exec_module(mod) 

369 

370 self.overloads[key] = getattr(mod, wrapper_name) 

371 

372 return self.overloads[key]( 

373 inp, dim, index, out, grid, magic, more, with_negative_index 

374 ) 

375 

376 def get_kernel(self, rank: int): 

377 return self.kernels.get(f"gather_rank_{rank}") 

378 

379 

380_gather_func = GatherFunction() 

381 

382 

383def gather( 

384 inp, 

385 dim: int, 

386 index, 

387 out=None, 

388 grid_fn=None, 

389 magic_map=None, 

390 with_negative_index=False, 

391): 

392 if out is None: 

393 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device) 

394 

395 _gather_func( 

396 inp, 

397 dim, 

398 index, 

399 out, 

400 grid_fn, 

401 magic_map=magic_map, 

402 with_negative_index=with_negative_index, 

403 ) 

404 return out