Coverage for src/flag_gems/ops/index_reduce.py: 44%

443 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7import flag_gems 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import dim_compress, libentry 

10 

11logger = logging.getLogger(__name__) 

12 

13REDUCE_PROD = 0 

14REDUCE_MEAN = 1 

15REDUCE_AMAX = 2 

16REDUCE_AMIN = 3 

17 

18 

19def _heur_block_m(args): 

20 M = args["M"] 

21 return 1 if M < 4 else 4 

22 

23 

24def _heur_block_n(args): 

25 N = args["N"] 

26 return max(1, min(256, triton.next_power_of_2(N))) 

27 

28 

29def _heur_flat_block(args): 

30 total = args["TOTAL"] if "TOTAL" in args else args["N"] 

31 return max(1, min(256, triton.next_power_of_2(total))) 

32 

33 

34@libentry() 

35@triton.heuristics({"BLOCK_M": _heur_block_m, "BLOCK_N": _heur_block_n}) 

36@triton.jit(do_not_specialize=["M", "N", "OUT_N"]) 

37def _index_reduce_kernel( 

38 out, 

39 index, 

40 src, 

41 count, 

42 touched, 

43 M, 

44 N, 

45 OUT_N, 

46 REDUCE: tl.constexpr, 

47 USE_COUNT: tl.constexpr, 

48 USE_TOUCHED: tl.constexpr, 

49 USE_CAS: tl.constexpr, 

50 BLOCK_M: tl.constexpr, 

51 BLOCK_N: tl.constexpr, 

52): 

53 pid_m = tl.program_id(axis=0) 

54 pid_n = tl.program_id(axis=1) 

55 

56 rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

57 cols = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] 

58 mask = (rows < M) & (cols < N) 

59 

60 dst_cols = tl.load(index + cols, mask=cols < N, other=0).to(tl.int64) 

61 src_offsets = rows * N + cols 

62 out_offsets = rows * OUT_N + dst_cols 

63 values = tl.load(src + src_offsets, mask=mask, other=0.0) 

64 

65 if REDUCE == 1: 

66 tl.atomic_add(out + out_offsets, values, mask=mask, sem="relaxed") 

67 ones_i = tl.full((BLOCK_M, BLOCK_N), 1, dtype=tl.int32) 

68 tl.atomic_add(count + out_offsets, ones_i, mask=mask, sem="relaxed") 

69 elif REDUCE == 0: 

70 stop = tl.where(mask, 0, 1).to(tl.int1) 

71 block_stop = False 

72 while not block_stop: 

73 cur = tl.load(out + out_offsets, mask=mask, other=0.0) 

74 new = tl.where(stop, cur, cur * values) 

75 is_nan = new != new 

76 new = tl.where(is_nan, values, new) 

77 cas = tl.atomic_cas(out + out_offsets, cur, new, sem="relaxed") 

78 stop |= (cur == cas) | is_nan 

79 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK_M * BLOCK_N 

80 else: 

81 if USE_CAS: 

82 stop = tl.where(mask, 0, 1).to(tl.int1) 

83 block_stop = False 

84 while not block_stop: 

85 cur = tl.load(out + out_offsets, mask=mask, other=0.0) 

86 if REDUCE == 2: 

87 new = tl.maximum(cur, values) 

88 else: 

89 new = tl.minimum(cur, values) 

90 cas = tl.atomic_cas(out + out_offsets, cur, new, sem="relaxed") 

91 stop |= cur == cas 

92 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK_M * BLOCK_N 

93 else: 

94 if REDUCE == 2: 

95 tl.atomic_max(out + out_offsets, values, mask=mask, sem="relaxed") 

96 else: 

97 tl.atomic_min(out + out_offsets, values, mask=mask, sem="relaxed") 

98 

99 if USE_TOUCHED: 

100 ones_i = tl.full((BLOCK_M, BLOCK_N), 1, dtype=tl.int32) 

101 tl.atomic_add(touched + out_offsets, ones_i, mask=mask, sem="relaxed") 

102 

103 

104@libentry() 

105@triton.heuristics({"BLOCK": _heur_flat_block}) 

106@triton.jit(do_not_specialize=["TOTAL", "M", "N", "OUT_N"]) 

107def _index_reduce_flat_kernel( 

108 out, 

109 index, 

110 src, 

111 count, 

112 touched, 

113 TOTAL, 

114 M, 

115 N, 

116 OUT_N, 

117 REDUCE: tl.constexpr, 

118 USE_COUNT: tl.constexpr, 

119 USE_TOUCHED: tl.constexpr, 

120 USE_CAS: tl.constexpr, 

121 INDEX_MAJOR: tl.constexpr, 

122 BLOCK: tl.constexpr, 

123): 

124 offsets = tl.program_id(axis=0) * BLOCK + tl.arange(0, BLOCK) 

125 mask = offsets < TOTAL 

126 

127 if INDEX_MAJOR: 

128 cols = offsets // M 

129 rows = offsets - cols * M 

130 else: 

131 rows = offsets // N 

132 cols = offsets - rows * N 

133 

134 dst_cols = tl.load(index + cols, mask=mask, other=0).to(tl.int64) 

135 src_offsets = rows * N + cols 

136 out_offsets = rows * OUT_N + dst_cols 

137 values = tl.load(src + src_offsets, mask=mask, other=0.0) 

138 

139 if REDUCE == 1: 

140 tl.atomic_add(out + out_offsets, values, mask=mask, sem="relaxed") 

141 ones_i = tl.full((BLOCK,), 1, dtype=tl.int32) 

142 tl.atomic_add(count + out_offsets, ones_i, mask=mask, sem="relaxed") 

143 elif REDUCE == 0: 

144 stop = tl.where(mask, 0, 1).to(tl.int1) 

145 block_stop = False 

146 while not block_stop: 

147 cur = tl.load(out + out_offsets, mask=mask, other=0.0) 

148 new = tl.where(stop, cur, cur * values) 

149 is_nan = new != new 

150 new = tl.where(is_nan, values, new) 

151 cas = tl.atomic_cas(out + out_offsets, cur, new, sem="relaxed") 

152 stop |= (cur == cas) | is_nan 

153 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK 

154 else: 

155 if USE_CAS: 

156 stop = tl.where(mask, 0, 1).to(tl.int1) 

157 block_stop = False 

158 while not block_stop: 

159 cur = tl.load(out + out_offsets, mask=mask, other=0.0) 

160 if REDUCE == 2: 

161 new = tl.maximum(cur, values) 

162 else: 

163 new = tl.minimum(cur, values) 

164 new = new.to(cur.dtype) 

165 cas = tl.atomic_cas(out + out_offsets, cur, new, sem="relaxed") 

166 stop |= cur == cas 

167 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK 

168 else: 

169 if REDUCE == 2: 

170 tl.atomic_max(out + out_offsets, values, mask=mask, sem="relaxed") 

171 else: 

172 tl.atomic_min(out + out_offsets, values, mask=mask, sem="relaxed") 

173 

174 if USE_TOUCHED: 

175 ones_i = tl.full((BLOCK,), 1, dtype=tl.int32) 

176 tl.atomic_add(touched + out_offsets, ones_i, mask=mask, sem="relaxed") 

177 

178 

179@libentry() 

180@triton.heuristics({"BLOCK": _heur_flat_block}) 

181@triton.jit(do_not_specialize=["TOTAL", "PRE", "POST", "N", "OUT_N"]) 

182def _index_reduce_contiguous_flat_kernel( 

183 out, 

184 index, 

185 src, 

186 count, 

187 touched, 

188 TOTAL, 

189 PRE, 

190 POST, 

191 N, 

192 OUT_N, 

193 REDUCE: tl.constexpr, 

194 USE_COUNT: tl.constexpr, 

195 USE_TOUCHED: tl.constexpr, 

196 USE_CAS: tl.constexpr, 

197 INDEX_MAJOR: tl.constexpr, 

198 BLOCK: tl.constexpr, 

199): 

200 offsets = tl.program_id(axis=0) * BLOCK + tl.arange(0, BLOCK) 

201 mask = offsets < TOTAL 

202 slice_size = PRE * POST 

203 

204 if INDEX_MAJOR: 

205 cols = offsets // slice_size 

206 element = offsets - cols * slice_size 

207 else: 

208 element = offsets // N 

209 cols = offsets - element * N 

210 

211 pre = element // POST 

212 post = element - pre * POST 

213 dst_cols = tl.load(index + cols, mask=mask, other=0).to(tl.int64) 

214 

215 src_offsets = pre * N * POST + cols * POST + post 

216 out_offsets = pre * OUT_N * POST + dst_cols * POST + post 

217 values = tl.load(src + src_offsets, mask=mask, other=0.0) 

218 

219 if REDUCE == 1: 

220 tl.atomic_add(out + out_offsets, values, mask=mask, sem="relaxed") 

221 ones_i = tl.full((BLOCK,), 1, dtype=tl.int32) 

222 tl.atomic_add(count + out_offsets, ones_i, mask=mask, sem="relaxed") 

223 elif REDUCE == 0: 

224 stop = tl.where(mask, 0, 1).to(tl.int1) 

225 block_stop = False 

226 while not block_stop: 

227 cur = tl.load(out + out_offsets, mask=mask, other=0.0) 

228 new = tl.where(stop, cur, cur * values) 

229 is_nan = new != new 

230 new = tl.where(is_nan, values, new) 

231 cas = tl.atomic_cas(out + out_offsets, cur, new, sem="relaxed") 

232 stop |= (cur == cas) | is_nan 

233 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK 

234 else: 

235 if USE_CAS: 

236 stop = tl.where(mask, 0, 1).to(tl.int1) 

237 block_stop = False 

238 while not block_stop: 

239 cur = tl.load(out + out_offsets, mask=mask, other=0.0) 

240 if REDUCE == 2: 

241 new = tl.maximum(cur, values) 

242 else: 

243 new = tl.minimum(cur, values) 

244 new = new.to(cur.dtype) 

245 cas = tl.atomic_cas(out + out_offsets, cur, new, sem="relaxed") 

246 stop |= cur == cas 

247 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK 

248 else: 

249 if REDUCE == 2: 

250 tl.atomic_max(out + out_offsets, values, mask=mask, sem="relaxed") 

251 else: 

252 tl.atomic_min(out + out_offsets, values, mask=mask, sem="relaxed") 

253 

254 if USE_TOUCHED: 

255 ones_i = tl.full((BLOCK,), 1, dtype=tl.int32) 

256 tl.atomic_add(touched + out_offsets, ones_i, mask=mask, sem="relaxed") 

257 

258 

259@libentry() 

260@triton.heuristics({"BLOCK": _heur_flat_block}) 

261@triton.jit(do_not_specialize=["TOTAL"]) 

262def _index_reduce_mean_finalize_kernel( 

263 result, 

264 acc, 

265 original, 

266 count, 

267 TOTAL, 

268 INCLUDE_SELF: tl.constexpr, 

269 BLOCK: tl.constexpr, 

270): 

271 offsets = tl.program_id(axis=0) * BLOCK + tl.arange(0, BLOCK) 

272 mask = offsets < TOTAL 

273 

274 cnt = tl.load(count + offsets, mask=mask, other=0) 

275 acc_val = tl.load(acc + offsets, mask=mask, other=0.0).to(tl.float32) 

276 if INCLUDE_SELF: 

277 denom = cnt + 1 

278 result_val = acc_val / denom.to(tl.float32) 

279 else: 

280 denom = tl.maximum(cnt, 1) 

281 mean_val = acc_val / denom.to(tl.float32) 

282 original_val = tl.load(original + offsets, mask=mask, other=0.0) 

283 result_val = tl.where(cnt > 0, mean_val, original_val) 

284 tl.store(result + offsets, result_val, mask=mask) 

285 

286 

287@libentry() 

288@triton.heuristics({"BLOCK_M": _heur_block_m, "BLOCK_N": _heur_block_n}) 

289@triton.jit(do_not_specialize=["M", "N", "OUT_N"]) 

290def _index_reduce_unique_kernel( 

291 out, 

292 index, 

293 src, 

294 M, 

295 N, 

296 OUT_N, 

297 REDUCE: tl.constexpr, 

298 INCLUDE_SELF: tl.constexpr, 

299 BLOCK_M: tl.constexpr, 

300 BLOCK_N: tl.constexpr, 

301): 

302 pid_m = tl.program_id(axis=0) 

303 pid_n = tl.program_id(axis=1) 

304 

305 rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

306 cols = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] 

307 mask = (rows < M) & (cols < N) 

308 

309 dst_cols = tl.load(index + cols, mask=cols < N, other=0).to(tl.int64) 

310 src_offsets = rows * N + cols 

311 out_offsets = rows * OUT_N + dst_cols 

312 src_values = tl.load(src + src_offsets, mask=mask, other=0.0) 

313 

314 if INCLUDE_SELF: 

315 inp_values = tl.load(out + out_offsets, mask=mask, other=0.0) 

316 if REDUCE == 0: 

317 result = inp_values * src_values 

318 elif REDUCE == 1: 

319 result = (inp_values + src_values) * 0.5 

320 elif REDUCE == 2: 

321 result = tl.maximum(inp_values, src_values) 

322 else: 

323 result = tl.minimum(inp_values, src_values) 

324 else: 

325 result = src_values 

326 

327 tl.store(out + out_offsets, result, mask=mask) 

328 

329 

330@libentry() 

331@triton.heuristics({"BLOCK_N": _heur_block_n}) 

332@triton.jit(do_not_specialize=["TOTAL", "N", "OUT_N"]) 

333def _index_reduce_scan_kernel( 

334 out, 

335 index, 

336 src, 

337 inp, 

338 TOTAL, 

339 N, 

340 OUT_N, 

341 REDUCE: tl.constexpr, 

342 INCLUDE_SELF: tl.constexpr, 

343 USE_FP64: tl.constexpr, 

344 BLOCK_N: tl.constexpr, 

345): 

346 pid = tl.program_id(axis=0) 

347 mask_out = pid < TOTAL 

348 row = pid // OUT_N 

349 dst_col = (pid - row * OUT_N).to(tl.int64) 

350 inp_val = tl.load(inp + pid, mask=mask_out, other=0.0) 

351 if USE_FP64: 

352 inp_val = inp_val.to(tl.float64) 

353 else: 

354 inp_val = inp_val.to(tl.float32) 

355 

356 if REDUCE == 0: 

357 if USE_FP64: 

358 acc = inp_val if INCLUDE_SELF else tl.full((), 1.0, dtype=tl.float64) 

359 else: 

360 acc = inp_val if INCLUDE_SELF else tl.full((), 1.0, dtype=tl.float32) 

361 elif REDUCE == 1: 

362 if USE_FP64: 

363 acc = inp_val if INCLUDE_SELF else tl.full((), 0.0, dtype=tl.float64) 

364 else: 

365 acc = inp_val if INCLUDE_SELF else tl.full((), 0.0, dtype=tl.float32) 

366 elif REDUCE == 2: 

367 if USE_FP64: 

368 acc = ( 

369 inp_val 

370 if INCLUDE_SELF 

371 else tl.full((), float("-inf"), dtype=tl.float64) 

372 ) 

373 else: 

374 acc = ( 

375 inp_val 

376 if INCLUDE_SELF 

377 else tl.full((), float("-inf"), dtype=tl.float32) 

378 ) 

379 else: 

380 if USE_FP64: 

381 acc = ( 

382 inp_val if INCLUDE_SELF else tl.full((), float("inf"), dtype=tl.float64) 

383 ) 

384 else: 

385 acc = ( 

386 inp_val if INCLUDE_SELF else tl.full((), float("inf"), dtype=tl.float32) 

387 ) 

388 

389 hit_count = tl.full((), 1 if INCLUDE_SELF else 0, dtype=tl.int32) 

390 if REDUCE == 0: 

391 col = 0 

392 while col < N: 

393 current_col = tl.load(index + col).to(tl.int64) 

394 matched = current_col == dst_col 

395 value = tl.load(src + row * N + col, mask=matched, other=1.0) 

396 if USE_FP64: 

397 value = value.to(tl.float64) 

398 else: 

399 value = value.to(tl.float32) 

400 acc *= tl.where(matched, value, 1.0) 

401 hit_count += matched.to(tl.int32) 

402 col += 1 

403 else: 

404 offsets = tl.arange(0, BLOCK_N) 

405 start = 0 

406 while start < N: 

407 cols = start + offsets 

408 mask = cols < N 

409 dst_cols = tl.load(index + cols, mask=mask, other=-1).to(tl.int64) 

410 matched = mask & (dst_cols == dst_col) 

411 values = tl.load(src + row * N + cols, mask=mask, other=0.0) 

412 if USE_FP64: 

413 values = values.to(tl.float64) 

414 else: 

415 values = values.to(tl.float32) 

416 

417 matched_count = tl.sum(matched.to(tl.int32), axis=0) 

418 hit_count += matched_count 

419 if REDUCE == 1: 

420 acc += tl.sum(tl.where(matched, values, 0.0), axis=0) 

421 elif REDUCE == 2: 

422 acc = tl.maximum( 

423 acc, tl.max(tl.where(matched, values, float("-inf")), axis=0) 

424 ) 

425 else: 

426 acc = tl.minimum( 

427 acc, tl.min(tl.where(matched, values, float("inf")), axis=0) 

428 ) 

429 start += BLOCK_N 

430 

431 if REDUCE == 1: 

432 acc = acc / tl.maximum(hit_count, 1).to(tl.float32) 

433 result = tl.where(hit_count > 0, acc, inp_val) 

434 tl.store(out + pid, result, mask=mask_out) 

435 

436 

437def _reduce_id(reduce): 

438 if reduce == "prod": 

439 return REDUCE_PROD 

440 if reduce == "mean": 

441 return REDUCE_MEAN 

442 if reduce == "amax": 

443 return REDUCE_AMAX 

444 if reduce == "amin": 

445 return REDUCE_AMIN 

446 raise RuntimeError(f"Unsupported reduce: {reduce}") 

447 

448 

449def _identity_like(inp, reduce): 

450 if reduce == "prod": 

451 return torch.ones_like(inp) 

452 if reduce == "mean": 

453 return torch.zeros_like(inp) 

454 if reduce == "amax": 

455 return torch.full_like(inp, float("-inf")) 

456 if reduce == "amin": 

457 return torch.full_like(inp, float("inf")) 

458 raise RuntimeError(f"Unsupported reduce: {reduce}") 

459 

460 

461def _needs_cas(reduce, dtype): 

462 return flag_gems.vendor_name in ("iluvatar",) or ( 

463 reduce in ("amax", "amin") and dtype in (torch.float16, torch.bfloat16) 

464 ) 

465 

466 

467def _triton_version_at_least(major, minor): 

468 version = triton.__version__.split("+", 1)[0] 

469 parts = [] 

470 for part in version.split(".")[:2]: 

471 number = "" 

472 for char in part: 

473 if not char.isdigit(): 

474 break 

475 number += char 

476 parts.append(int(number or 0)) 

477 while len(parts) < 2: 

478 parts.append(0) 

479 return tuple(parts) >= (major, minor) 

480 

481 

482# Triton 3.3.x rejects bf16 atomic_add during semantic type checking. 

483_TRITON_SUPPORTS_BF16_ATOMIC_ADD = _triton_version_at_least(3, 4) 

484 

485 

486def _should_scan_duplicate_index(index, out_dim, reduce, dtype): 

487 if flag_gems.vendor_name == "ascend": 

488 return False 

489 if _TRITON_SUPPORTS_BF16_ATOMIC_ADD: 

490 return False 

491 if reduce != "prod" and not _needs_cas(reduce, dtype): 

492 return False 

493 return not _index_is_unique(index, out_dim) 

494 

495 

496def _index_is_unique(index, out_dim): 

497 if index.numel() > out_dim: 

498 return False 

499 if index.numel() <= 1: 

500 return True 

501 if flag_gems.vendor_name == "ascend": 

502 index_cpu = index.cpu() 

503 return index_cpu.unique().numel() == index_cpu.numel() 

504 return index.unique().numel() == index.numel() 

505 

506 

507def _prod(values): 

508 result = 1 

509 for value in values: 

510 result *= value 

511 return result 

512 

513 

514def _validate_args(inp, dim, index, source, reduce): 

515 assert reduce in ("prod", "mean", "amax", "amin"), f"Unsupported reduce: {reduce}" 

516 assert inp.ndim > 0, "Expected self to have at least one dimension" 

517 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

518 assert index.ndim == 1, "Index is supposed to be a vector" 

519 assert index.dtype in ( 

520 torch.int32, 

521 torch.int64, 

522 ), "Expected dtype int32/int64 for index" 

523 assert ( 

524 inp.is_floating_point() 

525 ), "index_reduce_(): Expected self to be floating point" 

526 assert ( 

527 source.dtype == inp.dtype 

528 ), "index_reduce_(): Expected self and source to have same dtype" 

529 assert ( 

530 inp.ndim == source.ndim 

531 ), "Self and source should have the same number of dimensions" 

532 assert index.numel() == source.size( 

533 dim 

534 ), "The dimth dimension of source must have the same size as the length of index" 

535 assert all( 

536 inp.size(i) == source.size(i) or i == dim for i in range(inp.ndim) 

537 ), "source.size(d) == self.size(d) for all dimensions d != dim" 

538 

539 

540def _restore_dim(out, inp, dim): 

541 if ( 

542 out.data_ptr() == inp.data_ptr() 

543 and out.shape == inp.shape 

544 and out.stride() == inp.stride() 

545 ): 

546 return inp 

547 final_dim = inp.ndim - 1 

548 if dim != final_dim: 

549 order = list(range(out.ndim - 1)) 

550 order.insert(dim, final_dim) 

551 out = out.permute(order).contiguous() 

552 inp.copy_(out) 

553 return inp 

554 

555 

556def index_reduce_(inp, dim, index, source, reduce, *, include_self=True): 

557 logger.debug("GEMS INDEX_REDUCE_") 

558 _validate_args(inp, dim, index, source, reduce) 

559 

560 if index.numel() == 0: 

561 return inp 

562 

563 dim = dim % inp.ndim 

564 index = index.contiguous() 

565 reduce_id = _reduce_id(reduce) 

566 use_fp32_workspace = ( 

567 flag_gems.vendor_name != "ascend" 

568 and reduce == "mean" 

569 and inp.dtype == torch.bfloat16 

570 and not _TRITON_SUPPORTS_BF16_ATOMIC_ADD 

571 ) 

572 

573 if _should_scan_duplicate_index(index, inp.size(dim), reduce, inp.dtype): 

574 inp_work = dim_compress(inp, dim) 

575 source_work = dim_compress(source, dim) 

576 N = index.numel() 

577 out_n = inp_work.size(-1) 

578 compute_dtype = ( 

579 torch.float64 if inp_work.dtype == torch.float64 else torch.float32 

580 ) 

581 inp_compute = inp_work.to(compute_dtype) 

582 source_compute = source_work.to(compute_dtype) 

583 out = torch.empty_like(inp_compute) 

584 total = inp_compute.numel() 

585 grid = (total,) 

586 with torch_device_fn.device(inp.device): 

587 _index_reduce_scan_kernel[grid]( 

588 out, 

589 index, 

590 source_compute, 

591 inp_compute, 

592 total, 

593 N, 

594 out_n, 

595 reduce_id, 

596 include_self, 

597 compute_dtype == torch.float64, 

598 ) 

599 return _restore_dim(out.to(inp.dtype), inp, dim) 

600 

601 if ( 

602 flag_gems.vendor_name != "ascend" 

603 and inp.is_contiguous() 

604 and source.is_contiguous() 

605 and not use_fp32_workspace 

606 ): 

607 pre = _prod(inp.shape[:dim]) 

608 post = _prod(inp.shape[dim + 1 :]) 

609 N = index.numel() 

610 out_n = inp.size(dim) 

611 total = pre * post * N 

612 

613 if include_self: 

614 out = inp 

615 else: 

616 out = _identity_like(inp, reduce) 

617 touched = torch.zeros_like(inp, dtype=torch.int32) 

618 

619 if reduce == "mean": 

620 count = torch.zeros_like(inp, dtype=torch.int32) 

621 else: 

622 count = torch.empty(1, dtype=torch.int32, device=inp.device) 

623 

624 if include_self: 

625 touched = torch.empty(1, dtype=torch.int32, device=inp.device) 

626 

627 use_cas = _needs_cas(reduce, inp.dtype) 

628 index_major = post > 1 or dim == 0 

629 with torch_device_fn.device(inp.device): 

630 _index_reduce_contiguous_flat_kernel[ 

631 (lambda meta: (triton.cdiv(total, meta["BLOCK"]),)) 

632 ]( 

633 out, 

634 index, 

635 source, 

636 count, 

637 touched, 

638 total, 

639 pre, 

640 post, 

641 N, 

642 out_n, 

643 reduce_id, 

644 reduce == "mean", 

645 not include_self, 

646 use_cas, 

647 index_major, 

648 ) 

649 

650 if reduce == "mean": 

651 acc = out 

652 with torch_device_fn.device(inp.device): 

653 _index_reduce_mean_finalize_kernel[ 

654 (lambda meta: (triton.cdiv(inp.numel(), meta["BLOCK"]),)) 

655 ]( 

656 inp, 

657 acc, 

658 inp, 

659 count, 

660 inp.numel(), 

661 include_self, 

662 ) 

663 elif not include_self: 

664 inp.copy_(torch.where(touched == 0, inp, out)) 

665 return inp 

666 

667 inp_work = dim_compress(inp, dim) 

668 source_work = dim_compress(source, dim) 

669 

670 M = source_work.numel() // index.numel() 

671 N = index.numel() 

672 out_n = inp_work.size(-1) 

673 

674 if flag_gems.vendor_name == "ascend" and _index_is_unique(index, out_n): 

675 out = inp_work 

676 grid = lambda meta: ( 

677 triton.cdiv(M, meta["BLOCK_M"]), 

678 triton.cdiv(N, meta["BLOCK_N"]), 

679 ) 

680 with torch_device_fn.device(inp.device): 

681 _index_reduce_unique_kernel[grid]( 

682 out, 

683 index, 

684 source_work, 

685 M, 

686 N, 

687 out_n, 

688 reduce_id, 

689 include_self, 

690 False, 

691 ) 

692 return _restore_dim(out, inp, dim) 

693 

694 if flag_gems.vendor_name == "ascend": 

695 inp_compute = inp_work.to(torch.float32) 

696 source_compute = source_work.to(torch.float32) 

697 out = torch.empty_like(inp_compute) 

698 total = inp_compute.numel() 

699 grid = (total,) 

700 with torch_device_fn.device(inp.device): 

701 _index_reduce_scan_kernel[grid]( 

702 out, 

703 index, 

704 source_compute, 

705 inp_compute, 

706 total, 

707 N, 

708 out_n, 

709 reduce_id, 

710 include_self, 

711 False, 

712 ) 

713 return _restore_dim(out.to(inp.dtype), inp, dim) 

714 

715 if use_fp32_workspace: 

716 inp_compute = inp_work.to(torch.float32) 

717 source_compute = source_work.to(torch.float32) 

718 else: 

719 inp_compute = inp_work 

720 source_compute = source_work 

721 

722 if include_self: 

723 out = inp_compute 

724 else: 

725 out = _identity_like(inp_compute, reduce) 

726 touched = torch.zeros_like(inp_compute, dtype=torch.int32) 

727 

728 if reduce == "mean": 

729 count = torch.zeros_like(out, dtype=torch.int32) 

730 else: 

731 count = torch.empty(1, dtype=torch.int32, device=inp.device) 

732 

733 if include_self: 

734 touched = torch.empty(1, dtype=torch.int32, device=inp.device) 

735 

736 use_cas = _needs_cas(reduce, inp_work.dtype) 

737 total = M * N 

738 index_major = dim == 0 

739 

740 with torch_device_fn.device(inp.device): 

741 _index_reduce_flat_kernel[(lambda meta: (triton.cdiv(total, meta["BLOCK"]),))]( 

742 out, 

743 index, 

744 source_compute, 

745 count, 

746 touched, 

747 total, 

748 M, 

749 N, 

750 out_n, 

751 reduce_id, 

752 reduce == "mean", 

753 not include_self, 

754 use_cas, 

755 index_major, 

756 ) 

757 

758 if reduce == "mean": 

759 result = out 

760 with torch_device_fn.device(inp.device): 

761 _index_reduce_mean_finalize_kernel[ 

762 (lambda meta: (triton.cdiv(inp_compute.numel(), meta["BLOCK"]),)) 

763 ]( 

764 result, 

765 out, 

766 inp_compute, 

767 count, 

768 inp_compute.numel(), 

769 include_self, 

770 ) 

771 out = result 

772 elif not include_self: 

773 out = torch.where(touched == 0, inp_compute, out) 

774 

775 return _restore_dim(out.to(inp.dtype), inp, dim)