Coverage for src/flag_gems/runtime/backend/_sunrise/ops/scatter_reduce.py: 0%

470 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1"""Triton implementation of torch.scatter_reduce for FlagGems. 

2 

3Supports all reduce modes: sum, prod, mean, amax, amin. 

4Handles 1D-5D tensors with up to 5D coordinate decoding via padding. 

5 

6Vendor compatibility: 

7 - NVIDIA: native atomic_max/min for amax/amin reduce 

8 - Iluvatar: CAS-based fallback for atomic_max/min (no native support) 

9 - Metax: larger BLOCK=256 for better occupancy 

10 

11Performance notes: 

12 - Sum/mean use tl.atomic_add with relaxed semantics for throughput 

13 - Prod uses CAS loop with NaN detection guard (no tl.atomic_mul exists) 

14 - All offset arithmetic uses int64 to avoid overflow for N > 2^31 

15 - LOOP=4: each program processes LOOP*BLOCK elements to amortize launch overhead 

16 - 2D fast path: specialized kernels for 2D tensors avoid 5D coordinate decoding 

17""" 

18 

19import logging 

20 

21import torch 

22import triton 

23import triton.language as tl 

24 

25import flag_gems 

26from flag_gems.runtime import torch_device_fn 

27from flag_gems.utils import libentry 

28 

29logger = logging.getLogger(__name__) 

30 

31 

32def heur_block(args): 

33 """Vendor-aware block size heuristic. 

34 

35 Metax and Iluvatar GPUs benefit from larger blocks (256) for better 

36 occupancy. NVIDIA GPUs default to 128 which balances occupancy and 

37 register pressure. 

38 """ 

39 if flag_gems.vendor_name in ["metax", "iluvatar"]: 

40 return 256 

41 return 128 

42 

43 

44def heur_loop(args): 

45 """Loop unrolling factor. 

46 

47 Each program processes LOOP*BLOCK elements to amortize kernel launch 

48 overhead. LOOP=4 is optimal for Iluvatar BI-V150. 

49 """ 

50 return 4 

51 

52 

53def heur_scan_block(args): 

54 """Source-dimension tile size for deterministic product scan.""" 

55 return 128 

56 

57 

58# --------------------------------------------------------------------------- 

59# Helpers 

60# --------------------------------------------------------------------------- 

61 

62 

63def _pad5(lst, fill): 

64 """Pad a list to exactly 5 elements from the left with `fill`. 

65 

66 This enables uniform 5D coordinate decoding in kernels regardless 

67 of the actual tensor dimensionality (1D-5D). Shapes are padded with 1, 

68 strides with 0. 

69 """ 

70 return [fill] * (5 - len(lst)) + lst if len(lst) < 5 else lst 

71 

72 

73def _needs_cas_fallback(): 

74 """Check if the current vendor needs CAS-based fallback for atomic_max/min. 

75 

76 Iluvatar GPUs lack native tl.atomic_max/min, so we fall back to a 

77 CAS (Compare-And-Swap) loop for amax/amin reduce modes. 

78 """ 

79 return flag_gems.vendor_name in ["iluvatar"] 

80 

81 

82@libentry() 

83@triton.heuristics({"BLOCK": heur_scan_block}) 

84@triton.jit(do_not_specialize=["out_numel"]) 

85def scatter_reduce_prod_scan_kernel( 

86 index_ptr, 

87 src_ptr, 

88 out_ptr, 

89 mask_ptr, 

90 out_numel, 

91 DIM: tl.constexpr, 

92 USE_MASK: tl.constexpr, 

93 src_shape_dim: tl.constexpr, 

94 src_stride_0, 

95 src_stride_1, 

96 src_stride_2, 

97 src_stride_3, 

98 src_stride_4, 

99 idx_shape_0, 

100 idx_shape_1, 

101 idx_shape_2, 

102 idx_shape_3, 

103 idx_shape_4, 

104 src_shape_0, 

105 src_shape_1, 

106 src_shape_2, 

107 src_shape_3, 

108 src_shape_4, 

109 idx_stride_0, 

110 idx_stride_1, 

111 idx_stride_2, 

112 idx_stride_3, 

113 idx_stride_4, 

114 out_shape_0, 

115 out_shape_1, 

116 out_shape_2, 

117 out_shape_3, 

118 out_shape_4, 

119 out_stride_0, 

120 out_stride_1, 

121 out_stride_2, 

122 out_stride_3, 

123 out_stride_4, 

124 BLOCK: tl.constexpr, 

125): 

126 pid = tl.program_id(axis=0).to(tl.int64) 

127 in_bounds = pid < out_numel 

128 

129 remaining = pid 

130 coord0 = remaining // (out_shape_1 * out_shape_2 * out_shape_3 * out_shape_4) 

131 remaining = remaining % (out_shape_1 * out_shape_2 * out_shape_3 * out_shape_4) 

132 coord1 = remaining // (out_shape_2 * out_shape_3 * out_shape_4) 

133 remaining = remaining % (out_shape_2 * out_shape_3 * out_shape_4) 

134 coord2 = remaining // (out_shape_3 * out_shape_4) 

135 remaining = remaining % (out_shape_3 * out_shape_4) 

136 coord3 = remaining // out_shape_4 

137 coord4 = remaining % out_shape_4 

138 

139 out_offset = ( 

140 coord0 * out_stride_0 

141 + coord1 * out_stride_1 

142 + coord2 * out_stride_2 

143 + coord3 * out_stride_3 

144 + coord4 * out_stride_4 

145 ) 

146 idx_full_offset = ( 

147 coord0 * idx_stride_0 

148 + coord1 * idx_stride_1 

149 + coord2 * idx_stride_2 

150 + coord3 * idx_stride_3 

151 + coord4 * idx_stride_4 

152 ) 

153 src_full_offset = ( 

154 coord0 * src_stride_0 

155 + coord1 * src_stride_1 

156 + coord2 * src_stride_2 

157 + coord3 * src_stride_3 

158 + coord4 * src_stride_4 

159 ) 

160 

161 if DIM == 0: 

162 target = coord0 

163 idx_base = idx_full_offset - coord0 * idx_stride_0 

164 src_base = src_full_offset - coord0 * src_stride_0 

165 idx_scan_stride = idx_stride_0 

166 src_scan_stride = src_stride_0 

167 idx_scan_shape = idx_shape_0 

168 valid_other = ( 

169 (coord1 < idx_shape_1) 

170 & (coord2 < idx_shape_2) 

171 & (coord3 < idx_shape_3) 

172 & (coord4 < idx_shape_4) 

173 & (coord1 < src_shape_1) 

174 & (coord2 < src_shape_2) 

175 & (coord3 < src_shape_3) 

176 & (coord4 < src_shape_4) 

177 ) 

178 elif DIM == 1: 

179 target = coord1 

180 idx_base = idx_full_offset - coord1 * idx_stride_1 

181 src_base = src_full_offset - coord1 * src_stride_1 

182 idx_scan_stride = idx_stride_1 

183 src_scan_stride = src_stride_1 

184 idx_scan_shape = idx_shape_1 

185 valid_other = ( 

186 (coord0 < idx_shape_0) 

187 & (coord2 < idx_shape_2) 

188 & (coord3 < idx_shape_3) 

189 & (coord4 < idx_shape_4) 

190 & (coord0 < src_shape_0) 

191 & (coord2 < src_shape_2) 

192 & (coord3 < src_shape_3) 

193 & (coord4 < src_shape_4) 

194 ) 

195 elif DIM == 2: 

196 target = coord2 

197 idx_base = idx_full_offset - coord2 * idx_stride_2 

198 src_base = src_full_offset - coord2 * src_stride_2 

199 idx_scan_stride = idx_stride_2 

200 src_scan_stride = src_stride_2 

201 idx_scan_shape = idx_shape_2 

202 valid_other = ( 

203 (coord0 < idx_shape_0) 

204 & (coord1 < idx_shape_1) 

205 & (coord3 < idx_shape_3) 

206 & (coord4 < idx_shape_4) 

207 & (coord0 < src_shape_0) 

208 & (coord1 < src_shape_1) 

209 & (coord3 < src_shape_3) 

210 & (coord4 < src_shape_4) 

211 ) 

212 elif DIM == 3: 

213 target = coord3 

214 idx_base = idx_full_offset - coord3 * idx_stride_3 

215 src_base = src_full_offset - coord3 * src_stride_3 

216 idx_scan_stride = idx_stride_3 

217 src_scan_stride = src_stride_3 

218 idx_scan_shape = idx_shape_3 

219 valid_other = ( 

220 (coord0 < idx_shape_0) 

221 & (coord1 < idx_shape_1) 

222 & (coord2 < idx_shape_2) 

223 & (coord4 < idx_shape_4) 

224 & (coord0 < src_shape_0) 

225 & (coord1 < src_shape_1) 

226 & (coord2 < src_shape_2) 

227 & (coord4 < src_shape_4) 

228 ) 

229 else: 

230 target = coord4 

231 idx_base = idx_full_offset - coord4 * idx_stride_4 

232 src_base = src_full_offset - coord4 * src_stride_4 

233 idx_scan_stride = idx_stride_4 

234 src_scan_stride = src_stride_4 

235 idx_scan_shape = idx_shape_4 

236 valid_other = ( 

237 (coord0 < idx_shape_0) 

238 & (coord1 < idx_shape_1) 

239 & (coord2 < idx_shape_2) 

240 & (coord3 < idx_shape_3) 

241 & (coord0 < src_shape_0) 

242 & (coord1 < src_shape_1) 

243 & (coord2 < src_shape_2) 

244 & (coord3 < src_shape_3) 

245 ) 

246 

247 lanes = tl.arange(0, BLOCK) 

248 acc = tl.load(out_ptr + out_offset, mask=in_bounds, other=1.0).to(tl.float32) 

249 has_contrib = False 

250 

251 for start in range(0, src_shape_dim, BLOCK): 

252 scan = start + lanes 

253 valid = ( 

254 in_bounds & valid_other & (scan < src_shape_dim) & (scan < idx_scan_shape) 

255 ) 

256 idx_val = tl.load( 

257 index_ptr + idx_base + scan * idx_scan_stride, 

258 mask=valid, 

259 other=-1, 

260 ).to(tl.int64) 

261 match = valid & (idx_val == target) 

262 src_val = tl.load( 

263 src_ptr + src_base + scan * src_scan_stride, 

264 mask=valid, 

265 other=1.0, 

266 ).to(tl.float32) 

267 factors = tl.where(match, src_val, 1.0) 

268 prefix = tl.cumprod(factors, 0) 

269 tile_prod = tl.sum(tl.where(lanes == (BLOCK - 1), prefix, 0.0)) 

270 acc *= tile_prod 

271 has_contrib |= tl.sum(match.to(tl.int32)) > 0 

272 

273 tl.store(out_ptr + out_offset, acc, mask=in_bounds) 

274 if USE_MASK: 

275 tl.store(mask_ptr + out_offset, has_contrib.to(tl.int32), mask=in_bounds) 

276 

277 

278# --------------------------------------------------------------------------- 

279# 2D Fast Path Kernels with LOOP 

280# Specialized for 2D tensors to avoid 5D coordinate decoding overhead. 

281# Uses 1D grid with LOOP=4 to amortize kernel launch overhead. 

282# --------------------------------------------------------------------------- 

283 

284 

285@libentry() 

286@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop}) 

287@triton.jit(do_not_specialize=["N"]) 

288def scatter_reduce_sum_2d_kernel( 

289 index_ptr, 

290 src_ptr, 

291 out_ptr, 

292 mask_ptr, 

293 N, 

294 idx_ncols, 

295 src_ncols, 

296 out_ncols, 

297 DIM: tl.constexpr, 

298 USE_MASK: tl.constexpr, 

299 BLOCK: tl.constexpr, 

300 LOOP: tl.constexpr, 

301): 

302 pid = tl.program_id(axis=0) 

303 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK) 

304 

305 for i in range(LOOP): 

306 offsets = (base_offsets + i * BLOCK).to(tl.int64) 

307 mask = offsets < N 

308 

309 row = offsets // idx_ncols 

310 col = offsets % idx_ncols 

311 

312 if DIM == 0: 

313 src_offsets = row * src_ncols + col 

314 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64) 

315 out_offsets = idx * out_ncols + col 

316 else: 

317 src_offsets = row * src_ncols + col 

318 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64) 

319 out_offsets = row * out_ncols + idx 

320 

321 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32) 

322 tl.atomic_add(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed") 

323 

324 if USE_MASK: 

325 ones = tl.full((BLOCK,), 1, dtype=tl.int32) 

326 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed") 

327 

328 

329@libentry() 

330@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop}) 

331@triton.jit(do_not_specialize=["N"]) 

332def scatter_reduce_prod_2d_kernel( 

333 index_ptr, 

334 src_ptr, 

335 out_ptr, 

336 mask_ptr, 

337 N, 

338 idx_ncols, 

339 src_ncols, 

340 out_ncols, 

341 DIM: tl.constexpr, 

342 USE_MASK: tl.constexpr, 

343 BLOCK: tl.constexpr, 

344 LOOP: tl.constexpr, 

345): 

346 pid = tl.program_id(axis=0) 

347 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK) 

348 

349 for i in range(LOOP): 

350 offsets = (base_offsets + i * BLOCK).to(tl.int64) 

351 mask = offsets < N 

352 

353 row = offsets // idx_ncols 

354 col = offsets % idx_ncols 

355 

356 if DIM == 0: 

357 src_offsets = row * src_ncols + col 

358 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64) 

359 out_offsets = idx * out_ncols + col 

360 else: 

361 src_offsets = row * src_ncols + col 

362 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64) 

363 out_offsets = row * out_ncols + idx 

364 

365 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32) 

366 

367 # CAS loop for product 

368 stop = tl.where(mask, 0, 1).to(tl.int1) 

369 block_stop = False 

370 out_ptr_u32 = (out_ptr + out_offsets).to( 

371 tl.pointer_type(tl.uint32, 1), bitcast=True 

372 ) 

373 while not block_stop: 

374 cur_bits = tl.load(out_ptr_u32, mask=mask, other=0) 

375 cur_val = cur_bits.to(tl.float32, bitcast=True) 

376 new_val = tl.where(stop, cur_val, cur_val * src_val) 

377 is_nan = new_val != new_val 

378 new_val = tl.where(is_nan, src_val, new_val) 

379 new_bits = new_val.to(tl.uint32, bitcast=True) 

380 # Sunrise/PTPU is more stable when product CAS operates on the raw 

381 # float32 bit pattern instead of a floating-pointer CAS. 

382 cas_res = tl.atomic_cas(out_ptr_u32, cur_bits, new_bits, sem="acq_rel") 

383 stop |= (cur_bits == cas_res) | is_nan 

384 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK 

385 

386 if USE_MASK: 

387 ones = tl.full((BLOCK,), 1, dtype=tl.int32) 

388 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed") 

389 

390 

391@libentry() 

392@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop}) 

393@triton.jit(do_not_specialize=["N"]) 

394def scatter_reduce_mean_2d_kernel( 

395 index_ptr, 

396 src_ptr, 

397 out_ptr, 

398 count_ptr, 

399 mask_ptr, 

400 N, 

401 idx_ncols, 

402 src_ncols, 

403 out_ncols, 

404 DIM: tl.constexpr, 

405 USE_MASK: tl.constexpr, 

406 BLOCK: tl.constexpr, 

407 LOOP: tl.constexpr, 

408): 

409 pid = tl.program_id(axis=0) 

410 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK) 

411 

412 for i in range(LOOP): 

413 offsets = (base_offsets + i * BLOCK).to(tl.int64) 

414 mask = offsets < N 

415 

416 row = offsets // idx_ncols 

417 col = offsets % idx_ncols 

418 

419 if DIM == 0: 

420 src_offsets = row * src_ncols + col 

421 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64) 

422 out_offsets = idx * out_ncols + col 

423 else: 

424 src_offsets = row * src_ncols + col 

425 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64) 

426 out_offsets = row * out_ncols + idx 

427 

428 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32) 

429 

430 tl.atomic_add(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed") 

431 ones_f = tl.full((BLOCK,), 1.0, dtype=tl.float32) 

432 tl.atomic_add(count_ptr + out_offsets, ones_f, mask=mask, sem="relaxed") 

433 

434 if USE_MASK: 

435 ones_i = tl.full((BLOCK,), 1, dtype=tl.int32) 

436 tl.atomic_add(mask_ptr + out_offsets, ones_i, mask=mask, sem="relaxed") 

437 

438 

439@libentry() 

440@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop}) 

441@triton.jit(do_not_specialize=["N"]) 

442def scatter_reduce_amax_2d_kernel( 

443 index_ptr, 

444 src_ptr, 

445 out_ptr, 

446 mask_ptr, 

447 N, 

448 idx_ncols, 

449 src_ncols, 

450 out_ncols, 

451 DIM: tl.constexpr, 

452 IS_AMAX: tl.constexpr, 

453 USE_MASK: tl.constexpr, 

454 USE_CAS: tl.constexpr, 

455 BLOCK: tl.constexpr, 

456 LOOP: tl.constexpr, 

457): 

458 pid = tl.program_id(axis=0) 

459 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK) 

460 

461 for i in range(LOOP): 

462 offsets = (base_offsets + i * BLOCK).to(tl.int64) 

463 mask = offsets < N 

464 

465 row = offsets // idx_ncols 

466 col = offsets % idx_ncols 

467 

468 if DIM == 0: 

469 src_offsets = row * src_ncols + col 

470 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64) 

471 out_offsets = idx * out_ncols + col 

472 else: 

473 src_offsets = row * src_ncols + col 

474 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64) 

475 out_offsets = row * out_ncols + idx 

476 

477 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32) 

478 

479 if USE_CAS: 

480 stop = tl.where(mask, 0, 1).to(tl.int1) 

481 block_stop = False 

482 while not block_stop: 

483 cur_val = tl.load(out_ptr + out_offsets, mask=mask, other=0.0) 

484 if IS_AMAX: 

485 new_val = tl.maximum(cur_val, src_val) 

486 else: 

487 new_val = tl.minimum(cur_val, src_val) 

488 cas_res = tl.atomic_cas( 

489 out_ptr + out_offsets, cur_val, new_val, sem="relaxed" 

490 ) 

491 stop |= cur_val == cas_res 

492 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK 

493 else: 

494 if IS_AMAX: 

495 tl.atomic_max(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed") 

496 else: 

497 tl.atomic_min(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed") 

498 

499 if USE_MASK: 

500 ones = tl.full((BLOCK,), 1, dtype=tl.int32) 

501 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed") 

502 

503 

504# --------------------------------------------------------------------------- 

505# Generic 5D Kernels with LOOP optimization 

506# For tensors with ndim != 2. 

507# --------------------------------------------------------------------------- 

508 

509 

510@libentry() 

511@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop}) 

512@triton.jit(do_not_specialize=["N"]) 

513def scatter_reduce_sum_kernel( 

514 index_ptr, 

515 src_ptr, 

516 out_ptr, 

517 mask_ptr, 

518 N, 

519 out_stride_dim, 

520 src_stride_dim, 

521 src_shape_dim, 

522 out_shape_dim, 

523 DIM: tl.constexpr, 

524 USE_MASK: tl.constexpr, 

525 src_stride_0, 

526 src_stride_1, 

527 src_stride_2, 

528 src_stride_3, 

529 src_stride_4, 

530 idx_shape_0, 

531 idx_shape_1, 

532 idx_shape_2, 

533 idx_shape_3, 

534 idx_shape_4, 

535 src_shape_0, 

536 src_shape_1, 

537 src_shape_2, 

538 src_shape_3, 

539 src_shape_4, 

540 idx_stride_0, 

541 idx_stride_1, 

542 idx_stride_2, 

543 idx_stride_3, 

544 idx_stride_4, 

545 out_stride_0, 

546 out_stride_1, 

547 out_stride_2, 

548 out_stride_3, 

549 out_stride_4, 

550 BLOCK: tl.constexpr, 

551 LOOP: tl.constexpr, 

552): 

553 pid = tl.program_id(axis=0) 

554 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK) 

555 

556 for i in range(LOOP): 

557 offsets = (base_offsets + i * BLOCK).to(tl.int64) 

558 mask = offsets < N 

559 

560 remaining = offsets 

561 coord0 = remaining // (idx_shape_1 * idx_shape_2 * idx_shape_3 * idx_shape_4) 

562 remaining = remaining % (idx_shape_1 * idx_shape_2 * idx_shape_3 * idx_shape_4) 

563 coord1 = remaining // (idx_shape_2 * idx_shape_3 * idx_shape_4) 

564 remaining = remaining % (idx_shape_2 * idx_shape_3 * idx_shape_4) 

565 coord2 = remaining // (idx_shape_3 * idx_shape_4) 

566 remaining = remaining % (idx_shape_3 * idx_shape_4) 

567 coord3 = remaining // idx_shape_4 

568 coord4 = remaining % idx_shape_4 

569 

570 idx_offsets = ( 

571 coord0 * idx_stride_0 

572 + coord1 * idx_stride_1 

573 + coord2 * idx_stride_2 

574 + coord3 * idx_stride_3 

575 + coord4 * idx_stride_4 

576 ) 

577 src_offsets = ( 

578 coord0 * src_stride_0 

579 + coord1 * src_stride_1 

580 + coord2 * src_stride_2 

581 + coord3 * src_stride_3 

582 + coord4 * src_stride_4 

583 ) 

584 

585 idx = tl.load(index_ptr + idx_offsets, mask=mask, other=0).to(tl.int64) 

586 

587 if DIM == 0: 

588 out_offsets = ( 

589 idx * out_stride_0 

590 + coord1 * out_stride_1 

591 + coord2 * out_stride_2 

592 + coord3 * out_stride_3 

593 + coord4 * out_stride_4 

594 ) 

595 elif DIM == 1: 

596 out_offsets = ( 

597 coord0 * out_stride_0 

598 + idx * out_stride_1 

599 + coord2 * out_stride_2 

600 + coord3 * out_stride_3 

601 + coord4 * out_stride_4 

602 ) 

603 elif DIM == 2: 

604 out_offsets = ( 

605 coord0 * out_stride_0 

606 + coord1 * out_stride_1 

607 + idx * out_stride_2 

608 + coord3 * out_stride_3 

609 + coord4 * out_stride_4 

610 ) 

611 elif DIM == 3: 

612 out_offsets = ( 

613 coord0 * out_stride_0 

614 + coord1 * out_stride_1 

615 + coord2 * out_stride_2 

616 + idx * out_stride_3 

617 + coord4 * out_stride_4 

618 ) 

619 else: 

620 out_offsets = ( 

621 coord0 * out_stride_0 

622 + coord1 * out_stride_1 

623 + coord2 * out_stride_2 

624 + coord3 * out_stride_3 

625 + idx * out_stride_4 

626 ) 

627 

628 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32) 

629 tl.atomic_add(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed") 

630 

631 if USE_MASK: 

632 ones = tl.full((BLOCK,), 1, dtype=tl.int32) 

633 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed") 

634 

635 

636@libentry() 

637@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop}) 

638@triton.jit(do_not_specialize=["N"]) 

639def scatter_reduce_prod_kernel( 

640 index_ptr, 

641 src_ptr, 

642 out_ptr, 

643 mask_ptr, 

644 N, 

645 out_stride_dim, 

646 src_stride_dim, 

647 src_shape_dim, 

648 out_shape_dim, 

649 DIM: tl.constexpr, 

650 USE_MASK: tl.constexpr, 

651 src_stride_0, 

652 src_stride_1, 

653 src_stride_2, 

654 src_stride_3, 

655 src_stride_4, 

656 idx_shape_0, 

657 idx_shape_1, 

658 idx_shape_2, 

659 idx_shape_3, 

660 idx_shape_4, 

661 src_shape_0, 

662 src_shape_1, 

663 src_shape_2, 

664 src_shape_3, 

665 src_shape_4, 

666 idx_stride_0, 

667 idx_stride_1, 

668 idx_stride_2, 

669 idx_stride_3, 

670 idx_stride_4, 

671 out_stride_0, 

672 out_stride_1, 

673 out_stride_2, 

674 out_stride_3, 

675 out_stride_4, 

676 BLOCK: tl.constexpr, 

677 LOOP: tl.constexpr, 

678): 

679 pid = tl.program_id(axis=0) 

680 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK) 

681 

682 for i in range(LOOP): 

683 offsets = (base_offsets + i * BLOCK).to(tl.int64) 

684 mask = offsets < N 

685 

686 remaining = offsets 

687 coord0 = remaining // (idx_shape_1 * idx_shape_2 * idx_shape_3 * idx_shape_4) 

688 remaining = remaining % (idx_shape_1 * idx_shape_2 * idx_shape_3 * idx_shape_4) 

689 coord1 = remaining // (idx_shape_2 * idx_shape_3 * idx_shape_4) 

690 remaining = remaining % (idx_shape_2 * idx_shape_3 * idx_shape_4) 

691 coord2 = remaining // (idx_shape_3 * idx_shape_4) 

692 remaining = remaining % (idx_shape_3 * idx_shape_4) 

693 coord3 = remaining // idx_shape_4 

694 coord4 = remaining % idx_shape_4 

695 

696 idx_offsets = ( 

697 coord0 * idx_stride_0 

698 + coord1 * idx_stride_1 

699 + coord2 * idx_stride_2 

700 + coord3 * idx_stride_3 

701 + coord4 * idx_stride_4 

702 ) 

703 src_offsets = ( 

704 coord0 * src_stride_0 

705 + coord1 * src_stride_1 

706 + coord2 * src_stride_2 

707 + coord3 * src_stride_3 

708 + coord4 * src_stride_4 

709 ) 

710 

711 idx = tl.load(index_ptr + idx_offsets, mask=mask, other=0).to(tl.int64) 

712 

713 if DIM == 0: 

714 out_offsets = ( 

715 idx * out_stride_0 

716 + coord1 * out_stride_1 

717 + coord2 * out_stride_2 

718 + coord3 * out_stride_3 

719 + coord4 * out_stride_4 

720 ) 

721 elif DIM == 1: 

722 out_offsets = ( 

723 coord0 * out_stride_0 

724 + idx * out_stride_1 

725 + coord2 * out_stride_2 

726 + coord3 * out_stride_3 

727 + coord4 * out_stride_4 

728 ) 

729 elif DIM == 2: 

730 out_offsets = ( 

731 coord0 * out_stride_0 

732 + coord1 * out_stride_1 

733 + idx * out_stride_2 

734 + coord3 * out_stride_3 

735 + coord4 * out_stride_4 

736 ) 

737 elif DIM == 3: 

738 out_offsets = ( 

739 coord0 * out_stride_0 

740 + coord1 * out_stride_1 

741 + coord2 * out_stride_2 

742 + idx * out_stride_3 

743 + coord4 * out_stride_4 

744 ) 

745 else: 

746 out_offsets = ( 

747 coord0 * out_stride_0 

748 + coord1 * out_stride_1 

749 + coord2 * out_stride_2 

750 + coord3 * out_stride_3 

751 + idx * out_stride_4 

752 ) 

753 

754 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32) 

755 

756 # CAS loop for product. NaN/Inf guard: if cur_val is NaN, mark as done 

757 # to prevent infinite spin (NaN != NaN causes CAS to always fail). 

758 stop = tl.where(mask, 0, 1).to(tl.int1) 

759 block_stop = False 

760 out_ptr_u32 = (out_ptr + out_offsets).to( 

761 tl.pointer_type(tl.uint32, 1), bitcast=True 

762 ) 

763 while not block_stop: 

764 cur_bits = tl.load(out_ptr_u32, mask=mask, other=0) 

765 cur_val = cur_bits.to(tl.float32, bitcast=True) 

766 new_val = tl.where(stop, cur_val, cur_val * src_val) 

767 # Detect NaN: if new_val != new_val (NaN check), use src_val directly 

768 is_nan = new_val != new_val 

769 new_val = tl.where(is_nan, src_val, new_val) 

770 new_bits = new_val.to(tl.uint32, bitcast=True) 

771 # Sunrise/PTPU is more stable when product CAS operates on the raw 

772 # float32 bit pattern instead of a floating-pointer CAS. 

773 cas_res = tl.atomic_cas(out_ptr_u32, cur_bits, new_bits, sem="acq_rel") 

774 # Mark done if CAS succeeded OR if value is NaN (can't recover) 

775 stop |= (cur_bits == cas_res) | is_nan 

776 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK 

777 

778 if USE_MASK: 

779 ones = tl.full((BLOCK,), 1, dtype=tl.int32) 

780 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed") 

781 

782 

783@libentry() 

784@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop}) 

785@triton.jit(do_not_specialize=["N"]) 

786def scatter_reduce_mean_kernel( 

787 index_ptr, 

788 src_ptr, 

789 out_ptr, 

790 count_ptr, 

791 mask_ptr, 

792 N, 

793 out_stride_dim, 

794 src_stride_dim, 

795 src_shape_dim, 

796 out_shape_dim, 

797 DIM: tl.constexpr, 

798 USE_MASK: tl.constexpr, 

799 src_stride_0, 

800 src_stride_1, 

801 src_stride_2, 

802 src_stride_3, 

803 src_stride_4, 

804 idx_shape_0, 

805 idx_shape_1, 

806 idx_shape_2, 

807 idx_shape_3, 

808 idx_shape_4, 

809 src_shape_0, 

810 src_shape_1, 

811 src_shape_2, 

812 src_shape_3, 

813 src_shape_4, 

814 idx_stride_0, 

815 idx_stride_1, 

816 idx_stride_2, 

817 idx_stride_3, 

818 idx_stride_4, 

819 out_stride_0, 

820 out_stride_1, 

821 out_stride_2, 

822 out_stride_3, 

823 out_stride_4, 

824 BLOCK: tl.constexpr, 

825 LOOP: tl.constexpr, 

826): 

827 pid = tl.program_id(axis=0) 

828 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK) 

829 

830 for i in range(LOOP): 

831 offsets = (base_offsets + i * BLOCK).to(tl.int64) 

832 mask = offsets < N 

833 

834 remaining = offsets 

835 coord0 = remaining // (idx_shape_1 * idx_shape_2 * idx_shape_3 * idx_shape_4) 

836 remaining = remaining % (idx_shape_1 * idx_shape_2 * idx_shape_3 * idx_shape_4) 

837 coord1 = remaining // (idx_shape_2 * idx_shape_3 * idx_shape_4) 

838 remaining = remaining % (idx_shape_2 * idx_shape_3 * idx_shape_4) 

839 coord2 = remaining // (idx_shape_3 * idx_shape_4) 

840 remaining = remaining % (idx_shape_3 * idx_shape_4) 

841 coord3 = remaining // idx_shape_4 

842 coord4 = remaining % idx_shape_4 

843 

844 idx_offsets = ( 

845 coord0 * idx_stride_0 

846 + coord1 * idx_stride_1 

847 + coord2 * idx_stride_2 

848 + coord3 * idx_stride_3 

849 + coord4 * idx_stride_4 

850 ) 

851 src_offsets = ( 

852 coord0 * src_stride_0 

853 + coord1 * src_stride_1 

854 + coord2 * src_stride_2 

855 + coord3 * src_stride_3 

856 + coord4 * src_stride_4 

857 ) 

858 

859 idx = tl.load(index_ptr + idx_offsets, mask=mask, other=0).to(tl.int64) 

860 

861 if DIM == 0: 

862 out_offsets = ( 

863 idx * out_stride_0 

864 + coord1 * out_stride_1 

865 + coord2 * out_stride_2 

866 + coord3 * out_stride_3 

867 + coord4 * out_stride_4 

868 ) 

869 elif DIM == 1: 

870 out_offsets = ( 

871 coord0 * out_stride_0 

872 + idx * out_stride_1 

873 + coord2 * out_stride_2 

874 + coord3 * out_stride_3 

875 + coord4 * out_stride_4 

876 ) 

877 elif DIM == 2: 

878 out_offsets = ( 

879 coord0 * out_stride_0 

880 + coord1 * out_stride_1 

881 + idx * out_stride_2 

882 + coord3 * out_stride_3 

883 + coord4 * out_stride_4 

884 ) 

885 elif DIM == 3: 

886 out_offsets = ( 

887 coord0 * out_stride_0 

888 + coord1 * out_stride_1 

889 + coord2 * out_stride_2 

890 + idx * out_stride_3 

891 + coord4 * out_stride_4 

892 ) 

893 else: 

894 out_offsets = ( 

895 coord0 * out_stride_0 

896 + coord1 * out_stride_1 

897 + coord2 * out_stride_2 

898 + coord3 * out_stride_3 

899 + idx * out_stride_4 

900 ) 

901 

902 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32) 

903 

904 tl.atomic_add(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed") 

905 ones_f = tl.full((BLOCK,), 1.0, dtype=tl.float32) 

906 tl.atomic_add(count_ptr + out_offsets, ones_f, mask=mask, sem="relaxed") 

907 

908 if USE_MASK: 

909 ones_i = tl.full((BLOCK,), 1, dtype=tl.int32) 

910 tl.atomic_add(mask_ptr + out_offsets, ones_i, mask=mask, sem="relaxed") 

911 

912 

913@libentry() 

914@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop}) 

915@triton.jit(do_not_specialize=["N"]) 

916def scatter_reduce_amax_kernel( 

917 index_ptr, 

918 src_ptr, 

919 out_ptr, 

920 mask_ptr, 

921 N, 

922 out_stride_dim, 

923 src_stride_dim, 

924 src_shape_dim, 

925 out_shape_dim, 

926 DIM: tl.constexpr, 

927 IS_AMAX: tl.constexpr, 

928 USE_MASK: tl.constexpr, 

929 USE_CAS: tl.constexpr, 

930 src_stride_0, 

931 src_stride_1, 

932 src_stride_2, 

933 src_stride_3, 

934 src_stride_4, 

935 idx_shape_0, 

936 idx_shape_1, 

937 idx_shape_2, 

938 idx_shape_3, 

939 idx_shape_4, 

940 src_shape_0, 

941 src_shape_1, 

942 src_shape_2, 

943 src_shape_3, 

944 src_shape_4, 

945 idx_stride_0, 

946 idx_stride_1, 

947 idx_stride_2, 

948 idx_stride_3, 

949 idx_stride_4, 

950 out_stride_0, 

951 out_stride_1, 

952 out_stride_2, 

953 out_stride_3, 

954 out_stride_4, 

955 BLOCK: tl.constexpr, 

956 LOOP: tl.constexpr, 

957): 

958 pid = tl.program_id(axis=0) 

959 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK) 

960 

961 for i in range(LOOP): 

962 offsets = (base_offsets + i * BLOCK).to(tl.int64) 

963 mask = offsets < N 

964 

965 remaining = offsets 

966 coord0 = remaining // (idx_shape_1 * idx_shape_2 * idx_shape_3 * idx_shape_4) 

967 remaining = remaining % (idx_shape_1 * idx_shape_2 * idx_shape_3 * idx_shape_4) 

968 coord1 = remaining // (idx_shape_2 * idx_shape_3 * idx_shape_4) 

969 remaining = remaining % (idx_shape_2 * idx_shape_3 * idx_shape_4) 

970 coord2 = remaining // (idx_shape_3 * idx_shape_4) 

971 remaining = remaining % (idx_shape_3 * idx_shape_4) 

972 coord3 = remaining // idx_shape_4 

973 coord4 = remaining % idx_shape_4 

974 

975 idx_offsets = ( 

976 coord0 * idx_stride_0 

977 + coord1 * idx_stride_1 

978 + coord2 * idx_stride_2 

979 + coord3 * idx_stride_3 

980 + coord4 * idx_stride_4 

981 ) 

982 src_offsets = ( 

983 coord0 * src_stride_0 

984 + coord1 * src_stride_1 

985 + coord2 * src_stride_2 

986 + coord3 * src_stride_3 

987 + coord4 * src_stride_4 

988 ) 

989 

990 idx = tl.load(index_ptr + idx_offsets, mask=mask, other=0).to(tl.int64) 

991 

992 if DIM == 0: 

993 out_offsets = ( 

994 idx * out_stride_0 

995 + coord1 * out_stride_1 

996 + coord2 * out_stride_2 

997 + coord3 * out_stride_3 

998 + coord4 * out_stride_4 

999 ) 

1000 elif DIM == 1: 

1001 out_offsets = ( 

1002 coord0 * out_stride_0 

1003 + idx * out_stride_1 

1004 + coord2 * out_stride_2 

1005 + coord3 * out_stride_3 

1006 + coord4 * out_stride_4 

1007 ) 

1008 elif DIM == 2: 

1009 out_offsets = ( 

1010 coord0 * out_stride_0 

1011 + coord1 * out_stride_1 

1012 + idx * out_stride_2 

1013 + coord3 * out_stride_3 

1014 + coord4 * out_stride_4 

1015 ) 

1016 elif DIM == 3: 

1017 out_offsets = ( 

1018 coord0 * out_stride_0 

1019 + coord1 * out_stride_1 

1020 + coord2 * out_stride_2 

1021 + idx * out_stride_3 

1022 + coord4 * out_stride_4 

1023 ) 

1024 else: 

1025 out_offsets = ( 

1026 coord0 * out_stride_0 

1027 + coord1 * out_stride_1 

1028 + coord2 * out_stride_2 

1029 + coord3 * out_stride_3 

1030 + idx * out_stride_4 

1031 ) 

1032 

1033 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32) 

1034 

1035 if USE_CAS: 

1036 stop = tl.where(mask, 0, 1).to(tl.int1) 

1037 block_stop = False 

1038 while not block_stop: 

1039 cur_val = tl.load(out_ptr + out_offsets, mask=mask, other=0.0) 

1040 if IS_AMAX: 

1041 new_val = tl.maximum(cur_val, src_val) 

1042 else: 

1043 new_val = tl.minimum(cur_val, src_val) 

1044 cas_res = tl.atomic_cas( 

1045 out_ptr + out_offsets, cur_val, new_val, sem="relaxed" 

1046 ) 

1047 stop |= cur_val == cas_res 

1048 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK 

1049 else: 

1050 if IS_AMAX: 

1051 tl.atomic_max(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed") 

1052 else: 

1053 tl.atomic_min(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed") 

1054 

1055 if USE_MASK: 

1056 ones = tl.full((BLOCK,), 1, dtype=tl.int32) 

1057 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed") 

1058 

1059 

1060# --------------------------------------------------------------------------- 

1061# Python entry points 

1062# --------------------------------------------------------------------------- 

1063 

1064 

1065def scatter_reduce(inp, dim, index, src, reduce, *, include_self=True): 

1066 """Triton-accelerated scatter_reduce operation. 

1067 

1068 Scatters src values into the output tensor at positions determined by index, 

1069 applying the specified reduction. Supports sum, prod, mean, amax, amin. 

1070 

1071 Args: 

1072 inp: Input tensor (1D-5D). 

1073 dim: Dimension along which to scatter. 

1074 index: Index tensor mapping source elements to output positions. 

1075 src: Source tensor containing values to scatter. 

1076 reduce: Reduction mode - "sum", "prod", "mean", "amax", or "amin". 

1077 include_self: If True, include inp values in the reduction. 

1078 

1079 Returns: 

1080 Output tensor with same shape and dtype as inp. 

1081 """ 

1082 logger.debug("GEMS SCATTER_REDUCE_TWO") 

1083 

1084 assert reduce in ( 

1085 "sum", 

1086 "prod", 

1087 "mean", 

1088 "amax", 

1089 "amin", 

1090 ), f"Unsupported reduce: {reduce}" 

1091 assert inp.ndim <= 5, f"scatter_reduce supports up to 5D tensors, got {inp.ndim}D" 

1092 

1093 dim = dim % inp.ndim 

1094 padded_dim = dim + (5 - inp.ndim) 

1095 

1096 out_stride_dim = inp.stride(dim) 

1097 out_shape_dim = inp.size(dim) 

1098 src_stride_dim = src.stride(dim) 

1099 src_shape_dim = src.size(dim) 

1100 N = index.numel() 

1101 

1102 # Avoid double clone: merge contiguous + float32 cast 

1103 inp_f32 = inp.to(torch.float32).contiguous() 

1104 

1105 if include_self: 

1106 out = inp_f32.clone() 

1107 else: 

1108 if reduce in ("sum", "mean"): 

1109 out = torch.zeros_like(inp_f32) 

1110 elif reduce == "prod": 

1111 out = torch.ones_like(inp_f32) 

1112 elif reduce == "amax": 

1113 out = torch.full_like(inp_f32, float("-inf")) 

1114 elif reduce == "amin": 

1115 out = torch.full_like(inp_f32, float("inf")) 

1116 

1117 if N == 0: 

1118 return out.to(inp.dtype) if not include_self else inp_f32.to(inp.dtype) 

1119 

1120 use_mask = not include_self 

1121 if use_mask: 

1122 reduced_mask = torch.zeros(out.shape, dtype=torch.int32, device=inp.device) 

1123 

1124 if reduce == "mean": 

1125 if include_self: 

1126 count = torch.ones_like(out, dtype=torch.float32) 

1127 else: 

1128 count = torch.zeros_like(out, dtype=torch.float32) 

1129 

1130 src = src.contiguous() 

1131 index = index.contiguous() 

1132 

1133 # Convert strides/shapes to int64 to avoid overflow in kernel arithmetic 

1134 idx_shapes = [int(x) for x in _pad5(list(index.shape), 1)] 

1135 src_shapes = [int(x) for x in _pad5(list(src.shape), 1)] 

1136 src_strides_p = [int(x) for x in _pad5(list(src.stride()), 0)] 

1137 idx_strides_p = [int(x) for x in _pad5(list(index.stride()), 0)] 

1138 out_shapes = [int(x) for x in _pad5(list(out.shape), 1)] 

1139 out_strides_p = [int(x) for x in _pad5(list(out.stride()), 0)] 

1140 

1141 grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * meta["LOOP"]),) 

1142 

1143 dummy_mask = torch.empty(1, dtype=torch.int32, device=inp.device) 

1144 mask_ptr = reduced_mask if use_mask else dummy_mask 

1145 

1146 # Use 2D fast path for 2D tensors (most common case) 

1147 use_2d = inp.ndim == 2 

1148 

1149 # For 2D kernels, use raw dim (0 or 1) instead of padded_dim 

1150 dim_2d = dim 

1151 

1152 with torch_device_fn.device(inp.device): 

1153 if reduce == "sum": 

1154 if use_2d: 

1155 idx_ncols = index.shape[1] 

1156 src_ncols = src.shape[1] 

1157 out_ncols = out.shape[1] 

1158 scatter_reduce_sum_2d_kernel[grid]( 

1159 index, 

1160 src, 

1161 out, 

1162 mask_ptr, 

1163 N, 

1164 idx_ncols, 

1165 src_ncols, 

1166 out_ncols, 

1167 dim_2d, 

1168 use_mask, 

1169 ) 

1170 else: 

1171 scatter_reduce_sum_kernel[grid]( 

1172 index, 

1173 src, 

1174 out, 

1175 mask_ptr, 

1176 N, 

1177 out_stride_dim, 

1178 src_stride_dim, 

1179 src_shape_dim, 

1180 out_shape_dim, 

1181 padded_dim, 

1182 use_mask, 

1183 src_strides_p[0], 

1184 src_strides_p[1], 

1185 src_strides_p[2], 

1186 src_strides_p[3], 

1187 src_strides_p[4], 

1188 idx_shapes[0], 

1189 idx_shapes[1], 

1190 idx_shapes[2], 

1191 idx_shapes[3], 

1192 idx_shapes[4], 

1193 src_shapes[0], 

1194 src_shapes[1], 

1195 src_shapes[2], 

1196 src_shapes[3], 

1197 src_shapes[4], 

1198 idx_strides_p[0], 

1199 idx_strides_p[1], 

1200 idx_strides_p[2], 

1201 idx_strides_p[3], 

1202 idx_strides_p[4], 

1203 out_strides_p[0], 

1204 out_strides_p[1], 

1205 out_strides_p[2], 

1206 out_strides_p[3], 

1207 out_strides_p[4], 

1208 ) 

1209 elif reduce == "prod": 

1210 scan_grid = (out.numel(),) 

1211 scatter_reduce_prod_scan_kernel[scan_grid]( 

1212 index, 

1213 src, 

1214 out, 

1215 mask_ptr, 

1216 out.numel(), 

1217 padded_dim, 

1218 use_mask, 

1219 src_shape_dim, 

1220 src_strides_p[0], 

1221 src_strides_p[1], 

1222 src_strides_p[2], 

1223 src_strides_p[3], 

1224 src_strides_p[4], 

1225 idx_shapes[0], 

1226 idx_shapes[1], 

1227 idx_shapes[2], 

1228 idx_shapes[3], 

1229 idx_shapes[4], 

1230 src_shapes[0], 

1231 src_shapes[1], 

1232 src_shapes[2], 

1233 src_shapes[3], 

1234 src_shapes[4], 

1235 idx_strides_p[0], 

1236 idx_strides_p[1], 

1237 idx_strides_p[2], 

1238 idx_strides_p[3], 

1239 idx_strides_p[4], 

1240 out_shapes[0], 

1241 out_shapes[1], 

1242 out_shapes[2], 

1243 out_shapes[3], 

1244 out_shapes[4], 

1245 out_strides_p[0], 

1246 out_strides_p[1], 

1247 out_strides_p[2], 

1248 out_strides_p[3], 

1249 out_strides_p[4], 

1250 ) 

1251 elif reduce == "mean": 

1252 if use_2d: 

1253 idx_ncols = index.shape[1] 

1254 src_ncols = src.shape[1] 

1255 out_ncols = out.shape[1] 

1256 scatter_reduce_mean_2d_kernel[grid]( 

1257 index, 

1258 src, 

1259 out, 

1260 count, 

1261 mask_ptr, 

1262 N, 

1263 idx_ncols, 

1264 src_ncols, 

1265 out_ncols, 

1266 dim_2d, 

1267 use_mask, 

1268 ) 

1269 else: 

1270 scatter_reduce_mean_kernel[grid]( 

1271 index, 

1272 src, 

1273 out, 

1274 count, 

1275 mask_ptr, 

1276 N, 

1277 out_stride_dim, 

1278 src_stride_dim, 

1279 src_shape_dim, 

1280 out_shape_dim, 

1281 padded_dim, 

1282 use_mask, 

1283 src_strides_p[0], 

1284 src_strides_p[1], 

1285 src_strides_p[2], 

1286 src_strides_p[3], 

1287 src_strides_p[4], 

1288 idx_shapes[0], 

1289 idx_shapes[1], 

1290 idx_shapes[2], 

1291 idx_shapes[3], 

1292 idx_shapes[4], 

1293 src_shapes[0], 

1294 src_shapes[1], 

1295 src_shapes[2], 

1296 src_shapes[3], 

1297 src_shapes[4], 

1298 idx_strides_p[0], 

1299 idx_strides_p[1], 

1300 idx_strides_p[2], 

1301 idx_strides_p[3], 

1302 idx_strides_p[4], 

1303 out_strides_p[0], 

1304 out_strides_p[1], 

1305 out_strides_p[2], 

1306 out_strides_p[3], 

1307 out_strides_p[4], 

1308 ) 

1309 has_contributions = count > 0 

1310 count = torch.clamp(count, min=1.0) 

1311 out = out / count 

1312 out = torch.where(has_contributions, out, inp_f32) 

1313 elif reduce in ("amax", "amin"): 

1314 use_cas = _needs_cas_fallback() 

1315 if use_2d: 

1316 idx_ncols = index.shape[1] 

1317 src_ncols = src.shape[1] 

1318 out_ncols = out.shape[1] 

1319 scatter_reduce_amax_2d_kernel[grid]( 

1320 index, 

1321 src, 

1322 out, 

1323 mask_ptr, 

1324 N, 

1325 idx_ncols, 

1326 src_ncols, 

1327 out_ncols, 

1328 dim_2d, 

1329 reduce == "amax", 

1330 use_mask, 

1331 use_cas, 

1332 ) 

1333 else: 

1334 scatter_reduce_amax_kernel[grid]( 

1335 index, 

1336 src, 

1337 out, 

1338 mask_ptr, 

1339 N, 

1340 out_stride_dim, 

1341 src_stride_dim, 

1342 src_shape_dim, 

1343 out_shape_dim, 

1344 padded_dim, 

1345 reduce == "amax", 

1346 use_mask, 

1347 use_cas, 

1348 src_strides_p[0], 

1349 src_strides_p[1], 

1350 src_strides_p[2], 

1351 src_strides_p[3], 

1352 src_strides_p[4], 

1353 idx_shapes[0], 

1354 idx_shapes[1], 

1355 idx_shapes[2], 

1356 idx_shapes[3], 

1357 idx_shapes[4], 

1358 src_shapes[0], 

1359 src_shapes[1], 

1360 src_shapes[2], 

1361 src_shapes[3], 

1362 src_shapes[4], 

1363 idx_strides_p[0], 

1364 idx_strides_p[1], 

1365 idx_strides_p[2], 

1366 idx_strides_p[3], 

1367 idx_strides_p[4], 

1368 out_strides_p[0], 

1369 out_strides_p[1], 

1370 out_strides_p[2], 

1371 out_strides_p[3], 

1372 out_strides_p[4], 

1373 ) 

1374 

1375 if use_mask and reduce != "mean": 

1376 unreduced = reduced_mask == 0 

1377 out = torch.where(unreduced, inp_f32, out) 

1378 

1379 return out.to(inp.dtype) 

1380 

1381 

1382def scatter_reduce_(inp, dim, index, src, reduce, *, include_self=True): 

1383 """In-place variant of scatter_reduce. Modifies inp in-place.""" 

1384 logger.debug("GEMS SCATTER_REDUCE_TWO_") 

1385 

1386 result = scatter_reduce(inp, dim, index, src, reduce, include_self=include_self) 

1387 inp.copy_(result) 

1388 return inp 

1389 

1390 

1391def scatter_reduce_out(inp, dim, index, src, reduce, *, include_self=True, out=None): 

1392 """Out-variant of scatter_reduce. Writes result to out tensor if provided.""" 

1393 logger.debug("GEMS SCATTER_REDUCE_TWO_OUT") 

1394 

1395 result = scatter_reduce(inp, dim, index, src, reduce, include_self=include_self) 

1396 if out is not None: 

1397 out.copy_(result) 

1398 return out 

1399 return result