Coverage for src/flag_gems/ops/scatter_reduce.py: 34%

386 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1"""Triton implementation of torch.scatter_reduce for FlagGems. 

2 

3Supports all reduce modes: sum, prod, mean, amax, amin. 

4Handles 1D-5D tensors with up to 5D coordinate decoding via padding. 

5 

6Vendor compatibility: 

7 - NVIDIA: native atomic_max/min for amax/amin reduce 

8 - Iluvatar: CAS-based fallback for atomic_max/min (no native support) 

9 - Metax: larger BLOCK=256 for better occupancy 

10 

11Performance notes: 

12 - Sum/mean use tl.atomic_add with relaxed semantics for throughput 

13 - Prod uses CAS loop with NaN detection guard (no tl.atomic_mul exists) 

14 - All offset arithmetic uses int64 to avoid overflow for N > 2^31 

15 - LOOP=4: each program processes LOOP*BLOCK elements to amortize launch overhead 

16 - 2D fast path: specialized kernels for 2D tensors avoid 5D coordinate decoding 

17""" 

18 

19import logging 

20 

21import torch 

22import triton 

23import triton.language as tl 

24 

25import flag_gems 

26from flag_gems.runtime import torch_device_fn 

27from flag_gems.utils import libentry 

28 

29logger = logging.getLogger(__name__) 

30 

31 

32def heur_block(args): 

33 """Vendor-aware block size heuristic. 

34 

35 Metax and Iluvatar GPUs benefit from larger blocks (256) for better 

36 occupancy. NVIDIA GPUs default to 128 which balances occupancy and 

37 register pressure. 

38 """ 

39 if flag_gems.vendor_name in ["metax", "iluvatar"]: 

40 return 256 

41 return 128 

42 

43 

44def heur_loop(args): 

45 """Loop unrolling factor. 

46 

47 Each program processes LOOP*BLOCK elements to amortize kernel launch 

48 overhead. LOOP=4 is optimal for Iluvatar BI-V150. 

49 """ 

50 return 4 

51 

52 

53# --------------------------------------------------------------------------- 

54# Helpers 

55# --------------------------------------------------------------------------- 

56 

57 

58def _pad5(lst, fill): 

59 """Pad a list to exactly 5 elements from the left with `fill`. 

60 

61 This enables uniform 5D coordinate decoding in kernels regardless 

62 of the actual tensor dimensionality (1D-5D). Shapes are padded with 1, 

63 strides with 0. 

64 """ 

65 return [fill] * (5 - len(lst)) + lst if len(lst) < 5 else lst 

66 

67 

68def _needs_cas_fallback(): 

69 """Check if the current vendor needs CAS-based fallback for atomic_max/min. 

70 

71 Iluvatar GPUs lack native tl.atomic_max/min, so we fall back to a 

72 CAS (Compare-And-Swap) loop for amax/amin reduce modes. 

73 """ 

74 return flag_gems.vendor_name in ["iluvatar"] 

75 

76 

77# --------------------------------------------------------------------------- 

78# 2D Fast Path Kernels with LOOP 

79# Specialized for 2D tensors to avoid 5D coordinate decoding overhead. 

80# Uses 1D grid with LOOP=4 to amortize kernel launch overhead. 

81# --------------------------------------------------------------------------- 

82 

83 

84@libentry() 

85@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop}) 

86@triton.jit(do_not_specialize=["N"]) 

87def scatter_reduce_sum_2d_kernel( 

88 index_ptr, 

89 src_ptr, 

90 out_ptr, 

91 mask_ptr, 

92 N, 

93 src_ncols, 

94 out_ncols, 

95 DIM: tl.constexpr, 

96 USE_MASK: tl.constexpr, 

97 BLOCK: tl.constexpr, 

98 LOOP: tl.constexpr, 

99): 

100 pid = tl.program_id(axis=0) 

101 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK) 

102 

103 for i in range(LOOP): 

104 offsets = (base_offsets + i * BLOCK).to(tl.int64) 

105 mask = offsets < N 

106 

107 row = offsets // src_ncols 

108 col = offsets % src_ncols 

109 

110 if DIM == 0: 

111 src_offsets = row * src_ncols + col 

112 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64) 

113 out_offsets = idx * out_ncols + col 

114 else: 

115 src_offsets = row * src_ncols + col 

116 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64) 

117 out_offsets = row * out_ncols + idx 

118 

119 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32) 

120 tl.atomic_add(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed") 

121 

122 if USE_MASK: 

123 ones = tl.full((BLOCK,), 1, dtype=tl.int32) 

124 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed") 

125 

126 

127@libentry() 

128@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop}) 

129@triton.jit(do_not_specialize=["N"]) 

130def scatter_reduce_prod_2d_kernel( 

131 index_ptr, 

132 src_ptr, 

133 out_ptr, 

134 mask_ptr, 

135 N, 

136 src_ncols, 

137 out_ncols, 

138 DIM: tl.constexpr, 

139 USE_MASK: tl.constexpr, 

140 BLOCK: tl.constexpr, 

141 LOOP: tl.constexpr, 

142): 

143 pid = tl.program_id(axis=0) 

144 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK) 

145 

146 for i in range(LOOP): 

147 offsets = (base_offsets + i * BLOCK).to(tl.int64) 

148 mask = offsets < N 

149 

150 row = offsets // src_ncols 

151 col = offsets % src_ncols 

152 

153 if DIM == 0: 

154 src_offsets = row * src_ncols + col 

155 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64) 

156 out_offsets = idx * out_ncols + col 

157 else: 

158 src_offsets = row * src_ncols + col 

159 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64) 

160 out_offsets = row * out_ncols + idx 

161 

162 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32) 

163 

164 # CAS loop for product 

165 stop = tl.where(mask, 0, 1).to(tl.int1) 

166 block_stop = False 

167 while not block_stop: 

168 cur_val = tl.load(out_ptr + out_offsets, mask=mask, other=0.0) 

169 new_val = tl.where(stop, cur_val, cur_val * src_val) 

170 is_nan = new_val != new_val 

171 new_val = tl.where(is_nan, src_val, new_val) 

172 cas_res = tl.atomic_cas( 

173 out_ptr + out_offsets, cur_val, new_val, sem="relaxed" 

174 ) 

175 stop |= (cur_val == cas_res) | is_nan 

176 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK 

177 

178 if USE_MASK: 

179 ones = tl.full((BLOCK,), 1, dtype=tl.int32) 

180 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed") 

181 

182 

183@libentry() 

184@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop}) 

185@triton.jit(do_not_specialize=["N"]) 

186def scatter_reduce_mean_2d_kernel( 

187 index_ptr, 

188 src_ptr, 

189 out_ptr, 

190 count_ptr, 

191 mask_ptr, 

192 N, 

193 src_ncols, 

194 out_ncols, 

195 DIM: tl.constexpr, 

196 USE_MASK: tl.constexpr, 

197 BLOCK: tl.constexpr, 

198 LOOP: tl.constexpr, 

199): 

200 pid = tl.program_id(axis=0) 

201 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK) 

202 

203 for i in range(LOOP): 

204 offsets = (base_offsets + i * BLOCK).to(tl.int64) 

205 mask = offsets < N 

206 

207 row = offsets // src_ncols 

208 col = offsets % src_ncols 

209 

210 if DIM == 0: 

211 src_offsets = row * src_ncols + col 

212 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64) 

213 out_offsets = idx * out_ncols + col 

214 else: 

215 src_offsets = row * src_ncols + col 

216 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64) 

217 out_offsets = row * out_ncols + idx 

218 

219 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32) 

220 

221 tl.atomic_add(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed") 

222 ones_f = tl.full((BLOCK,), 1.0, dtype=tl.float32) 

223 tl.atomic_add(count_ptr + out_offsets, ones_f, mask=mask, sem="relaxed") 

224 

225 if USE_MASK: 

226 ones_i = tl.full((BLOCK,), 1, dtype=tl.int32) 

227 tl.atomic_add(mask_ptr + out_offsets, ones_i, mask=mask, sem="relaxed") 

228 

229 

230@libentry() 

231@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop}) 

232@triton.jit(do_not_specialize=["N"]) 

233def scatter_reduce_amax_2d_kernel( 

234 index_ptr, 

235 src_ptr, 

236 out_ptr, 

237 mask_ptr, 

238 N, 

239 src_ncols, 

240 out_ncols, 

241 DIM: tl.constexpr, 

242 IS_AMAX: tl.constexpr, 

243 USE_MASK: tl.constexpr, 

244 USE_CAS: tl.constexpr, 

245 BLOCK: tl.constexpr, 

246 LOOP: tl.constexpr, 

247): 

248 pid = tl.program_id(axis=0) 

249 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK) 

250 

251 for i in range(LOOP): 

252 offsets = (base_offsets + i * BLOCK).to(tl.int64) 

253 mask = offsets < N 

254 

255 row = offsets // src_ncols 

256 col = offsets % src_ncols 

257 

258 if DIM == 0: 

259 src_offsets = row * src_ncols + col 

260 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64) 

261 out_offsets = idx * out_ncols + col 

262 else: 

263 src_offsets = row * src_ncols + col 

264 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64) 

265 out_offsets = row * out_ncols + idx 

266 

267 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32) 

268 

269 if USE_CAS: 

270 stop = tl.where(mask, 0, 1).to(tl.int1) 

271 block_stop = False 

272 while not block_stop: 

273 cur_val = tl.load(out_ptr + out_offsets, mask=mask, other=0.0) 

274 if IS_AMAX: 

275 new_val = tl.maximum(cur_val, src_val) 

276 else: 

277 new_val = tl.minimum(cur_val, src_val) 

278 cas_res = tl.atomic_cas( 

279 out_ptr + out_offsets, cur_val, new_val, sem="relaxed" 

280 ) 

281 stop |= cur_val == cas_res 

282 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK 

283 else: 

284 if IS_AMAX: 

285 tl.atomic_max(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed") 

286 else: 

287 tl.atomic_min(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed") 

288 

289 if USE_MASK: 

290 ones = tl.full((BLOCK,), 1, dtype=tl.int32) 

291 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed") 

292 

293 

294# --------------------------------------------------------------------------- 

295# Generic 5D Kernels with LOOP optimization 

296# For tensors with ndim != 2. 

297# --------------------------------------------------------------------------- 

298 

299 

300@libentry() 

301@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop}) 

302@triton.jit(do_not_specialize=["N"]) 

303def scatter_reduce_sum_kernel( 

304 index_ptr, 

305 src_ptr, 

306 out_ptr, 

307 mask_ptr, 

308 N, 

309 out_stride_dim, 

310 src_stride_dim, 

311 src_shape_dim, 

312 out_shape_dim, 

313 DIM: tl.constexpr, 

314 USE_MASK: tl.constexpr, 

315 src_stride_0, 

316 src_stride_1, 

317 src_stride_2, 

318 src_stride_3, 

319 src_stride_4, 

320 src_shape_0, 

321 src_shape_1, 

322 src_shape_2, 

323 src_shape_3, 

324 src_shape_4, 

325 idx_stride_0, 

326 idx_stride_1, 

327 idx_stride_2, 

328 idx_stride_3, 

329 idx_stride_4, 

330 out_stride_0, 

331 out_stride_1, 

332 out_stride_2, 

333 out_stride_3, 

334 out_stride_4, 

335 BLOCK: tl.constexpr, 

336 LOOP: tl.constexpr, 

337): 

338 pid = tl.program_id(axis=0) 

339 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK) 

340 

341 for i in range(LOOP): 

342 offsets = (base_offsets + i * BLOCK).to(tl.int64) 

343 mask = offsets < N 

344 

345 remaining = offsets 

346 coord0 = remaining // (src_shape_1 * src_shape_2 * src_shape_3 * src_shape_4) 

347 remaining = remaining % (src_shape_1 * src_shape_2 * src_shape_3 * src_shape_4) 

348 coord1 = remaining // (src_shape_2 * src_shape_3 * src_shape_4) 

349 remaining = remaining % (src_shape_2 * src_shape_3 * src_shape_4) 

350 coord2 = remaining // (src_shape_3 * src_shape_4) 

351 remaining = remaining % (src_shape_3 * src_shape_4) 

352 coord3 = remaining // src_shape_4 

353 coord4 = remaining % src_shape_4 

354 

355 idx_offsets = ( 

356 coord0 * idx_stride_0 

357 + coord1 * idx_stride_1 

358 + coord2 * idx_stride_2 

359 + coord3 * idx_stride_3 

360 + coord4 * idx_stride_4 

361 ) 

362 src_offsets = ( 

363 coord0 * src_stride_0 

364 + coord1 * src_stride_1 

365 + coord2 * src_stride_2 

366 + coord3 * src_stride_3 

367 + coord4 * src_stride_4 

368 ) 

369 

370 idx = tl.load(index_ptr + idx_offsets, mask=mask, other=0).to(tl.int64) 

371 

372 if DIM == 0: 

373 out_offsets = ( 

374 idx * out_stride_0 

375 + coord1 * out_stride_1 

376 + coord2 * out_stride_2 

377 + coord3 * out_stride_3 

378 + coord4 * out_stride_4 

379 ) 

380 elif DIM == 1: 

381 out_offsets = ( 

382 coord0 * out_stride_0 

383 + idx * out_stride_1 

384 + coord2 * out_stride_2 

385 + coord3 * out_stride_3 

386 + coord4 * out_stride_4 

387 ) 

388 elif DIM == 2: 

389 out_offsets = ( 

390 coord0 * out_stride_0 

391 + coord1 * out_stride_1 

392 + idx * out_stride_2 

393 + coord3 * out_stride_3 

394 + coord4 * out_stride_4 

395 ) 

396 elif DIM == 3: 

397 out_offsets = ( 

398 coord0 * out_stride_0 

399 + coord1 * out_stride_1 

400 + coord2 * out_stride_2 

401 + idx * out_stride_3 

402 + coord4 * out_stride_4 

403 ) 

404 else: 

405 out_offsets = ( 

406 coord0 * out_stride_0 

407 + coord1 * out_stride_1 

408 + coord2 * out_stride_2 

409 + coord3 * out_stride_3 

410 + idx * out_stride_4 

411 ) 

412 

413 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32) 

414 tl.atomic_add(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed") 

415 

416 if USE_MASK: 

417 ones = tl.full((BLOCK,), 1, dtype=tl.int32) 

418 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed") 

419 

420 

421@libentry() 

422@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop}) 

423@triton.jit(do_not_specialize=["N"]) 

424def scatter_reduce_prod_kernel( 

425 index_ptr, 

426 src_ptr, 

427 out_ptr, 

428 mask_ptr, 

429 N, 

430 out_stride_dim, 

431 src_stride_dim, 

432 src_shape_dim, 

433 out_shape_dim, 

434 DIM: tl.constexpr, 

435 USE_MASK: tl.constexpr, 

436 src_stride_0, 

437 src_stride_1, 

438 src_stride_2, 

439 src_stride_3, 

440 src_stride_4, 

441 src_shape_0, 

442 src_shape_1, 

443 src_shape_2, 

444 src_shape_3, 

445 src_shape_4, 

446 idx_stride_0, 

447 idx_stride_1, 

448 idx_stride_2, 

449 idx_stride_3, 

450 idx_stride_4, 

451 out_stride_0, 

452 out_stride_1, 

453 out_stride_2, 

454 out_stride_3, 

455 out_stride_4, 

456 BLOCK: tl.constexpr, 

457 LOOP: tl.constexpr, 

458): 

459 pid = tl.program_id(axis=0) 

460 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK) 

461 

462 for i in range(LOOP): 

463 offsets = (base_offsets + i * BLOCK).to(tl.int64) 

464 mask = offsets < N 

465 

466 remaining = offsets 

467 coord0 = remaining // (src_shape_1 * src_shape_2 * src_shape_3 * src_shape_4) 

468 remaining = remaining % (src_shape_1 * src_shape_2 * src_shape_3 * src_shape_4) 

469 coord1 = remaining // (src_shape_2 * src_shape_3 * src_shape_4) 

470 remaining = remaining % (src_shape_2 * src_shape_3 * src_shape_4) 

471 coord2 = remaining // (src_shape_3 * src_shape_4) 

472 remaining = remaining % (src_shape_3 * src_shape_4) 

473 coord3 = remaining // src_shape_4 

474 coord4 = remaining % src_shape_4 

475 

476 idx_offsets = ( 

477 coord0 * idx_stride_0 

478 + coord1 * idx_stride_1 

479 + coord2 * idx_stride_2 

480 + coord3 * idx_stride_3 

481 + coord4 * idx_stride_4 

482 ) 

483 src_offsets = ( 

484 coord0 * src_stride_0 

485 + coord1 * src_stride_1 

486 + coord2 * src_stride_2 

487 + coord3 * src_stride_3 

488 + coord4 * src_stride_4 

489 ) 

490 

491 idx = tl.load(index_ptr + idx_offsets, mask=mask, other=0).to(tl.int64) 

492 

493 if DIM == 0: 

494 out_offsets = ( 

495 idx * out_stride_0 

496 + coord1 * out_stride_1 

497 + coord2 * out_stride_2 

498 + coord3 * out_stride_3 

499 + coord4 * out_stride_4 

500 ) 

501 elif DIM == 1: 

502 out_offsets = ( 

503 coord0 * out_stride_0 

504 + idx * out_stride_1 

505 + coord2 * out_stride_2 

506 + coord3 * out_stride_3 

507 + coord4 * out_stride_4 

508 ) 

509 elif DIM == 2: 

510 out_offsets = ( 

511 coord0 * out_stride_0 

512 + coord1 * out_stride_1 

513 + idx * out_stride_2 

514 + coord3 * out_stride_3 

515 + coord4 * out_stride_4 

516 ) 

517 elif DIM == 3: 

518 out_offsets = ( 

519 coord0 * out_stride_0 

520 + coord1 * out_stride_1 

521 + coord2 * out_stride_2 

522 + idx * out_stride_3 

523 + coord4 * out_stride_4 

524 ) 

525 else: 

526 out_offsets = ( 

527 coord0 * out_stride_0 

528 + coord1 * out_stride_1 

529 + coord2 * out_stride_2 

530 + coord3 * out_stride_3 

531 + idx * out_stride_4 

532 ) 

533 

534 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32) 

535 

536 # CAS loop for product. NaN/Inf guard: if cur_val is NaN, mark as done 

537 # to prevent infinite spin (NaN != NaN causes CAS to always fail). 

538 stop = tl.where(mask, 0, 1).to(tl.int1) 

539 block_stop = False 

540 while not block_stop: 

541 cur_val = tl.load(out_ptr + out_offsets, mask=mask, other=0.0) 

542 new_val = tl.where(stop, cur_val, cur_val * src_val) 

543 # Detect NaN: if new_val != new_val (NaN check), use src_val directly 

544 is_nan = new_val != new_val 

545 new_val = tl.where(is_nan, src_val, new_val) 

546 cas_res = tl.atomic_cas( 

547 out_ptr + out_offsets, cur_val, new_val, sem="relaxed" 

548 ) 

549 # Mark done if CAS succeeded OR if value is NaN (can't recover) 

550 stop |= (cur_val == cas_res) | is_nan 

551 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK 

552 

553 if USE_MASK: 

554 ones = tl.full((BLOCK,), 1, dtype=tl.int32) 

555 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed") 

556 

557 

558@libentry() 

559@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop}) 

560@triton.jit(do_not_specialize=["N"]) 

561def scatter_reduce_mean_kernel( 

562 index_ptr, 

563 src_ptr, 

564 out_ptr, 

565 count_ptr, 

566 mask_ptr, 

567 N, 

568 out_stride_dim, 

569 src_stride_dim, 

570 src_shape_dim, 

571 out_shape_dim, 

572 DIM: tl.constexpr, 

573 USE_MASK: tl.constexpr, 

574 src_stride_0, 

575 src_stride_1, 

576 src_stride_2, 

577 src_stride_3, 

578 src_stride_4, 

579 src_shape_0, 

580 src_shape_1, 

581 src_shape_2, 

582 src_shape_3, 

583 src_shape_4, 

584 idx_stride_0, 

585 idx_stride_1, 

586 idx_stride_2, 

587 idx_stride_3, 

588 idx_stride_4, 

589 out_stride_0, 

590 out_stride_1, 

591 out_stride_2, 

592 out_stride_3, 

593 out_stride_4, 

594 BLOCK: tl.constexpr, 

595 LOOP: tl.constexpr, 

596): 

597 pid = tl.program_id(axis=0) 

598 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK) 

599 

600 for i in range(LOOP): 

601 offsets = (base_offsets + i * BLOCK).to(tl.int64) 

602 mask = offsets < N 

603 

604 remaining = offsets 

605 coord0 = remaining // (src_shape_1 * src_shape_2 * src_shape_3 * src_shape_4) 

606 remaining = remaining % (src_shape_1 * src_shape_2 * src_shape_3 * src_shape_4) 

607 coord1 = remaining // (src_shape_2 * src_shape_3 * src_shape_4) 

608 remaining = remaining % (src_shape_2 * src_shape_3 * src_shape_4) 

609 coord2 = remaining // (src_shape_3 * src_shape_4) 

610 remaining = remaining % (src_shape_3 * src_shape_4) 

611 coord3 = remaining // src_shape_4 

612 coord4 = remaining % src_shape_4 

613 

614 idx_offsets = ( 

615 coord0 * idx_stride_0 

616 + coord1 * idx_stride_1 

617 + coord2 * idx_stride_2 

618 + coord3 * idx_stride_3 

619 + coord4 * idx_stride_4 

620 ) 

621 src_offsets = ( 

622 coord0 * src_stride_0 

623 + coord1 * src_stride_1 

624 + coord2 * src_stride_2 

625 + coord3 * src_stride_3 

626 + coord4 * src_stride_4 

627 ) 

628 

629 idx = tl.load(index_ptr + idx_offsets, mask=mask, other=0).to(tl.int64) 

630 

631 if DIM == 0: 

632 out_offsets = ( 

633 idx * out_stride_0 

634 + coord1 * out_stride_1 

635 + coord2 * out_stride_2 

636 + coord3 * out_stride_3 

637 + coord4 * out_stride_4 

638 ) 

639 elif DIM == 1: 

640 out_offsets = ( 

641 coord0 * out_stride_0 

642 + idx * out_stride_1 

643 + coord2 * out_stride_2 

644 + coord3 * out_stride_3 

645 + coord4 * out_stride_4 

646 ) 

647 elif DIM == 2: 

648 out_offsets = ( 

649 coord0 * out_stride_0 

650 + coord1 * out_stride_1 

651 + idx * out_stride_2 

652 + coord3 * out_stride_3 

653 + coord4 * out_stride_4 

654 ) 

655 elif DIM == 3: 

656 out_offsets = ( 

657 coord0 * out_stride_0 

658 + coord1 * out_stride_1 

659 + coord2 * out_stride_2 

660 + idx * out_stride_3 

661 + coord4 * out_stride_4 

662 ) 

663 else: 

664 out_offsets = ( 

665 coord0 * out_stride_0 

666 + coord1 * out_stride_1 

667 + coord2 * out_stride_2 

668 + coord3 * out_stride_3 

669 + idx * out_stride_4 

670 ) 

671 

672 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32) 

673 

674 tl.atomic_add(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed") 

675 ones_f = tl.full((BLOCK,), 1.0, dtype=tl.float32) 

676 tl.atomic_add(count_ptr + out_offsets, ones_f, mask=mask, sem="relaxed") 

677 

678 if USE_MASK: 

679 ones_i = tl.full((BLOCK,), 1, dtype=tl.int32) 

680 tl.atomic_add(mask_ptr + out_offsets, ones_i, mask=mask, sem="relaxed") 

681 

682 

683@libentry() 

684@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop}) 

685@triton.jit(do_not_specialize=["N"]) 

686def scatter_reduce_amax_kernel( 

687 index_ptr, 

688 src_ptr, 

689 out_ptr, 

690 mask_ptr, 

691 N, 

692 out_stride_dim, 

693 src_stride_dim, 

694 src_shape_dim, 

695 out_shape_dim, 

696 DIM: tl.constexpr, 

697 IS_AMAX: tl.constexpr, 

698 USE_MASK: tl.constexpr, 

699 USE_CAS: tl.constexpr, 

700 src_stride_0, 

701 src_stride_1, 

702 src_stride_2, 

703 src_stride_3, 

704 src_stride_4, 

705 src_shape_0, 

706 src_shape_1, 

707 src_shape_2, 

708 src_shape_3, 

709 src_shape_4, 

710 idx_stride_0, 

711 idx_stride_1, 

712 idx_stride_2, 

713 idx_stride_3, 

714 idx_stride_4, 

715 out_stride_0, 

716 out_stride_1, 

717 out_stride_2, 

718 out_stride_3, 

719 out_stride_4, 

720 BLOCK: tl.constexpr, 

721 LOOP: tl.constexpr, 

722): 

723 pid = tl.program_id(axis=0) 

724 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK) 

725 

726 for i in range(LOOP): 

727 offsets = (base_offsets + i * BLOCK).to(tl.int64) 

728 mask = offsets < N 

729 

730 remaining = offsets 

731 coord0 = remaining // (src_shape_1 * src_shape_2 * src_shape_3 * src_shape_4) 

732 remaining = remaining % (src_shape_1 * src_shape_2 * src_shape_3 * src_shape_4) 

733 coord1 = remaining // (src_shape_2 * src_shape_3 * src_shape_4) 

734 remaining = remaining % (src_shape_2 * src_shape_3 * src_shape_4) 

735 coord2 = remaining // (src_shape_3 * src_shape_4) 

736 remaining = remaining % (src_shape_3 * src_shape_4) 

737 coord3 = remaining // src_shape_4 

738 coord4 = remaining % src_shape_4 

739 

740 idx_offsets = ( 

741 coord0 * idx_stride_0 

742 + coord1 * idx_stride_1 

743 + coord2 * idx_stride_2 

744 + coord3 * idx_stride_3 

745 + coord4 * idx_stride_4 

746 ) 

747 src_offsets = ( 

748 coord0 * src_stride_0 

749 + coord1 * src_stride_1 

750 + coord2 * src_stride_2 

751 + coord3 * src_stride_3 

752 + coord4 * src_stride_4 

753 ) 

754 

755 idx = tl.load(index_ptr + idx_offsets, mask=mask, other=0).to(tl.int64) 

756 

757 if DIM == 0: 

758 out_offsets = ( 

759 idx * out_stride_0 

760 + coord1 * out_stride_1 

761 + coord2 * out_stride_2 

762 + coord3 * out_stride_3 

763 + coord4 * out_stride_4 

764 ) 

765 elif DIM == 1: 

766 out_offsets = ( 

767 coord0 * out_stride_0 

768 + idx * out_stride_1 

769 + coord2 * out_stride_2 

770 + coord3 * out_stride_3 

771 + coord4 * out_stride_4 

772 ) 

773 elif DIM == 2: 

774 out_offsets = ( 

775 coord0 * out_stride_0 

776 + coord1 * out_stride_1 

777 + idx * out_stride_2 

778 + coord3 * out_stride_3 

779 + coord4 * out_stride_4 

780 ) 

781 elif DIM == 3: 

782 out_offsets = ( 

783 coord0 * out_stride_0 

784 + coord1 * out_stride_1 

785 + coord2 * out_stride_2 

786 + idx * out_stride_3 

787 + coord4 * out_stride_4 

788 ) 

789 else: 

790 out_offsets = ( 

791 coord0 * out_stride_0 

792 + coord1 * out_stride_1 

793 + coord2 * out_stride_2 

794 + coord3 * out_stride_3 

795 + idx * out_stride_4 

796 ) 

797 

798 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32) 

799 

800 if USE_CAS: 

801 stop = tl.where(mask, 0, 1).to(tl.int1) 

802 block_stop = False 

803 while not block_stop: 

804 cur_val = tl.load(out_ptr + out_offsets, mask=mask, other=0.0) 

805 if IS_AMAX: 

806 new_val = tl.maximum(cur_val, src_val) 

807 else: 

808 new_val = tl.minimum(cur_val, src_val) 

809 cas_res = tl.atomic_cas( 

810 out_ptr + out_offsets, cur_val, new_val, sem="relaxed" 

811 ) 

812 stop |= cur_val == cas_res 

813 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK 

814 else: 

815 if IS_AMAX: 

816 tl.atomic_max(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed") 

817 else: 

818 tl.atomic_min(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed") 

819 

820 if USE_MASK: 

821 ones = tl.full((BLOCK,), 1, dtype=tl.int32) 

822 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed") 

823 

824 

825# --------------------------------------------------------------------------- 

826# Python entry points 

827# --------------------------------------------------------------------------- 

828 

829 

830def scatter_reduce(inp, dim, index, src, reduce, *, include_self=True): 

831 """Triton-accelerated scatter_reduce operation. 

832 

833 Scatters src values into the output tensor at positions determined by index, 

834 applying the specified reduction. Supports sum, prod, mean, amax, amin. 

835 

836 Args: 

837 inp: Input tensor (1D-5D). 

838 dim: Dimension along which to scatter. 

839 index: Index tensor mapping source elements to output positions. 

840 src: Source tensor containing values to scatter. 

841 reduce: Reduction mode - "sum", "prod", "mean", "amax", or "amin". 

842 include_self: If True, include inp values in the reduction. 

843 

844 Returns: 

845 Output tensor with same shape and dtype as inp. 

846 """ 

847 logger.debug("GEMS SCATTER_REDUCE_TWO") 

848 

849 assert reduce in ( 

850 "sum", 

851 "prod", 

852 "mean", 

853 "amax", 

854 "amin", 

855 ), f"Unsupported reduce: {reduce}" 

856 assert inp.ndim <= 5, f"scatter_reduce supports up to 5D tensors, got {inp.ndim}D" 

857 

858 dim = dim % inp.ndim 

859 padded_dim = dim + (5 - inp.ndim) 

860 

861 out_stride_dim = inp.stride(dim) 

862 out_shape_dim = inp.size(dim) 

863 src_stride_dim = src.stride(dim) 

864 src_shape_dim = src.size(dim) 

865 N = index.numel() 

866 

867 # Avoid double clone: merge contiguous + float32 cast 

868 inp_f32 = inp.to(torch.float32).contiguous() 

869 

870 if include_self: 

871 out = inp_f32.clone() 

872 else: 

873 if reduce in ("sum", "mean"): 

874 out = torch.zeros_like(inp_f32) 

875 elif reduce == "prod": 

876 out = torch.ones_like(inp_f32) 

877 elif reduce == "amax": 

878 out = torch.full_like(inp_f32, float("-inf")) 

879 elif reduce == "amin": 

880 out = torch.full_like(inp_f32, float("inf")) 

881 

882 if N == 0: 

883 return out.to(inp.dtype) if not include_self else inp_f32.to(inp.dtype) 

884 

885 use_mask = not include_self 

886 if use_mask: 

887 reduced_mask = torch.zeros(out.shape, dtype=torch.int32, device=inp.device) 

888 

889 if reduce == "mean": 

890 if include_self: 

891 count = torch.ones_like(out, dtype=torch.float32) 

892 else: 

893 count = torch.zeros_like(out, dtype=torch.float32) 

894 

895 src = src.contiguous() 

896 index = index.contiguous() 

897 

898 # Convert strides/shapes to int64 to avoid overflow in kernel arithmetic 

899 src_shapes = [int(x) for x in _pad5(list(src.shape), 1)] 

900 src_strides_p = [int(x) for x in _pad5(list(src.stride()), 0)] 

901 idx_strides_p = [int(x) for x in _pad5(list(index.stride()), 0)] 

902 out_strides_p = [int(x) for x in _pad5(list(out.stride()), 0)] 

903 

904 grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * meta["LOOP"]),) 

905 

906 dummy_mask = torch.empty(1, dtype=torch.int32, device=inp.device) 

907 mask_ptr = reduced_mask if use_mask else dummy_mask 

908 

909 # Use 2D fast path for 2D tensors (most common case) 

910 use_2d = inp.ndim == 2 

911 

912 # For 2D kernels, use raw dim (0 or 1) instead of padded_dim 

913 dim_2d = dim 

914 

915 with torch_device_fn.device(inp.device): 

916 if reduce == "sum": 

917 if use_2d: 

918 src_ncols = src.shape[1] 

919 out_ncols = out.shape[1] 

920 scatter_reduce_sum_2d_kernel[grid]( 

921 index, 

922 src, 

923 out, 

924 mask_ptr, 

925 N, 

926 src_ncols, 

927 out_ncols, 

928 dim_2d, 

929 use_mask, 

930 ) 

931 else: 

932 scatter_reduce_sum_kernel[grid]( 

933 index, 

934 src, 

935 out, 

936 mask_ptr, 

937 N, 

938 out_stride_dim, 

939 src_stride_dim, 

940 src_shape_dim, 

941 out_shape_dim, 

942 padded_dim, 

943 use_mask, 

944 src_strides_p[0], 

945 src_strides_p[1], 

946 src_strides_p[2], 

947 src_strides_p[3], 

948 src_strides_p[4], 

949 src_shapes[0], 

950 src_shapes[1], 

951 src_shapes[2], 

952 src_shapes[3], 

953 src_shapes[4], 

954 idx_strides_p[0], 

955 idx_strides_p[1], 

956 idx_strides_p[2], 

957 idx_strides_p[3], 

958 idx_strides_p[4], 

959 out_strides_p[0], 

960 out_strides_p[1], 

961 out_strides_p[2], 

962 out_strides_p[3], 

963 out_strides_p[4], 

964 ) 

965 elif reduce == "prod": 

966 if use_2d: 

967 src_ncols = src.shape[1] 

968 out_ncols = out.shape[1] 

969 scatter_reduce_prod_2d_kernel[grid]( 

970 index, 

971 src, 

972 out, 

973 mask_ptr, 

974 N, 

975 src_ncols, 

976 out_ncols, 

977 dim_2d, 

978 use_mask, 

979 ) 

980 else: 

981 scatter_reduce_prod_kernel[grid]( 

982 index, 

983 src, 

984 out, 

985 mask_ptr, 

986 N, 

987 out_stride_dim, 

988 src_stride_dim, 

989 src_shape_dim, 

990 out_shape_dim, 

991 padded_dim, 

992 use_mask, 

993 src_strides_p[0], 

994 src_strides_p[1], 

995 src_strides_p[2], 

996 src_strides_p[3], 

997 src_strides_p[4], 

998 src_shapes[0], 

999 src_shapes[1], 

1000 src_shapes[2], 

1001 src_shapes[3], 

1002 src_shapes[4], 

1003 idx_strides_p[0], 

1004 idx_strides_p[1], 

1005 idx_strides_p[2], 

1006 idx_strides_p[3], 

1007 idx_strides_p[4], 

1008 out_strides_p[0], 

1009 out_strides_p[1], 

1010 out_strides_p[2], 

1011 out_strides_p[3], 

1012 out_strides_p[4], 

1013 ) 

1014 elif reduce == "mean": 

1015 if use_2d: 

1016 src_ncols = src.shape[1] 

1017 out_ncols = out.shape[1] 

1018 scatter_reduce_mean_2d_kernel[grid]( 

1019 index, 

1020 src, 

1021 out, 

1022 count, 

1023 mask_ptr, 

1024 N, 

1025 src_ncols, 

1026 out_ncols, 

1027 dim_2d, 

1028 use_mask, 

1029 ) 

1030 else: 

1031 scatter_reduce_mean_kernel[grid]( 

1032 index, 

1033 src, 

1034 out, 

1035 count, 

1036 mask_ptr, 

1037 N, 

1038 out_stride_dim, 

1039 src_stride_dim, 

1040 src_shape_dim, 

1041 out_shape_dim, 

1042 padded_dim, 

1043 use_mask, 

1044 src_strides_p[0], 

1045 src_strides_p[1], 

1046 src_strides_p[2], 

1047 src_strides_p[3], 

1048 src_strides_p[4], 

1049 src_shapes[0], 

1050 src_shapes[1], 

1051 src_shapes[2], 

1052 src_shapes[3], 

1053 src_shapes[4], 

1054 idx_strides_p[0], 

1055 idx_strides_p[1], 

1056 idx_strides_p[2], 

1057 idx_strides_p[3], 

1058 idx_strides_p[4], 

1059 out_strides_p[0], 

1060 out_strides_p[1], 

1061 out_strides_p[2], 

1062 out_strides_p[3], 

1063 out_strides_p[4], 

1064 ) 

1065 has_contributions = count > 0 

1066 count = torch.clamp(count, min=1.0) 

1067 out = out / count 

1068 out = torch.where(has_contributions, out, inp_f32) 

1069 elif reduce in ("amax", "amin"): 

1070 use_cas = _needs_cas_fallback() 

1071 if use_2d: 

1072 src_ncols = src.shape[1] 

1073 out_ncols = out.shape[1] 

1074 scatter_reduce_amax_2d_kernel[grid]( 

1075 index, 

1076 src, 

1077 out, 

1078 mask_ptr, 

1079 N, 

1080 src_ncols, 

1081 out_ncols, 

1082 dim_2d, 

1083 reduce == "amax", 

1084 use_mask, 

1085 use_cas, 

1086 ) 

1087 else: 

1088 scatter_reduce_amax_kernel[grid]( 

1089 index, 

1090 src, 

1091 out, 

1092 mask_ptr, 

1093 N, 

1094 out_stride_dim, 

1095 src_stride_dim, 

1096 src_shape_dim, 

1097 out_shape_dim, 

1098 padded_dim, 

1099 reduce == "amax", 

1100 use_mask, 

1101 use_cas, 

1102 src_strides_p[0], 

1103 src_strides_p[1], 

1104 src_strides_p[2], 

1105 src_strides_p[3], 

1106 src_strides_p[4], 

1107 src_shapes[0], 

1108 src_shapes[1], 

1109 src_shapes[2], 

1110 src_shapes[3], 

1111 src_shapes[4], 

1112 idx_strides_p[0], 

1113 idx_strides_p[1], 

1114 idx_strides_p[2], 

1115 idx_strides_p[3], 

1116 idx_strides_p[4], 

1117 out_strides_p[0], 

1118 out_strides_p[1], 

1119 out_strides_p[2], 

1120 out_strides_p[3], 

1121 out_strides_p[4], 

1122 ) 

1123 

1124 if use_mask and reduce != "mean": 

1125 unreduced = reduced_mask == 0 

1126 out = torch.where(unreduced, inp_f32, out) 

1127 

1128 return out.to(inp.dtype) 

1129 

1130 

1131def scatter_reduce_(inp, dim, index, src, reduce, *, include_self=True): 

1132 """In-place variant of scatter_reduce. Modifies inp in-place.""" 

1133 logger.debug("GEMS SCATTER_REDUCE_TWO_") 

1134 

1135 result = scatter_reduce(inp, dim, index, src, reduce, include_self=include_self) 

1136 inp.copy_(result) 

1137 return inp 

1138 

1139 

1140def scatter_reduce_out(inp, dim, index, src, reduce, *, include_self=True, out=None): 

1141 """Out-variant of scatter_reduce. Writes result to out tensor if provided.""" 

1142 logger.debug("GEMS SCATTER_REDUCE_TWO_OUT") 

1143 

1144 result = scatter_reduce(inp, dim, index, src, reduce, include_self=include_self) 

1145 if out is not None: 

1146 out.copy_(result) 

1147 return out 

1148 return result