Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/index_put.py: 0%

282 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14def get_max_rank_shape(indices: List[torch.Tensor]) -> List[int]: 

15 # Filter out None values (basic indexing markers) 

16 tensor_indices = [idx for idx in indices if idx is not None] 

17 if len(tensor_indices) == 0: 

18 return [] 

19 max_rank = max([len(index.shape) for index in tensor_indices]) 

20 shape = [0 for _ in range(max_rank)] 

21 for i in range(max_rank): 

22 max_num = 0 

23 for index in tensor_indices: 

24 axis = len(index.shape) - 1 - i 

25 if axis >= 0: 

26 max_num = max(max_num, index.shape[axis]) 

27 shape[max_rank - 1 - i] = max_num 

28 return shape 

29 

30 

31def broadcast_indices(indices, target_shape): 

32 for i, index in enumerate(indices): 

33 if index is not None and tuple(index.shape) != tuple(target_shape): 

34 indices[i] = torch.broadcast_to(index, target_shape) 

35 

36 

37def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

38 code.writeline("import triton") 

39 code.writeline("import triton.language as tl") 

40 code.writeline("import builtins") 

41 code.newline() 

42 code.writeline("from flag_gems.utils import libentry") 

43 code.writeline("from flag_gems import runtime") 

44 code.writeline("from flag_gems.utils.shape_utils import volume") 

45 

46 code.newline() 

47 code.newline() 

48 

49 code.writeline("def heur_block_m(args):") 

50 with code.indent(): 

51 code.writeline('if args["M"] == 0:') 

52 with code.indent(): 

53 code.writeline("return 2") 

54 code.writeline('return triton.next_power_of_2(triton.cdiv(args["M"], 12))') 

55 

56 code.newline() 

57 

58 code.writeline("def heur_block_n(args):") 

59 with code.indent(): 

60 code.writeline('return builtins.min(triton.next_power_of_2(args["N"]), 8192)') 

61 

62 code.newline() 

63 code.newline() 

64 return code 

65 

66 

67def generate_index_put_kernel( 

68 inp_rank, indices_len, index_rank, kernel_name: str, code: IndentedBuffer 

69): 

70 code.writeline("@libentry()") 

71 # code.writeline( 

72 # '@triton.autotune(configs=runtime.get_tuned_config("index_put"), key=["M", "N"], restore_value=["input_ptr"])' 

73 # ) 

74 code.writeline("@triton.heuristics(") 

75 with code.indent(): 

76 code.writeline("values={") 

77 with code.indent(): 

78 code.writeline('"BLOCK_SIZE0": heur_block_m,') 

79 code.writeline('"BLOCK_SIZE1": heur_block_n,') 

80 code.writeline("},") 

81 code.writeline(")") 

82 code.writeline("@triton.jit") 

83 code.writeline(f"def {kernel_name}(") 

84 with code.indent(): 

85 args = ["input_ptr,"] 

86 args += [f"indices{i}_ptr," for i in range(indices_len)] 

87 args += ["values_ptr,"] 

88 args += [f"input_shape{i}: tl.constexpr," for i in range(inp_rank)] 

89 for i in range(indices_len): 

90 args += [f"indices{i}_shape{j}: tl.constexpr," for j in range(index_rank)] 

91 args += [f"input_stride{i}: tl.constexpr," for i in range(inp_rank)] 

92 for i in range(indices_len): 

93 args += [f"indices{i}_stride{j}: tl.constexpr," for j in range(index_rank)] 

94 args += [ 

95 f"values_stride{i}: tl.constexpr," 

96 for i in range(index_rank + inp_rank - indices_len) 

97 ] 

98 args += [ 

99 "M: tl.constexpr,", 

100 "N: tl.constexpr,", 

101 "IS_ACCUMULATE: tl.constexpr,", 

102 "BLOCK_SIZE0: tl.constexpr,", 

103 "BLOCK_SIZE1: tl.constexpr,", 

104 ] 

105 code.writelines(args) 

106 code.writeline("):") 

107 

108 with code.indent(): 

109 code.writeline("pid0 = tl.program_id(axis=0)") 

110 code.writeline("pid1 = tl.program_id(axis=1)") 

111 code.writeline( 

112 "offset0 = pid0 * BLOCK_SIZE0 + tl.arange(0, BLOCK_SIZE0)[:, None]" 

113 ) 

114 if inp_rank == indices_len: 

115 code.writeline("offset1 = pid1 * 1 + tl.arange(0, 1)[None, :]") 

116 else: 

117 code.writeline( 

118 "offset1 = pid1 * BLOCK_SIZE1 + tl.arange(0, BLOCK_SIZE1)[None, :]" 

119 ) 

120 code.newline() 

121 code.writeline("cur_idx = offset0") 

122 for i in range(index_rank - 1, -1, -1): 

123 code.writeline(f"indices_idx{i} = cur_idx % indices0_shape{i}") 

124 code.writeline(f"cur_idx = cur_idx // indices0_shape{i}") 

125 code.newline() 

126 code.writeline("cur_idx = offset1") 

127 for i in range(inp_rank - 1, indices_len - 1, -1): 

128 code.writeline(f"input_idx{i} = cur_idx % input_shape{i}") 

129 code.writeline(f"cur_idx = cur_idx // input_shape{i}") 

130 code.newline() 

131 code.writeline("mask0 = offset0 < M") 

132 for i in range(indices_len): 

133 comp = [f"indices_idx{j} * indices{i}_stride{j}" for j in range(index_rank)] 

134 code.writeline( 

135 f"cur_index{i} = tl.load(indices{i}_ptr + {' + '.join(comp)}, mask=mask0, other=0)" 

136 ) 

137 code.newline() 

138 index_mask = [ 

139 f"(cur_index{i} >= 0) & (cur_index{i} < input_shape{i})" 

140 for i in range(indices_len) 

141 ] 

142 code.writeline(f"index_mask = {' & '.join(index_mask)}") 

143 code.writeline("mask1 = offset1 < N") 

144 code.writeline("mask = index_mask & mask0 & mask1") 

145 code.newline() 

146 comp = [f"cur_index{i} * input_stride{i}" for i in range(indices_len)] 

147 comp += [ 

148 f"input_idx{i} * input_stride{i}" for i in range(indices_len, inp_rank) 

149 ] 

150 code.writeline(f"input_offset = {' + '.join(comp)}") 

151 comp = [f"indices_idx{i} * values_stride{i}" for i in range(index_rank)] 

152 comp += [ 

153 f"input_idx{indices_len + i} * values_stride{index_rank + i}" 

154 for i in range(inp_rank - indices_len) 

155 ] 

156 code.writeline(f"values_offset = {' + '.join(comp)}") 

157 code.newline() 

158 code.writeline("cur_value = tl.load(values_ptr + values_offset, mask=mask)") 

159 code.writeline("if IS_ACCUMULATE:") 

160 with code.indent(): 

161 code.writeline( 

162 "tl.atomic_add(input_ptr + input_offset, cur_value, mask=mask)" 

163 ) 

164 code.writeline("else:") 

165 with code.indent(): 

166 code.writeline("tl.store(input_ptr + input_offset, cur_value, mask=mask)") 

167 

168 code.newline() 

169 code.newline() 

170 return code 

171 

172 

173def generate_index_put_wrapper( 

174 inp_rank, 

175 indices_len, 

176 index_rank, 

177 wrapper_name: str, 

178 kernel_name: str, 

179 code: IndentedBuffer, 

180): 

181 code.writeline(f"def {wrapper_name}(input, indices, values, accumulate):") 

182 with code.indent(): 

183 code.writeline("input_shape = input.shape") 

184 code.writeline("input_stride = input.stride()") 

185 for i in range(indices_len): 

186 code.writeline(f"indices{i}_shape = indices[{i}].shape") 

187 code.writeline(f"indices{i}_stride = indices[{i}].stride()") 

188 code.writeline("values_shape = values.shape") 

189 code.writeline("values_stride = values.stride()") 

190 code.writeline("M = indices[0].numel()") 

191 code.writeline(f"N = volume(input_shape[{indices_len}: ])") 

192 code.newline() 

193 code.writeline("grid = lambda meta: (") 

194 with code.indent(): 

195 code.writeline("triton.cdiv(M, meta['BLOCK_SIZE0']), ") 

196 code.writeline("triton.cdiv(N, meta['BLOCK_SIZE1']), ") 

197 code.writeline(")") 

198 code.newline() 

199 code.writeline(f"{kernel_name}[grid](") 

200 with code.indent(): 

201 args = ["input,"] 

202 args += [f"indices[{i}]," for i in range(indices_len)] 

203 args += ["values,"] 

204 args += [f"input_shape[{i}]," for i in range(inp_rank)] 

205 for i in range(indices_len): 

206 args += [f"indices{i}_shape[{j}]," for j in range(index_rank)] 

207 args += [f"input_stride[{i}]," for i in range(inp_rank)] 

208 for i in range(indices_len): 

209 args += [f"indices{i}_stride[{j}]," for j in range(index_rank)] 

210 args += [ 

211 f"values_stride[{i}]," 

212 for i in range(index_rank + inp_rank - indices_len) 

213 ] 

214 args += ["M,", "N,", "accumulate==True,"] 

215 code.writelines(args) 

216 code.writeline(")") 

217 code.writeline("return input") 

218 code.newline() 

219 code.newline() 

220 return code 

221 

222 

223def generate_code( 

224 inputs: Tuple[Any], 

225 wrapper_name: str, 

226 kernel_name: str, 

227 code: IndentedBuffer, 

228): 

229 inp_rank = inputs[0].ndim 

230 # Filter out None values to get actual tensor indices 

231 tensor_indices = [idx for idx in inputs[1] if idx is not None] 

232 indices_len = len(tensor_indices) 

233 if indices_len == 0: 

234 raise ValueError("At least one non-None index tensor is required") 

235 index_rank = tensor_indices[0].ndim 

236 code = generate_imports(code) 

237 generate_index_put_kernel(inp_rank, indices_len, index_rank, kernel_name, code) 

238 generate_index_put_wrapper( 

239 inp_rank, indices_len, index_rank, wrapper_name, kernel_name, code 

240 ) 

241 return code 

242 

243 

244class IndexPutFunction: 

245 def __init__(self): 

246 self.pid = os.getpid() 

247 self.overloads: Mapping[str, Callable] = {} 

248 

249 def __call__(self, *args, **kwargs): 

250 inp, tensor_indices, values, accumulate = args 

251 full_args = (inp, tensor_indices, values) 

252 

253 key = self.arg_key(*full_args) 

254 if key in self.overloads: 

255 overload = self.overloads[key] 

256 else: 

257 code = IndentedBuffer() 

258 code = generate_code( 

259 full_args, 

260 "_index_put_wrapper", 

261 "_index_put_jit_function", 

262 code, 

263 ) 

264 file_name = f"index_put_{key}.py" 

265 file_path = code_cache_dir() / file_name 

266 write_atomic(file_path, code.getvalue()) 

267 

268 spec = importlib.util.spec_from_file_location( 

269 f"_gen_module_rank_{key}", 

270 file_path, 

271 ) 

272 

273 m = importlib.util.module_from_spec(spec) 

274 spec.loader.exec_module(m) 

275 overload = getattr(m, "_index_put_wrapper") 

276 self.overloads[key] = overload 

277 

278 return overload(*args) 

279 

280 def arg_key(self, *args, **kwargs): 

281 inp, tensor_indices, _ = args[0], args[1], args[2] 

282 inp_rank = inp.ndim 

283 indices_len = len(tensor_indices) 

284 if indices_len == 0: 

285 index_rank = 0 

286 else: 

287 index_rank = tensor_indices[0].ndim 

288 return f"inp_rank_{inp_rank}_indices_len_{indices_len}_index_rank_{index_rank}" 

289 

290 

291_index_put_func = IndexPutFunction() 

292 

293 

294def index_put(inp, indices, values, accumulate=False): 

295 logger.debug("GEMS INDEX PUT") 

296 

297 indices = list(indices) 

298 if len(indices) == 1 and indices[0].dtype == torch.bool: 

299 mask = indices[0] 

300 

301 if mask.device != inp.device: 

302 mask = mask.to(inp.device) 

303 

304 indices = list(torch.where(mask)) 

305 

306 K = indices[0].numel() 

307 target_shape = (K,) + inp.shape[len(indices) :] 

308 

309 if values.numel() == 1: 

310 values = torch.full( 

311 target_shape, values.item(), dtype=inp.dtype, device=inp.device 

312 ) 

313 elif values.numel() == K: 

314 values = values.reshape((K,)).expand(target_shape) 

315 

316 indices = [ 

317 index.to(inp.device) 

318 if index is not None and index.device != inp.device 

319 else index 

320 for index in indices 

321 ] 

322 

323 # Pad missing indices with None to match input dimensions 

324 if len(indices) < inp.ndim: 

325 indices.extend([None] * (inp.ndim - len(indices))) 

326 

327 # Broadcast tensor indices 

328 tensor_pos = [i for i, x in enumerate(indices) if x is not None] 

329 if not tensor_pos: 

330 raise ValueError("At least one non-None index tensor is required") 

331 

332 tensor_indices_list = [indices[i] for i in tensor_pos] 

333 if len(tensor_indices_list) > 1: 

334 broadcasted = torch.broadcast_tensors(*tensor_indices_list) 

335 for i, pos in enumerate(tensor_pos): 

336 indices[pos] = broadcasted[i] 

337 

338 # Determine if transpose is needed 

339 is_contiguous = (tensor_pos[-1] - tensor_pos[0] + 1) == len(tensor_pos) 

340 starts_with_none = indices[0] is None 

341 need_transpose = not is_contiguous or starts_with_none 

342 

343 if need_transpose: 

344 perm_order = tensor_pos + [i for i, x in enumerate(indices) if x is None] 

345 final_indices = [indices[i] for i in tensor_pos] + [None] * ( 

346 len(indices) - len(tensor_pos) 

347 ) 

348 else: 

349 perm_order = None 

350 final_indices = indices 

351 

352 out = inp.clone() 

353 

354 if need_transpose: 

355 # Create a contiguous permuted copy for the kernel 

356 out_perm = out.permute(perm_order).contiguous() 

357 else: 

358 out_perm = out 

359 

360 # Compute target_shape: broadcast_shape + slice_shape (for None dims) 

361 tensors = [x for x in final_indices if x is not None] 

362 broadcast_shape = list(tensors[0].shape) 

363 slice_shape = [out_perm.shape[i] for i, x in enumerate(final_indices) if x is None] 

364 target_shape = broadcast_shape + slice_shape 

365 

366 if values.device != inp.device: 

367 values = values.to(inp.device) 

368 

369 if need_transpose and is_contiguous: 

370 num_before = tensor_pos[0] 

371 before_dims = slice_shape[:num_before] 

372 after_dims = slice_shape[num_before:] 

373 natural_shape = before_dims + broadcast_shape + after_dims 

374 values = values.broadcast_to(natural_shape) 

375 B, T = len(before_dims), len(broadcast_shape) 

376 val_perm = ( 

377 list(range(B, B + T)) + list(range(0, B)) + list(range(B + T, values.ndim)) 

378 ) 

379 values = values.permute(val_perm).contiguous() 

380 else: 

381 values = torch.broadcast_to(values, target_shape).contiguous() 

382 

383 _index_put_func(out_perm, tensors, values, accumulate) 

384 

385 if need_transpose: 

386 # Copy results back to original dimension order 

387 out.permute(perm_order).copy_(out_perm) 

388 

389 return out 

390 

391 

392def index_put_(inp, indices, values, accumulate=False): 

393 logger.debug("GEMS INDEX PUT_") 

394 

395 indices = list(indices) 

396 if len(indices) == 1 and indices[0].dtype == torch.bool: 

397 mask = indices[0] 

398 

399 if mask.device != inp.device: 

400 mask = mask.to(inp.device) 

401 

402 indices = list(torch.where(mask)) 

403 

404 K = indices[0].numel() 

405 target_shape = (K,) + inp.shape[len(indices) :] 

406 

407 if values.numel() == 1: 

408 values = torch.full( 

409 target_shape, values.item(), dtype=inp.dtype, device=inp.device 

410 ) 

411 elif values.numel() == K: 

412 values = values.reshape((K,)).expand(target_shape) 

413 

414 indices = [ 

415 index.to(inp.device) 

416 if index is not None and index.device != inp.device 

417 else index 

418 for index in indices 

419 ] 

420 

421 target_shape = get_max_rank_shape(indices) 

422 broadcast_indices(indices, target_shape) 

423 target_shape += inp.shape[len(indices) :] 

424 # Filter out None values for kernel call (only tensor indices) 

425 # Must be done AFTER broadcast_indices, as broadcast may create new tensors 

426 tensor_indices = [idx for idx in indices if idx is not None] 

427 if not tensor_indices: 

428 raise ValueError("At least one non-None index tensor is required") 

429 

430 if values.device != inp.device: 

431 values = values.to(inp.device) 

432 values = torch.broadcast_to(values, target_shape) 

433 

434 _index_put_func(inp, tensor_indices, values, accumulate) 

435 return inp