Coverage for src/flag_gems/runtime/backend/_mthreads/ops/index_put.py: 0%

382 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

10 

11logger = logging.getLogger( 

12 f"flag_gems.runtime.backend._mthreads.ops.{__name__.split('.')[-1]}" 

13) 

14 

15 

16def get_max_rank_shape(indices: List[torch.Tensor]) -> List[int]: 

17 # Filter out None values (basic indexing markers) 

18 tensor_indices = [idx for idx in indices if idx is not None] 

19 if len(tensor_indices) == 0: 

20 return [] 

21 max_rank = max([len(index.shape) for index in tensor_indices]) 

22 shape = [0 for _ in range(max_rank)] 

23 for i in range(max_rank): 

24 max_num = 0 

25 for index in tensor_indices: 

26 axis = len(index.shape) - 1 - i 

27 if axis >= 0: 

28 max_num = max(max_num, index.shape[axis]) 

29 shape[max_rank - 1 - i] = max_num 

30 return shape 

31 

32 

33def broadcast_indices(indices, target_shape): 

34 for i, index in enumerate(indices): 

35 if index is not None and tuple(index.shape) != tuple(target_shape): 

36 indices[i] = torch.broadcast_to(index, target_shape) 

37 

38 

39def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

40 code.writeline("import triton") 

41 code.writeline("import triton.language as tl") 

42 code.newline() 

43 code.writeline("from flag_gems.utils import libentry, libtuner") 

44 code.writeline("from flag_gems import runtime") 

45 code.writeline("from flag_gems.utils.shape_utils import volume") 

46 code.writeline("from flag_gems.utils import triton_lang_extension as ext") 

47 

48 code.newline() 

49 code.newline() 

50 return code 

51 

52 

53def generate_index_put_kernel( 

54 inp_rank, indices_len, index_rank, kernel_name: str, code: IndentedBuffer 

55): 

56 code.writeline("@libentry()") 

57 code.writeline("@libtuner(") 

58 with code.indent(): 

59 code.writeline('configs=runtime.get_tuned_config("index_put"),') 

60 code.writeline('key=["M", "N"],') 

61 code.writeline('restore_value=["input_ptr"],') 

62 code.writeline('strategy=["align32", "align32"],') 

63 code.writeline("warmup=5,") 

64 code.writeline("rep=10,") 

65 code.writeline(")") 

66 code.writeline("@triton.jit") 

67 code.writeline(f"def {kernel_name}(") 

68 with code.indent(): 

69 args = ["input_ptr,"] 

70 args += [f"indices{i}_ptr," for i in range(indices_len)] 

71 args += ["values_ptr,"] 

72 args += [f"input_shape{i}," for i in range(inp_rank)] 

73 for i in range(indices_len): 

74 args += [f"indices{i}_shape{j}," for j in range(index_rank)] 

75 args += [f"input_stride{i}," for i in range(inp_rank)] 

76 for i in range(indices_len): 

77 args += [f"indices{i}_stride{j}," for j in range(index_rank)] 

78 args += [ 

79 f"values_stride{i}," for i in range(index_rank + inp_rank - indices_len) 

80 ] 

81 args += [ 

82 "M,", 

83 "N,", 

84 "IS_ACCUMULATE: tl.constexpr,", 

85 "BLOCK_SIZE0: tl.constexpr,", 

86 "BLOCK_SIZE1: tl.constexpr,", 

87 ] 

88 code.writelines(args) 

89 code.writeline("):") 

90 

91 with code.indent(): 

92 code.writeline("pid0 = ext.program_id(axis=0)") 

93 code.writeline("pid1 = ext.program_id(axis=1)") 

94 code.writeline( 

95 "offset0 = pid0 * BLOCK_SIZE0 + tl.arange(0, BLOCK_SIZE0)[:, None]" 

96 ) 

97 if inp_rank == indices_len: 

98 code.writeline("offset1 = pid1 * 1 + tl.arange(0, 1)[None, :]") 

99 else: 

100 code.writeline( 

101 "offset1 = pid1 * BLOCK_SIZE1 + tl.arange(0, BLOCK_SIZE1)[None, :]" 

102 ) 

103 code.newline() 

104 code.writeline("cur_idx = offset0") 

105 for i in range(index_rank - 1, -1, -1): 

106 code.writeline(f"indices_idx{i} = cur_idx % indices0_shape{i}") 

107 code.writeline(f"cur_idx = cur_idx // indices0_shape{i}") 

108 code.newline() 

109 code.writeline("cur_idx = offset1") 

110 for i in range(inp_rank - 1, indices_len - 1, -1): 

111 code.writeline(f"input_idx{i} = cur_idx % input_shape{i}") 

112 code.writeline(f"cur_idx = cur_idx // input_shape{i}") 

113 code.newline() 

114 code.writeline("mask0 = offset0 < M") 

115 for i in range(indices_len): 

116 comp = [f"indices_idx{j} * indices{i}_stride{j}" for j in range(index_rank)] 

117 code.writeline( 

118 f"cur_index{i} = tl.load(indices{i}_ptr + {' + '.join(comp)}, mask=mask0, other=0)" 

119 ) 

120 code.newline() 

121 index_mask = [ 

122 f"(cur_index{i} >= 0) & (cur_index{i} < input_shape{i})" 

123 for i in range(indices_len) 

124 ] 

125 code.writeline(f"index_mask = {' & '.join(index_mask)}") 

126 code.writeline("mask1 = offset1 < N") 

127 code.writeline("mask = index_mask & mask0 & mask1") 

128 code.newline() 

129 comp = [f"cur_index{i} * input_stride{i}" for i in range(indices_len)] 

130 comp += [ 

131 f"input_idx{i} * input_stride{i}" for i in range(indices_len, inp_rank) 

132 ] 

133 code.writeline(f"input_offset = {' + '.join(comp)}") 

134 comp = [f"indices_idx{i} * values_stride{i}" for i in range(index_rank)] 

135 comp += [ 

136 f"input_idx{indices_len + i} * values_stride{index_rank + i}" 

137 for i in range(inp_rank - indices_len) 

138 ] 

139 code.writeline(f"values_offset = {' + '.join(comp)}") 

140 code.newline() 

141 code.writeline("cur_value = tl.load(values_ptr + values_offset, mask=mask)") 

142 code.writeline("if IS_ACCUMULATE:") 

143 with code.indent(): 

144 code.writeline( 

145 "tl.atomic_add(input_ptr + input_offset, cur_value, mask=mask)" 

146 ) 

147 code.writeline("else:") 

148 with code.indent(): 

149 code.writeline("tl.store(input_ptr + input_offset, cur_value, mask=mask)") 

150 

151 code.newline() 

152 code.newline() 

153 return code 

154 

155 

156def generate_index_put_wrapper( 

157 inp_rank, 

158 indices_len, 

159 index_rank, 

160 wrapper_name: str, 

161 kernel_name: str, 

162 code: IndentedBuffer, 

163): 

164 code.writeline(f"def {wrapper_name}(input, indices, values, accumulate):") 

165 with code.indent(): 

166 code.writeline("input_shape = input.shape") 

167 code.writeline("input_stride = input.stride()") 

168 for i in range(indices_len): 

169 code.writeline(f"indices{i}_shape = indices[{i}].shape") 

170 code.writeline(f"indices{i}_stride = indices[{i}].stride()") 

171 code.writeline("values_shape = values.shape") 

172 code.writeline("values_stride = values.stride()") 

173 code.writeline("M = indices[0].numel()") 

174 code.writeline(f"N = volume(input_shape[{indices_len}: ])") 

175 code.newline() 

176 code.writeline("grid = lambda meta: (") 

177 with code.indent(): 

178 code.writeline("triton.cdiv(M, meta['BLOCK_SIZE0']), ") 

179 code.writeline("triton.cdiv(N, meta['BLOCK_SIZE1']), ") 

180 code.writeline(")") 

181 code.newline() 

182 code.writeline(f"{kernel_name}[grid](") 

183 with code.indent(): 

184 args = ["input,"] 

185 args += [f"indices[{i}]," for i in range(indices_len)] 

186 args += ["values,"] 

187 args += [f"input_shape[{i}]," for i in range(inp_rank)] 

188 for i in range(indices_len): 

189 args += [f"indices{i}_shape[{j}]," for j in range(index_rank)] 

190 args += [f"input_stride[{i}]," for i in range(inp_rank)] 

191 for i in range(indices_len): 

192 args += [f"indices{i}_stride[{j}]," for j in range(index_rank)] 

193 args += [ 

194 f"values_stride[{i}]," 

195 for i in range(index_rank + inp_rank - indices_len) 

196 ] 

197 args += ["M,", "N,", "accumulate==True,"] 

198 code.writelines(args) 

199 code.writeline(")") 

200 code.writeline("return input") 

201 code.newline() 

202 code.newline() 

203 return code 

204 

205 

206def generate_code( 

207 inputs: Tuple[Any], 

208 wrapper_name: str, 

209 kernel_name: str, 

210 code: IndentedBuffer, 

211): 

212 inp_rank = inputs[0].ndim 

213 # Filter out None values to get actual tensor indices 

214 tensor_indices = [idx for idx in inputs[1] if idx is not None] 

215 indices_len = len(tensor_indices) 

216 if indices_len == 0: 

217 raise ValueError("At least one non-None index tensor is required") 

218 index_rank = tensor_indices[0].ndim 

219 code = generate_imports(code) 

220 generate_index_put_kernel(inp_rank, indices_len, index_rank, kernel_name, code) 

221 generate_index_put_wrapper( 

222 inp_rank, indices_len, index_rank, wrapper_name, kernel_name, code 

223 ) 

224 return code 

225 

226 

227class IndexPutFunction: 

228 def __init__(self): 

229 self.pid = os.getpid() 

230 self.overloads: Mapping[str, Callable] = {} 

231 

232 def __call__(self, *args, **kwargs): 

233 inp, tensor_indices, values, accumulate = args 

234 full_args = (inp, tensor_indices, values) 

235 

236 key = self.arg_key(*full_args) 

237 if key in self.overloads: 

238 overload = self.overloads[key] 

239 else: 

240 code = IndentedBuffer() 

241 code = generate_code( 

242 full_args, 

243 "_index_put_wrapper", 

244 "_index_put_jit_function", 

245 code, 

246 ) 

247 file_name = f"index_put_{key}.py" 

248 file_path = code_cache_dir() / file_name 

249 write_atomic(file_path, code.getvalue()) 

250 

251 spec = importlib.util.spec_from_file_location( 

252 f"_gen_module_rank_{key}", 

253 file_path, 

254 ) 

255 

256 m = importlib.util.module_from_spec(spec) 

257 spec.loader.exec_module(m) 

258 overload = getattr(m, "_index_put_wrapper") 

259 self.overloads[key] = overload 

260 

261 return overload(*args) 

262 

263 def arg_key(self, *args, **kwargs): 

264 inp, tensor_indices, _ = args[0], args[1], args[2] 

265 inp_rank = inp.ndim 

266 indices_len = len(tensor_indices) 

267 if indices_len == 0: 

268 index_rank = 0 

269 else: 

270 index_rank = tensor_indices[0].ndim 

271 return f"inp_rank_{inp_rank}_indices_len_{indices_len}_index_rank_{index_rank}" 

272 

273 

274_index_put_func = IndexPutFunction() 

275 

276 

277def index_put(inp, indices, values, accumulate=False): 

278 logger.debug("GEMS_MTHREADS INDEX PUT") 

279 

280 indices = list(indices) 

281 if len(indices) == 1 and indices[0].dtype == torch.bool: 

282 mask = indices[0] 

283 

284 if mask.device != inp.device: 

285 mask = mask.to(inp.device) 

286 

287 indices = list(torch.where(mask)) 

288 

289 K = indices[0].numel() 

290 target_shape = (K,) + inp.shape[len(indices) :] 

291 

292 if values.numel() == 1: 

293 values = torch.full( 

294 target_shape, values.item(), dtype=inp.dtype, device=inp.device 

295 ) 

296 elif values.numel() == K: 

297 values = values.reshape((K,)).expand(target_shape) 

298 

299 if not indices: 

300 raise ValueError("At least one index tensor is required") 

301 

302 indices = [ 

303 index.to(inp.device) 

304 if index is not None and index.device != inp.device 

305 else index 

306 for index in indices 

307 ] 

308 

309 processed_indices = [] 

310 for idx in indices: 

311 if idx is None: 

312 processed_indices.append(None) 

313 elif idx.dtype in (torch.bool, torch.int8): 

314 processed_indices.extend(idx.nonzero(as_tuple=True)) 

315 elif torch.is_tensor(idx): 

316 processed_indices.append(idx) 

317 else: 

318 raise TypeError( 

319 "tensors used as indices must be long, int, byte or bool tensors" 

320 ) 

321 

322 indices = processed_indices 

323 

324 if len(indices) < inp.ndim: 

325 indices.extend([None] * (inp.ndim - len(indices))) 

326 

327 if len(indices) > inp.ndim: 

328 raise IndexError("too many indices for tensor of dimension {}".format(inp.ndim)) 

329 

330 tensor_pos = [i for i, x in enumerate(indices) if x is not None] 

331 if not tensor_pos: 

332 raise ValueError("At least one non-None index tensor is required") 

333 

334 tensor_indices = [indices[i] for i in tensor_pos] 

335 if len(tensor_indices) > 1: 

336 broadcasted = torch.broadcast_tensors(*tensor_indices) 

337 for i, pos in enumerate(tensor_pos): 

338 indices[pos] = broadcasted[i] 

339 

340 is_contiguous = (tensor_pos[-1] - tensor_pos[0] + 1) == len(tensor_pos) 

341 starts_with_none = indices[0] is None 

342 need_transpose = not is_contiguous or starts_with_none 

343 

344 out = inp.clone() 

345 if need_transpose: 

346 perm_order = tensor_pos + [i for i, x in enumerate(indices) if x is None] 

347 inp_view = out.permute(perm_order) 

348 final_indices = [indices[i] for i in tensor_pos] + [None] * ( 

349 len(indices) - len(tensor_pos) 

350 ) 

351 else: 

352 inp_view = out 

353 final_indices = indices 

354 

355 tensors = [x for x in final_indices if x is not None] 

356 broadcast_shape = list(tensors[0].shape) 

357 slice_shape = [inp_view.shape[i] for i, x in enumerate(final_indices) if x is None] 

358 

359 target_shape = broadcast_shape + slice_shape 

360 values = values.to(inp.device) 

361 if need_transpose and is_contiguous: 

362 num_before = tensor_pos[0] 

363 

364 before_dims = slice_shape[:num_before] 

365 after_dims = slice_shape[num_before:] 

366 natural_shape = before_dims + broadcast_shape + after_dims 

367 values = values.broadcast_to(natural_shape) 

368 

369 B, T = len(before_dims), len(broadcast_shape) 

370 val_perm = ( 

371 list(range(B, B + T)) + list(range(0, B)) + list(range(B + T, values.ndim)) 

372 ) 

373 values = values.permute(val_perm) 

374 else: 

375 values = values.broadcast_to(target_shape) 

376 

377 _index_put_func(inp_view, tensors, values, accumulate) 

378 return out 

379 

380 

381def index_put_(inp, indices, values, accumulate=False): 

382 logger.debug("GEMS_MTHREADS INDEX PUT_") 

383 

384 indices = list(indices) 

385 if len(indices) == 1 and indices[0].dtype == torch.bool: 

386 mask = indices[0] 

387 

388 if mask.device != inp.device: 

389 mask = mask.to(inp.device) 

390 

391 indices = list(torch.where(mask)) 

392 

393 K = indices[0].numel() 

394 target_shape = (K,) + inp.shape[len(indices) :] 

395 

396 if values.numel() == 1: 

397 values = torch.full( 

398 target_shape, values.item(), dtype=inp.dtype, device=inp.device 

399 ) 

400 elif values.numel() == K: 

401 values = values.reshape((K,)).expand(target_shape) 

402 

403 if not indices: 

404 raise ValueError("At least one index tensor is required") 

405 

406 indices = [ 

407 index.to(inp.device) 

408 if index is not None and index.device != inp.device 

409 else index 

410 for index in indices 

411 ] 

412 

413 processed_indices = [] 

414 for idx in indices: 

415 if idx is None: 

416 processed_indices.append(None) 

417 elif idx.dtype in (torch.bool, torch.int8): 

418 processed_indices.extend(idx.nonzero(as_tuple=True)) 

419 elif torch.is_tensor(idx): 

420 processed_indices.append(idx) 

421 else: 

422 raise TypeError( 

423 "tensors used as indices must be long, int, byte or bool tensors" 

424 ) 

425 

426 indices = processed_indices 

427 

428 if len(indices) < inp.ndim: 

429 indices.extend([None] * (inp.ndim - len(indices))) 

430 

431 if len(indices) > inp.ndim: 

432 raise IndexError("too many indices for tensor of dimension {}".format(inp.ndim)) 

433 

434 tensor_pos = [i for i, x in enumerate(indices) if x is not None] 

435 if not tensor_pos: 

436 raise ValueError("At least one non-None index tensor is required") 

437 

438 tensor_indices = [indices[i] for i in tensor_pos] 

439 if len(tensor_indices) > 1: 

440 broadcasted = torch.broadcast_tensors(*tensor_indices) 

441 for i, pos in enumerate(tensor_pos): 

442 indices[pos] = broadcasted[i] 

443 

444 is_contiguous = (tensor_pos[-1] - tensor_pos[0] + 1) == len(tensor_pos) 

445 starts_with_none = indices[0] is None 

446 need_transpose = not is_contiguous or starts_with_none 

447 

448 if need_transpose: 

449 perm_order = tensor_pos + [i for i, x in enumerate(indices) if x is None] 

450 inp_view = inp.permute(perm_order) 

451 final_indices = [indices[i] for i in tensor_pos] + [None] * ( 

452 len(indices) - len(tensor_pos) 

453 ) 

454 else: 

455 inp_view = inp 

456 final_indices = indices 

457 

458 tensors = [x for x in final_indices if x is not None] 

459 broadcast_shape = list(tensors[0].shape) 

460 slice_shape = [inp_view.shape[i] for i, x in enumerate(final_indices) if x is None] 

461 

462 target_shape = broadcast_shape + slice_shape 

463 values = values.to(inp.device) 

464 if need_transpose and is_contiguous: 

465 num_before = tensor_pos[0] 

466 

467 before_dims = slice_shape[:num_before] 

468 after_dims = slice_shape[num_before:] 

469 natural_shape = before_dims + broadcast_shape + after_dims 

470 values = values.broadcast_to(natural_shape) 

471 

472 B, T = len(before_dims), len(broadcast_shape) 

473 val_perm = ( 

474 list(range(B, B + T)) + list(range(0, B)) + list(range(B + T, values.ndim)) 

475 ) 

476 values = values.permute(val_perm) 

477 else: 

478 values = values.broadcast_to(target_shape) 

479 

480 _index_put_func(inp_view, tensors, values, accumulate) 

481 return inp 

482 

483 

484def _index_put_impl_(inp, indices, values, accumulate=False, unsafe=False): 

485 logger.debug("GEMS_MTHREADS _INDEX_PUT_IMPL_") 

486 

487 # The `unsafe` parameter is a hint to PyTorch for bounds checking. 

488 # Our implementation always performs bounds checking, so we ignore this parameter. 

489 # This is consistent with how PyTorch handles it internally. 

490 

491 indices = list(indices) 

492 if len(indices) == 1 and indices[0].dtype == torch.bool: 

493 mask = indices[0] 

494 

495 if mask.device != inp.device: 

496 mask = mask.to(inp.device) 

497 

498 indices = list(torch.where(mask)) 

499 

500 K = indices[0].numel() 

501 target_shape = (K,) + inp.shape[len(indices) :] 

502 

503 if values.numel() == 1: 

504 values = torch.full( 

505 target_shape, values.item(), dtype=inp.dtype, device=inp.device 

506 ) 

507 elif values.numel() == K: 

508 values = values.reshape((K,)).expand(target_shape) 

509 

510 indices = [ 

511 index.to(inp.device) 

512 if index is not None and index.device != inp.device 

513 else index 

514 for index in indices 

515 ] 

516 

517 processed_indices = [] 

518 for idx in indices: 

519 if idx is None: 

520 processed_indices.append(None) 

521 elif idx.dtype in (torch.bool, torch.int8): 

522 processed_indices.extend(idx.nonzero(as_tuple=True)) 

523 elif torch.is_tensor(idx): 

524 processed_indices.append(idx) 

525 else: 

526 raise TypeError( 

527 "tensors used as indices must be long, int, byte or bool tensors" 

528 ) 

529 

530 indices = processed_indices 

531 

532 if len(indices) < inp.ndim: 

533 indices.extend([None] * (inp.ndim - len(indices))) 

534 

535 if len(indices) > inp.ndim: 

536 raise IndexError("too many indices for tensor of dimension {}".format(inp.ndim)) 

537 

538 tensor_pos = [i for i, x in enumerate(indices) if x is not None] 

539 if not tensor_pos: 

540 raise ValueError("At least one non-None index tensor is required") 

541 

542 tensor_indices = [indices[i] for i in tensor_pos] 

543 if len(tensor_indices) > 1: 

544 broadcasted = torch.broadcast_tensors(*tensor_indices) 

545 for i, pos in enumerate(tensor_pos): 

546 indices[pos] = broadcasted[i] 

547 

548 is_contiguous = (tensor_pos[-1] - tensor_pos[0] + 1) == len(tensor_pos) 

549 starts_with_none = indices[0] is None 

550 need_transpose = not is_contiguous or starts_with_none 

551 

552 if need_transpose: 

553 perm_order = tensor_pos + [i for i, x in enumerate(indices) if x is None] 

554 inp_view = inp.permute(perm_order) 

555 final_indices = [indices[i] for i in tensor_pos] + [None] * ( 

556 len(indices) - len(tensor_pos) 

557 ) 

558 else: 

559 inp_view = inp 

560 final_indices = indices 

561 

562 tensors = [x for x in final_indices if x is not None] 

563 broadcast_shape = list(tensors[0].shape) 

564 slice_shape = [inp_view.shape[i] for i, x in enumerate(final_indices) if x is None] 

565 

566 target_shape = broadcast_shape + slice_shape 

567 values = values.to(inp.device) 

568 if need_transpose and is_contiguous: 

569 num_before = tensor_pos[0] 

570 

571 before_dims = slice_shape[:num_before] 

572 after_dims = slice_shape[num_before:] 

573 natural_shape = before_dims + broadcast_shape + after_dims 

574 values = values.broadcast_to(natural_shape) 

575 

576 B, T = len(before_dims), len(broadcast_shape) 

577 val_perm = ( 

578 list(range(B, B + T)) + list(range(0, B)) + list(range(B + T, values.ndim)) 

579 ) 

580 values = values.permute(val_perm) 

581 else: 

582 values = values.broadcast_to(target_shape) 

583 

584 _index_put_func(inp_view, tensors, values, accumulate) 

585 return inp