Coverage for src/flag_gems/ops/scatter_reduce.py: 34%
386 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1"""Triton implementation of torch.scatter_reduce for FlagGems.
3Supports all reduce modes: sum, prod, mean, amax, amin.
4Handles 1D-5D tensors with up to 5D coordinate decoding via padding.
6Vendor compatibility:
7 - NVIDIA: native atomic_max/min for amax/amin reduce
8 - Iluvatar: CAS-based fallback for atomic_max/min (no native support)
9 - Metax: larger BLOCK=256 for better occupancy
11Performance notes:
12 - Sum/mean use tl.atomic_add with relaxed semantics for throughput
13 - Prod uses CAS loop with NaN detection guard (no tl.atomic_mul exists)
14 - All offset arithmetic uses int64 to avoid overflow for N > 2^31
15 - LOOP=4: each program processes LOOP*BLOCK elements to amortize launch overhead
16 - 2D fast path: specialized kernels for 2D tensors avoid 5D coordinate decoding
17"""
19import logging
21import torch
22import triton
23import triton.language as tl
25import flag_gems
26from flag_gems.runtime import torch_device_fn
27from flag_gems.utils import libentry
29logger = logging.getLogger(__name__)
32def heur_block(args):
33 """Vendor-aware block size heuristic.
35 Metax and Iluvatar GPUs benefit from larger blocks (256) for better
36 occupancy. NVIDIA GPUs default to 128 which balances occupancy and
37 register pressure.
38 """
39 if flag_gems.vendor_name in ["metax", "iluvatar"]:
40 return 256
41 return 128
44def heur_loop(args):
45 """Loop unrolling factor.
47 Each program processes LOOP*BLOCK elements to amortize kernel launch
48 overhead. LOOP=4 is optimal for Iluvatar BI-V150.
49 """
50 return 4
53# ---------------------------------------------------------------------------
54# Helpers
55# ---------------------------------------------------------------------------
58def _pad5(lst, fill):
59 """Pad a list to exactly 5 elements from the left with `fill`.
61 This enables uniform 5D coordinate decoding in kernels regardless
62 of the actual tensor dimensionality (1D-5D). Shapes are padded with 1,
63 strides with 0.
64 """
65 return [fill] * (5 - len(lst)) + lst if len(lst) < 5 else lst
68def _needs_cas_fallback():
69 """Check if the current vendor needs CAS-based fallback for atomic_max/min.
71 Iluvatar GPUs lack native tl.atomic_max/min, so we fall back to a
72 CAS (Compare-And-Swap) loop for amax/amin reduce modes.
73 """
74 return flag_gems.vendor_name in ["iluvatar"]
77# ---------------------------------------------------------------------------
78# 2D Fast Path Kernels with LOOP
79# Specialized for 2D tensors to avoid 5D coordinate decoding overhead.
80# Uses 1D grid with LOOP=4 to amortize kernel launch overhead.
81# ---------------------------------------------------------------------------
84@libentry()
85@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop})
86@triton.jit(do_not_specialize=["N"])
87def scatter_reduce_sum_2d_kernel(
88 index_ptr,
89 src_ptr,
90 out_ptr,
91 mask_ptr,
92 N,
93 src_ncols,
94 out_ncols,
95 DIM: tl.constexpr,
96 USE_MASK: tl.constexpr,
97 BLOCK: tl.constexpr,
98 LOOP: tl.constexpr,
99):
100 pid = tl.program_id(axis=0)
101 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK)
103 for i in range(LOOP):
104 offsets = (base_offsets + i * BLOCK).to(tl.int64)
105 mask = offsets < N
107 row = offsets // src_ncols
108 col = offsets % src_ncols
110 if DIM == 0:
111 src_offsets = row * src_ncols + col
112 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64)
113 out_offsets = idx * out_ncols + col
114 else:
115 src_offsets = row * src_ncols + col
116 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64)
117 out_offsets = row * out_ncols + idx
119 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32)
120 tl.atomic_add(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed")
122 if USE_MASK:
123 ones = tl.full((BLOCK,), 1, dtype=tl.int32)
124 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed")
127@libentry()
128@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop})
129@triton.jit(do_not_specialize=["N"])
130def scatter_reduce_prod_2d_kernel(
131 index_ptr,
132 src_ptr,
133 out_ptr,
134 mask_ptr,
135 N,
136 src_ncols,
137 out_ncols,
138 DIM: tl.constexpr,
139 USE_MASK: tl.constexpr,
140 BLOCK: tl.constexpr,
141 LOOP: tl.constexpr,
142):
143 pid = tl.program_id(axis=0)
144 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK)
146 for i in range(LOOP):
147 offsets = (base_offsets + i * BLOCK).to(tl.int64)
148 mask = offsets < N
150 row = offsets // src_ncols
151 col = offsets % src_ncols
153 if DIM == 0:
154 src_offsets = row * src_ncols + col
155 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64)
156 out_offsets = idx * out_ncols + col
157 else:
158 src_offsets = row * src_ncols + col
159 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64)
160 out_offsets = row * out_ncols + idx
162 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32)
164 # CAS loop for product
165 stop = tl.where(mask, 0, 1).to(tl.int1)
166 block_stop = False
167 while not block_stop:
168 cur_val = tl.load(out_ptr + out_offsets, mask=mask, other=0.0)
169 new_val = tl.where(stop, cur_val, cur_val * src_val)
170 is_nan = new_val != new_val
171 new_val = tl.where(is_nan, src_val, new_val)
172 cas_res = tl.atomic_cas(
173 out_ptr + out_offsets, cur_val, new_val, sem="relaxed"
174 )
175 stop |= (cur_val == cas_res) | is_nan
176 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK
178 if USE_MASK:
179 ones = tl.full((BLOCK,), 1, dtype=tl.int32)
180 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed")
183@libentry()
184@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop})
185@triton.jit(do_not_specialize=["N"])
186def scatter_reduce_mean_2d_kernel(
187 index_ptr,
188 src_ptr,
189 out_ptr,
190 count_ptr,
191 mask_ptr,
192 N,
193 src_ncols,
194 out_ncols,
195 DIM: tl.constexpr,
196 USE_MASK: tl.constexpr,
197 BLOCK: tl.constexpr,
198 LOOP: tl.constexpr,
199):
200 pid = tl.program_id(axis=0)
201 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK)
203 for i in range(LOOP):
204 offsets = (base_offsets + i * BLOCK).to(tl.int64)
205 mask = offsets < N
207 row = offsets // src_ncols
208 col = offsets % src_ncols
210 if DIM == 0:
211 src_offsets = row * src_ncols + col
212 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64)
213 out_offsets = idx * out_ncols + col
214 else:
215 src_offsets = row * src_ncols + col
216 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64)
217 out_offsets = row * out_ncols + idx
219 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32)
221 tl.atomic_add(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed")
222 ones_f = tl.full((BLOCK,), 1.0, dtype=tl.float32)
223 tl.atomic_add(count_ptr + out_offsets, ones_f, mask=mask, sem="relaxed")
225 if USE_MASK:
226 ones_i = tl.full((BLOCK,), 1, dtype=tl.int32)
227 tl.atomic_add(mask_ptr + out_offsets, ones_i, mask=mask, sem="relaxed")
230@libentry()
231@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop})
232@triton.jit(do_not_specialize=["N"])
233def scatter_reduce_amax_2d_kernel(
234 index_ptr,
235 src_ptr,
236 out_ptr,
237 mask_ptr,
238 N,
239 src_ncols,
240 out_ncols,
241 DIM: tl.constexpr,
242 IS_AMAX: tl.constexpr,
243 USE_MASK: tl.constexpr,
244 USE_CAS: tl.constexpr,
245 BLOCK: tl.constexpr,
246 LOOP: tl.constexpr,
247):
248 pid = tl.program_id(axis=0)
249 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK)
251 for i in range(LOOP):
252 offsets = (base_offsets + i * BLOCK).to(tl.int64)
253 mask = offsets < N
255 row = offsets // src_ncols
256 col = offsets % src_ncols
258 if DIM == 0:
259 src_offsets = row * src_ncols + col
260 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64)
261 out_offsets = idx * out_ncols + col
262 else:
263 src_offsets = row * src_ncols + col
264 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64)
265 out_offsets = row * out_ncols + idx
267 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32)
269 if USE_CAS:
270 stop = tl.where(mask, 0, 1).to(tl.int1)
271 block_stop = False
272 while not block_stop:
273 cur_val = tl.load(out_ptr + out_offsets, mask=mask, other=0.0)
274 if IS_AMAX:
275 new_val = tl.maximum(cur_val, src_val)
276 else:
277 new_val = tl.minimum(cur_val, src_val)
278 cas_res = tl.atomic_cas(
279 out_ptr + out_offsets, cur_val, new_val, sem="relaxed"
280 )
281 stop |= cur_val == cas_res
282 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK
283 else:
284 if IS_AMAX:
285 tl.atomic_max(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed")
286 else:
287 tl.atomic_min(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed")
289 if USE_MASK:
290 ones = tl.full((BLOCK,), 1, dtype=tl.int32)
291 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed")
294# ---------------------------------------------------------------------------
295# Generic 5D Kernels with LOOP optimization
296# For tensors with ndim != 2.
297# ---------------------------------------------------------------------------
300@libentry()
301@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop})
302@triton.jit(do_not_specialize=["N"])
303def scatter_reduce_sum_kernel(
304 index_ptr,
305 src_ptr,
306 out_ptr,
307 mask_ptr,
308 N,
309 out_stride_dim,
310 src_stride_dim,
311 src_shape_dim,
312 out_shape_dim,
313 DIM: tl.constexpr,
314 USE_MASK: tl.constexpr,
315 src_stride_0,
316 src_stride_1,
317 src_stride_2,
318 src_stride_3,
319 src_stride_4,
320 src_shape_0,
321 src_shape_1,
322 src_shape_2,
323 src_shape_3,
324 src_shape_4,
325 idx_stride_0,
326 idx_stride_1,
327 idx_stride_2,
328 idx_stride_3,
329 idx_stride_4,
330 out_stride_0,
331 out_stride_1,
332 out_stride_2,
333 out_stride_3,
334 out_stride_4,
335 BLOCK: tl.constexpr,
336 LOOP: tl.constexpr,
337):
338 pid = tl.program_id(axis=0)
339 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK)
341 for i in range(LOOP):
342 offsets = (base_offsets + i * BLOCK).to(tl.int64)
343 mask = offsets < N
345 remaining = offsets
346 coord0 = remaining // (src_shape_1 * src_shape_2 * src_shape_3 * src_shape_4)
347 remaining = remaining % (src_shape_1 * src_shape_2 * src_shape_3 * src_shape_4)
348 coord1 = remaining // (src_shape_2 * src_shape_3 * src_shape_4)
349 remaining = remaining % (src_shape_2 * src_shape_3 * src_shape_4)
350 coord2 = remaining // (src_shape_3 * src_shape_4)
351 remaining = remaining % (src_shape_3 * src_shape_4)
352 coord3 = remaining // src_shape_4
353 coord4 = remaining % src_shape_4
355 idx_offsets = (
356 coord0 * idx_stride_0
357 + coord1 * idx_stride_1
358 + coord2 * idx_stride_2
359 + coord3 * idx_stride_3
360 + coord4 * idx_stride_4
361 )
362 src_offsets = (
363 coord0 * src_stride_0
364 + coord1 * src_stride_1
365 + coord2 * src_stride_2
366 + coord3 * src_stride_3
367 + coord4 * src_stride_4
368 )
370 idx = tl.load(index_ptr + idx_offsets, mask=mask, other=0).to(tl.int64)
372 if DIM == 0:
373 out_offsets = (
374 idx * out_stride_0
375 + coord1 * out_stride_1
376 + coord2 * out_stride_2
377 + coord3 * out_stride_3
378 + coord4 * out_stride_4
379 )
380 elif DIM == 1:
381 out_offsets = (
382 coord0 * out_stride_0
383 + idx * out_stride_1
384 + coord2 * out_stride_2
385 + coord3 * out_stride_3
386 + coord4 * out_stride_4
387 )
388 elif DIM == 2:
389 out_offsets = (
390 coord0 * out_stride_0
391 + coord1 * out_stride_1
392 + idx * out_stride_2
393 + coord3 * out_stride_3
394 + coord4 * out_stride_4
395 )
396 elif DIM == 3:
397 out_offsets = (
398 coord0 * out_stride_0
399 + coord1 * out_stride_1
400 + coord2 * out_stride_2
401 + idx * out_stride_3
402 + coord4 * out_stride_4
403 )
404 else:
405 out_offsets = (
406 coord0 * out_stride_0
407 + coord1 * out_stride_1
408 + coord2 * out_stride_2
409 + coord3 * out_stride_3
410 + idx * out_stride_4
411 )
413 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32)
414 tl.atomic_add(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed")
416 if USE_MASK:
417 ones = tl.full((BLOCK,), 1, dtype=tl.int32)
418 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed")
421@libentry()
422@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop})
423@triton.jit(do_not_specialize=["N"])
424def scatter_reduce_prod_kernel(
425 index_ptr,
426 src_ptr,
427 out_ptr,
428 mask_ptr,
429 N,
430 out_stride_dim,
431 src_stride_dim,
432 src_shape_dim,
433 out_shape_dim,
434 DIM: tl.constexpr,
435 USE_MASK: tl.constexpr,
436 src_stride_0,
437 src_stride_1,
438 src_stride_2,
439 src_stride_3,
440 src_stride_4,
441 src_shape_0,
442 src_shape_1,
443 src_shape_2,
444 src_shape_3,
445 src_shape_4,
446 idx_stride_0,
447 idx_stride_1,
448 idx_stride_2,
449 idx_stride_3,
450 idx_stride_4,
451 out_stride_0,
452 out_stride_1,
453 out_stride_2,
454 out_stride_3,
455 out_stride_4,
456 BLOCK: tl.constexpr,
457 LOOP: tl.constexpr,
458):
459 pid = tl.program_id(axis=0)
460 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK)
462 for i in range(LOOP):
463 offsets = (base_offsets + i * BLOCK).to(tl.int64)
464 mask = offsets < N
466 remaining = offsets
467 coord0 = remaining // (src_shape_1 * src_shape_2 * src_shape_3 * src_shape_4)
468 remaining = remaining % (src_shape_1 * src_shape_2 * src_shape_3 * src_shape_4)
469 coord1 = remaining // (src_shape_2 * src_shape_3 * src_shape_4)
470 remaining = remaining % (src_shape_2 * src_shape_3 * src_shape_4)
471 coord2 = remaining // (src_shape_3 * src_shape_4)
472 remaining = remaining % (src_shape_3 * src_shape_4)
473 coord3 = remaining // src_shape_4
474 coord4 = remaining % src_shape_4
476 idx_offsets = (
477 coord0 * idx_stride_0
478 + coord1 * idx_stride_1
479 + coord2 * idx_stride_2
480 + coord3 * idx_stride_3
481 + coord4 * idx_stride_4
482 )
483 src_offsets = (
484 coord0 * src_stride_0
485 + coord1 * src_stride_1
486 + coord2 * src_stride_2
487 + coord3 * src_stride_3
488 + coord4 * src_stride_4
489 )
491 idx = tl.load(index_ptr + idx_offsets, mask=mask, other=0).to(tl.int64)
493 if DIM == 0:
494 out_offsets = (
495 idx * out_stride_0
496 + coord1 * out_stride_1
497 + coord2 * out_stride_2
498 + coord3 * out_stride_3
499 + coord4 * out_stride_4
500 )
501 elif DIM == 1:
502 out_offsets = (
503 coord0 * out_stride_0
504 + idx * out_stride_1
505 + coord2 * out_stride_2
506 + coord3 * out_stride_3
507 + coord4 * out_stride_4
508 )
509 elif DIM == 2:
510 out_offsets = (
511 coord0 * out_stride_0
512 + coord1 * out_stride_1
513 + idx * out_stride_2
514 + coord3 * out_stride_3
515 + coord4 * out_stride_4
516 )
517 elif DIM == 3:
518 out_offsets = (
519 coord0 * out_stride_0
520 + coord1 * out_stride_1
521 + coord2 * out_stride_2
522 + idx * out_stride_3
523 + coord4 * out_stride_4
524 )
525 else:
526 out_offsets = (
527 coord0 * out_stride_0
528 + coord1 * out_stride_1
529 + coord2 * out_stride_2
530 + coord3 * out_stride_3
531 + idx * out_stride_4
532 )
534 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32)
536 # CAS loop for product. NaN/Inf guard: if cur_val is NaN, mark as done
537 # to prevent infinite spin (NaN != NaN causes CAS to always fail).
538 stop = tl.where(mask, 0, 1).to(tl.int1)
539 block_stop = False
540 while not block_stop:
541 cur_val = tl.load(out_ptr + out_offsets, mask=mask, other=0.0)
542 new_val = tl.where(stop, cur_val, cur_val * src_val)
543 # Detect NaN: if new_val != new_val (NaN check), use src_val directly
544 is_nan = new_val != new_val
545 new_val = tl.where(is_nan, src_val, new_val)
546 cas_res = tl.atomic_cas(
547 out_ptr + out_offsets, cur_val, new_val, sem="relaxed"
548 )
549 # Mark done if CAS succeeded OR if value is NaN (can't recover)
550 stop |= (cur_val == cas_res) | is_nan
551 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK
553 if USE_MASK:
554 ones = tl.full((BLOCK,), 1, dtype=tl.int32)
555 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed")
558@libentry()
559@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop})
560@triton.jit(do_not_specialize=["N"])
561def scatter_reduce_mean_kernel(
562 index_ptr,
563 src_ptr,
564 out_ptr,
565 count_ptr,
566 mask_ptr,
567 N,
568 out_stride_dim,
569 src_stride_dim,
570 src_shape_dim,
571 out_shape_dim,
572 DIM: tl.constexpr,
573 USE_MASK: tl.constexpr,
574 src_stride_0,
575 src_stride_1,
576 src_stride_2,
577 src_stride_3,
578 src_stride_4,
579 src_shape_0,
580 src_shape_1,
581 src_shape_2,
582 src_shape_3,
583 src_shape_4,
584 idx_stride_0,
585 idx_stride_1,
586 idx_stride_2,
587 idx_stride_3,
588 idx_stride_4,
589 out_stride_0,
590 out_stride_1,
591 out_stride_2,
592 out_stride_3,
593 out_stride_4,
594 BLOCK: tl.constexpr,
595 LOOP: tl.constexpr,
596):
597 pid = tl.program_id(axis=0)
598 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK)
600 for i in range(LOOP):
601 offsets = (base_offsets + i * BLOCK).to(tl.int64)
602 mask = offsets < N
604 remaining = offsets
605 coord0 = remaining // (src_shape_1 * src_shape_2 * src_shape_3 * src_shape_4)
606 remaining = remaining % (src_shape_1 * src_shape_2 * src_shape_3 * src_shape_4)
607 coord1 = remaining // (src_shape_2 * src_shape_3 * src_shape_4)
608 remaining = remaining % (src_shape_2 * src_shape_3 * src_shape_4)
609 coord2 = remaining // (src_shape_3 * src_shape_4)
610 remaining = remaining % (src_shape_3 * src_shape_4)
611 coord3 = remaining // src_shape_4
612 coord4 = remaining % src_shape_4
614 idx_offsets = (
615 coord0 * idx_stride_0
616 + coord1 * idx_stride_1
617 + coord2 * idx_stride_2
618 + coord3 * idx_stride_3
619 + coord4 * idx_stride_4
620 )
621 src_offsets = (
622 coord0 * src_stride_0
623 + coord1 * src_stride_1
624 + coord2 * src_stride_2
625 + coord3 * src_stride_3
626 + coord4 * src_stride_4
627 )
629 idx = tl.load(index_ptr + idx_offsets, mask=mask, other=0).to(tl.int64)
631 if DIM == 0:
632 out_offsets = (
633 idx * out_stride_0
634 + coord1 * out_stride_1
635 + coord2 * out_stride_2
636 + coord3 * out_stride_3
637 + coord4 * out_stride_4
638 )
639 elif DIM == 1:
640 out_offsets = (
641 coord0 * out_stride_0
642 + idx * out_stride_1
643 + coord2 * out_stride_2
644 + coord3 * out_stride_3
645 + coord4 * out_stride_4
646 )
647 elif DIM == 2:
648 out_offsets = (
649 coord0 * out_stride_0
650 + coord1 * out_stride_1
651 + idx * out_stride_2
652 + coord3 * out_stride_3
653 + coord4 * out_stride_4
654 )
655 elif DIM == 3:
656 out_offsets = (
657 coord0 * out_stride_0
658 + coord1 * out_stride_1
659 + coord2 * out_stride_2
660 + idx * out_stride_3
661 + coord4 * out_stride_4
662 )
663 else:
664 out_offsets = (
665 coord0 * out_stride_0
666 + coord1 * out_stride_1
667 + coord2 * out_stride_2
668 + coord3 * out_stride_3
669 + idx * out_stride_4
670 )
672 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32)
674 tl.atomic_add(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed")
675 ones_f = tl.full((BLOCK,), 1.0, dtype=tl.float32)
676 tl.atomic_add(count_ptr + out_offsets, ones_f, mask=mask, sem="relaxed")
678 if USE_MASK:
679 ones_i = tl.full((BLOCK,), 1, dtype=tl.int32)
680 tl.atomic_add(mask_ptr + out_offsets, ones_i, mask=mask, sem="relaxed")
683@libentry()
684@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop})
685@triton.jit(do_not_specialize=["N"])
686def scatter_reduce_amax_kernel(
687 index_ptr,
688 src_ptr,
689 out_ptr,
690 mask_ptr,
691 N,
692 out_stride_dim,
693 src_stride_dim,
694 src_shape_dim,
695 out_shape_dim,
696 DIM: tl.constexpr,
697 IS_AMAX: tl.constexpr,
698 USE_MASK: tl.constexpr,
699 USE_CAS: tl.constexpr,
700 src_stride_0,
701 src_stride_1,
702 src_stride_2,
703 src_stride_3,
704 src_stride_4,
705 src_shape_0,
706 src_shape_1,
707 src_shape_2,
708 src_shape_3,
709 src_shape_4,
710 idx_stride_0,
711 idx_stride_1,
712 idx_stride_2,
713 idx_stride_3,
714 idx_stride_4,
715 out_stride_0,
716 out_stride_1,
717 out_stride_2,
718 out_stride_3,
719 out_stride_4,
720 BLOCK: tl.constexpr,
721 LOOP: tl.constexpr,
722):
723 pid = tl.program_id(axis=0)
724 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK)
726 for i in range(LOOP):
727 offsets = (base_offsets + i * BLOCK).to(tl.int64)
728 mask = offsets < N
730 remaining = offsets
731 coord0 = remaining // (src_shape_1 * src_shape_2 * src_shape_3 * src_shape_4)
732 remaining = remaining % (src_shape_1 * src_shape_2 * src_shape_3 * src_shape_4)
733 coord1 = remaining // (src_shape_2 * src_shape_3 * src_shape_4)
734 remaining = remaining % (src_shape_2 * src_shape_3 * src_shape_4)
735 coord2 = remaining // (src_shape_3 * src_shape_4)
736 remaining = remaining % (src_shape_3 * src_shape_4)
737 coord3 = remaining // src_shape_4
738 coord4 = remaining % src_shape_4
740 idx_offsets = (
741 coord0 * idx_stride_0
742 + coord1 * idx_stride_1
743 + coord2 * idx_stride_2
744 + coord3 * idx_stride_3
745 + coord4 * idx_stride_4
746 )
747 src_offsets = (
748 coord0 * src_stride_0
749 + coord1 * src_stride_1
750 + coord2 * src_stride_2
751 + coord3 * src_stride_3
752 + coord4 * src_stride_4
753 )
755 idx = tl.load(index_ptr + idx_offsets, mask=mask, other=0).to(tl.int64)
757 if DIM == 0:
758 out_offsets = (
759 idx * out_stride_0
760 + coord1 * out_stride_1
761 + coord2 * out_stride_2
762 + coord3 * out_stride_3
763 + coord4 * out_stride_4
764 )
765 elif DIM == 1:
766 out_offsets = (
767 coord0 * out_stride_0
768 + idx * out_stride_1
769 + coord2 * out_stride_2
770 + coord3 * out_stride_3
771 + coord4 * out_stride_4
772 )
773 elif DIM == 2:
774 out_offsets = (
775 coord0 * out_stride_0
776 + coord1 * out_stride_1
777 + idx * out_stride_2
778 + coord3 * out_stride_3
779 + coord4 * out_stride_4
780 )
781 elif DIM == 3:
782 out_offsets = (
783 coord0 * out_stride_0
784 + coord1 * out_stride_1
785 + coord2 * out_stride_2
786 + idx * out_stride_3
787 + coord4 * out_stride_4
788 )
789 else:
790 out_offsets = (
791 coord0 * out_stride_0
792 + coord1 * out_stride_1
793 + coord2 * out_stride_2
794 + coord3 * out_stride_3
795 + idx * out_stride_4
796 )
798 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32)
800 if USE_CAS:
801 stop = tl.where(mask, 0, 1).to(tl.int1)
802 block_stop = False
803 while not block_stop:
804 cur_val = tl.load(out_ptr + out_offsets, mask=mask, other=0.0)
805 if IS_AMAX:
806 new_val = tl.maximum(cur_val, src_val)
807 else:
808 new_val = tl.minimum(cur_val, src_val)
809 cas_res = tl.atomic_cas(
810 out_ptr + out_offsets, cur_val, new_val, sem="relaxed"
811 )
812 stop |= cur_val == cas_res
813 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK
814 else:
815 if IS_AMAX:
816 tl.atomic_max(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed")
817 else:
818 tl.atomic_min(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed")
820 if USE_MASK:
821 ones = tl.full((BLOCK,), 1, dtype=tl.int32)
822 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed")
825# ---------------------------------------------------------------------------
826# Python entry points
827# ---------------------------------------------------------------------------
830def scatter_reduce(inp, dim, index, src, reduce, *, include_self=True):
831 """Triton-accelerated scatter_reduce operation.
833 Scatters src values into the output tensor at positions determined by index,
834 applying the specified reduction. Supports sum, prod, mean, amax, amin.
836 Args:
837 inp: Input tensor (1D-5D).
838 dim: Dimension along which to scatter.
839 index: Index tensor mapping source elements to output positions.
840 src: Source tensor containing values to scatter.
841 reduce: Reduction mode - "sum", "prod", "mean", "amax", or "amin".
842 include_self: If True, include inp values in the reduction.
844 Returns:
845 Output tensor with same shape and dtype as inp.
846 """
847 logger.debug("GEMS SCATTER_REDUCE_TWO")
849 assert reduce in (
850 "sum",
851 "prod",
852 "mean",
853 "amax",
854 "amin",
855 ), f"Unsupported reduce: {reduce}"
856 assert inp.ndim <= 5, f"scatter_reduce supports up to 5D tensors, got {inp.ndim}D"
858 dim = dim % inp.ndim
859 padded_dim = dim + (5 - inp.ndim)
861 out_stride_dim = inp.stride(dim)
862 out_shape_dim = inp.size(dim)
863 src_stride_dim = src.stride(dim)
864 src_shape_dim = src.size(dim)
865 N = index.numel()
867 # Avoid double clone: merge contiguous + float32 cast
868 inp_f32 = inp.to(torch.float32).contiguous()
870 if include_self:
871 out = inp_f32.clone()
872 else:
873 if reduce in ("sum", "mean"):
874 out = torch.zeros_like(inp_f32)
875 elif reduce == "prod":
876 out = torch.ones_like(inp_f32)
877 elif reduce == "amax":
878 out = torch.full_like(inp_f32, float("-inf"))
879 elif reduce == "amin":
880 out = torch.full_like(inp_f32, float("inf"))
882 if N == 0:
883 return out.to(inp.dtype) if not include_self else inp_f32.to(inp.dtype)
885 use_mask = not include_self
886 if use_mask:
887 reduced_mask = torch.zeros(out.shape, dtype=torch.int32, device=inp.device)
889 if reduce == "mean":
890 if include_self:
891 count = torch.ones_like(out, dtype=torch.float32)
892 else:
893 count = torch.zeros_like(out, dtype=torch.float32)
895 src = src.contiguous()
896 index = index.contiguous()
898 # Convert strides/shapes to int64 to avoid overflow in kernel arithmetic
899 src_shapes = [int(x) for x in _pad5(list(src.shape), 1)]
900 src_strides_p = [int(x) for x in _pad5(list(src.stride()), 0)]
901 idx_strides_p = [int(x) for x in _pad5(list(index.stride()), 0)]
902 out_strides_p = [int(x) for x in _pad5(list(out.stride()), 0)]
904 grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * meta["LOOP"]),)
906 dummy_mask = torch.empty(1, dtype=torch.int32, device=inp.device)
907 mask_ptr = reduced_mask if use_mask else dummy_mask
909 # Use 2D fast path for 2D tensors (most common case)
910 use_2d = inp.ndim == 2
912 # For 2D kernels, use raw dim (0 or 1) instead of padded_dim
913 dim_2d = dim
915 with torch_device_fn.device(inp.device):
916 if reduce == "sum":
917 if use_2d:
918 src_ncols = src.shape[1]
919 out_ncols = out.shape[1]
920 scatter_reduce_sum_2d_kernel[grid](
921 index,
922 src,
923 out,
924 mask_ptr,
925 N,
926 src_ncols,
927 out_ncols,
928 dim_2d,
929 use_mask,
930 )
931 else:
932 scatter_reduce_sum_kernel[grid](
933 index,
934 src,
935 out,
936 mask_ptr,
937 N,
938 out_stride_dim,
939 src_stride_dim,
940 src_shape_dim,
941 out_shape_dim,
942 padded_dim,
943 use_mask,
944 src_strides_p[0],
945 src_strides_p[1],
946 src_strides_p[2],
947 src_strides_p[3],
948 src_strides_p[4],
949 src_shapes[0],
950 src_shapes[1],
951 src_shapes[2],
952 src_shapes[3],
953 src_shapes[4],
954 idx_strides_p[0],
955 idx_strides_p[1],
956 idx_strides_p[2],
957 idx_strides_p[3],
958 idx_strides_p[4],
959 out_strides_p[0],
960 out_strides_p[1],
961 out_strides_p[2],
962 out_strides_p[3],
963 out_strides_p[4],
964 )
965 elif reduce == "prod":
966 if use_2d:
967 src_ncols = src.shape[1]
968 out_ncols = out.shape[1]
969 scatter_reduce_prod_2d_kernel[grid](
970 index,
971 src,
972 out,
973 mask_ptr,
974 N,
975 src_ncols,
976 out_ncols,
977 dim_2d,
978 use_mask,
979 )
980 else:
981 scatter_reduce_prod_kernel[grid](
982 index,
983 src,
984 out,
985 mask_ptr,
986 N,
987 out_stride_dim,
988 src_stride_dim,
989 src_shape_dim,
990 out_shape_dim,
991 padded_dim,
992 use_mask,
993 src_strides_p[0],
994 src_strides_p[1],
995 src_strides_p[2],
996 src_strides_p[3],
997 src_strides_p[4],
998 src_shapes[0],
999 src_shapes[1],
1000 src_shapes[2],
1001 src_shapes[3],
1002 src_shapes[4],
1003 idx_strides_p[0],
1004 idx_strides_p[1],
1005 idx_strides_p[2],
1006 idx_strides_p[3],
1007 idx_strides_p[4],
1008 out_strides_p[0],
1009 out_strides_p[1],
1010 out_strides_p[2],
1011 out_strides_p[3],
1012 out_strides_p[4],
1013 )
1014 elif reduce == "mean":
1015 if use_2d:
1016 src_ncols = src.shape[1]
1017 out_ncols = out.shape[1]
1018 scatter_reduce_mean_2d_kernel[grid](
1019 index,
1020 src,
1021 out,
1022 count,
1023 mask_ptr,
1024 N,
1025 src_ncols,
1026 out_ncols,
1027 dim_2d,
1028 use_mask,
1029 )
1030 else:
1031 scatter_reduce_mean_kernel[grid](
1032 index,
1033 src,
1034 out,
1035 count,
1036 mask_ptr,
1037 N,
1038 out_stride_dim,
1039 src_stride_dim,
1040 src_shape_dim,
1041 out_shape_dim,
1042 padded_dim,
1043 use_mask,
1044 src_strides_p[0],
1045 src_strides_p[1],
1046 src_strides_p[2],
1047 src_strides_p[3],
1048 src_strides_p[4],
1049 src_shapes[0],
1050 src_shapes[1],
1051 src_shapes[2],
1052 src_shapes[3],
1053 src_shapes[4],
1054 idx_strides_p[0],
1055 idx_strides_p[1],
1056 idx_strides_p[2],
1057 idx_strides_p[3],
1058 idx_strides_p[4],
1059 out_strides_p[0],
1060 out_strides_p[1],
1061 out_strides_p[2],
1062 out_strides_p[3],
1063 out_strides_p[4],
1064 )
1065 has_contributions = count > 0
1066 count = torch.clamp(count, min=1.0)
1067 out = out / count
1068 out = torch.where(has_contributions, out, inp_f32)
1069 elif reduce in ("amax", "amin"):
1070 use_cas = _needs_cas_fallback()
1071 if use_2d:
1072 src_ncols = src.shape[1]
1073 out_ncols = out.shape[1]
1074 scatter_reduce_amax_2d_kernel[grid](
1075 index,
1076 src,
1077 out,
1078 mask_ptr,
1079 N,
1080 src_ncols,
1081 out_ncols,
1082 dim_2d,
1083 reduce == "amax",
1084 use_mask,
1085 use_cas,
1086 )
1087 else:
1088 scatter_reduce_amax_kernel[grid](
1089 index,
1090 src,
1091 out,
1092 mask_ptr,
1093 N,
1094 out_stride_dim,
1095 src_stride_dim,
1096 src_shape_dim,
1097 out_shape_dim,
1098 padded_dim,
1099 reduce == "amax",
1100 use_mask,
1101 use_cas,
1102 src_strides_p[0],
1103 src_strides_p[1],
1104 src_strides_p[2],
1105 src_strides_p[3],
1106 src_strides_p[4],
1107 src_shapes[0],
1108 src_shapes[1],
1109 src_shapes[2],
1110 src_shapes[3],
1111 src_shapes[4],
1112 idx_strides_p[0],
1113 idx_strides_p[1],
1114 idx_strides_p[2],
1115 idx_strides_p[3],
1116 idx_strides_p[4],
1117 out_strides_p[0],
1118 out_strides_p[1],
1119 out_strides_p[2],
1120 out_strides_p[3],
1121 out_strides_p[4],
1122 )
1124 if use_mask and reduce != "mean":
1125 unreduced = reduced_mask == 0
1126 out = torch.where(unreduced, inp_f32, out)
1128 return out.to(inp.dtype)
1131def scatter_reduce_(inp, dim, index, src, reduce, *, include_self=True):
1132 """In-place variant of scatter_reduce. Modifies inp in-place."""
1133 logger.debug("GEMS SCATTER_REDUCE_TWO_")
1135 result = scatter_reduce(inp, dim, index, src, reduce, include_self=include_self)
1136 inp.copy_(result)
1137 return inp
1140def scatter_reduce_out(inp, dim, index, src, reduce, *, include_self=True, out=None):
1141 """Out-variant of scatter_reduce. Writes result to out tensor if provided."""
1142 logger.debug("GEMS SCATTER_REDUCE_TWO_OUT")
1144 result = scatter_reduce(inp, dim, index, src, reduce, include_self=include_self)
1145 if out is not None:
1146 out.copy_(result)
1147 return out
1148 return result