Coverage for src/flag_gems/runtime/backend/_sunrise/ops/scatter_reduce.py: 0%
470 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1"""Triton implementation of torch.scatter_reduce for FlagGems.
3Supports all reduce modes: sum, prod, mean, amax, amin.
4Handles 1D-5D tensors with up to 5D coordinate decoding via padding.
6Vendor compatibility:
7 - NVIDIA: native atomic_max/min for amax/amin reduce
8 - Iluvatar: CAS-based fallback for atomic_max/min (no native support)
9 - Metax: larger BLOCK=256 for better occupancy
11Performance notes:
12 - Sum/mean use tl.atomic_add with relaxed semantics for throughput
13 - Prod uses CAS loop with NaN detection guard (no tl.atomic_mul exists)
14 - All offset arithmetic uses int64 to avoid overflow for N > 2^31
15 - LOOP=4: each program processes LOOP*BLOCK elements to amortize launch overhead
16 - 2D fast path: specialized kernels for 2D tensors avoid 5D coordinate decoding
17"""
19import logging
21import torch
22import triton
23import triton.language as tl
25import flag_gems
26from flag_gems.runtime import torch_device_fn
27from flag_gems.utils import libentry
29logger = logging.getLogger(__name__)
32def heur_block(args):
33 """Vendor-aware block size heuristic.
35 Metax and Iluvatar GPUs benefit from larger blocks (256) for better
36 occupancy. NVIDIA GPUs default to 128 which balances occupancy and
37 register pressure.
38 """
39 if flag_gems.vendor_name in ["metax", "iluvatar"]:
40 return 256
41 return 128
44def heur_loop(args):
45 """Loop unrolling factor.
47 Each program processes LOOP*BLOCK elements to amortize kernel launch
48 overhead. LOOP=4 is optimal for Iluvatar BI-V150.
49 """
50 return 4
53def heur_scan_block(args):
54 """Source-dimension tile size for deterministic product scan."""
55 return 128
58# ---------------------------------------------------------------------------
59# Helpers
60# ---------------------------------------------------------------------------
63def _pad5(lst, fill):
64 """Pad a list to exactly 5 elements from the left with `fill`.
66 This enables uniform 5D coordinate decoding in kernels regardless
67 of the actual tensor dimensionality (1D-5D). Shapes are padded with 1,
68 strides with 0.
69 """
70 return [fill] * (5 - len(lst)) + lst if len(lst) < 5 else lst
73def _needs_cas_fallback():
74 """Check if the current vendor needs CAS-based fallback for atomic_max/min.
76 Iluvatar GPUs lack native tl.atomic_max/min, so we fall back to a
77 CAS (Compare-And-Swap) loop for amax/amin reduce modes.
78 """
79 return flag_gems.vendor_name in ["iluvatar"]
82@libentry()
83@triton.heuristics({"BLOCK": heur_scan_block})
84@triton.jit(do_not_specialize=["out_numel"])
85def scatter_reduce_prod_scan_kernel(
86 index_ptr,
87 src_ptr,
88 out_ptr,
89 mask_ptr,
90 out_numel,
91 DIM: tl.constexpr,
92 USE_MASK: tl.constexpr,
93 src_shape_dim: tl.constexpr,
94 src_stride_0,
95 src_stride_1,
96 src_stride_2,
97 src_stride_3,
98 src_stride_4,
99 idx_shape_0,
100 idx_shape_1,
101 idx_shape_2,
102 idx_shape_3,
103 idx_shape_4,
104 src_shape_0,
105 src_shape_1,
106 src_shape_2,
107 src_shape_3,
108 src_shape_4,
109 idx_stride_0,
110 idx_stride_1,
111 idx_stride_2,
112 idx_stride_3,
113 idx_stride_4,
114 out_shape_0,
115 out_shape_1,
116 out_shape_2,
117 out_shape_3,
118 out_shape_4,
119 out_stride_0,
120 out_stride_1,
121 out_stride_2,
122 out_stride_3,
123 out_stride_4,
124 BLOCK: tl.constexpr,
125):
126 pid = tl.program_id(axis=0).to(tl.int64)
127 in_bounds = pid < out_numel
129 remaining = pid
130 coord0 = remaining // (out_shape_1 * out_shape_2 * out_shape_3 * out_shape_4)
131 remaining = remaining % (out_shape_1 * out_shape_2 * out_shape_3 * out_shape_4)
132 coord1 = remaining // (out_shape_2 * out_shape_3 * out_shape_4)
133 remaining = remaining % (out_shape_2 * out_shape_3 * out_shape_4)
134 coord2 = remaining // (out_shape_3 * out_shape_4)
135 remaining = remaining % (out_shape_3 * out_shape_4)
136 coord3 = remaining // out_shape_4
137 coord4 = remaining % out_shape_4
139 out_offset = (
140 coord0 * out_stride_0
141 + coord1 * out_stride_1
142 + coord2 * out_stride_2
143 + coord3 * out_stride_3
144 + coord4 * out_stride_4
145 )
146 idx_full_offset = (
147 coord0 * idx_stride_0
148 + coord1 * idx_stride_1
149 + coord2 * idx_stride_2
150 + coord3 * idx_stride_3
151 + coord4 * idx_stride_4
152 )
153 src_full_offset = (
154 coord0 * src_stride_0
155 + coord1 * src_stride_1
156 + coord2 * src_stride_2
157 + coord3 * src_stride_3
158 + coord4 * src_stride_4
159 )
161 if DIM == 0:
162 target = coord0
163 idx_base = idx_full_offset - coord0 * idx_stride_0
164 src_base = src_full_offset - coord0 * src_stride_0
165 idx_scan_stride = idx_stride_0
166 src_scan_stride = src_stride_0
167 idx_scan_shape = idx_shape_0
168 valid_other = (
169 (coord1 < idx_shape_1)
170 & (coord2 < idx_shape_2)
171 & (coord3 < idx_shape_3)
172 & (coord4 < idx_shape_4)
173 & (coord1 < src_shape_1)
174 & (coord2 < src_shape_2)
175 & (coord3 < src_shape_3)
176 & (coord4 < src_shape_4)
177 )
178 elif DIM == 1:
179 target = coord1
180 idx_base = idx_full_offset - coord1 * idx_stride_1
181 src_base = src_full_offset - coord1 * src_stride_1
182 idx_scan_stride = idx_stride_1
183 src_scan_stride = src_stride_1
184 idx_scan_shape = idx_shape_1
185 valid_other = (
186 (coord0 < idx_shape_0)
187 & (coord2 < idx_shape_2)
188 & (coord3 < idx_shape_3)
189 & (coord4 < idx_shape_4)
190 & (coord0 < src_shape_0)
191 & (coord2 < src_shape_2)
192 & (coord3 < src_shape_3)
193 & (coord4 < src_shape_4)
194 )
195 elif DIM == 2:
196 target = coord2
197 idx_base = idx_full_offset - coord2 * idx_stride_2
198 src_base = src_full_offset - coord2 * src_stride_2
199 idx_scan_stride = idx_stride_2
200 src_scan_stride = src_stride_2
201 idx_scan_shape = idx_shape_2
202 valid_other = (
203 (coord0 < idx_shape_0)
204 & (coord1 < idx_shape_1)
205 & (coord3 < idx_shape_3)
206 & (coord4 < idx_shape_4)
207 & (coord0 < src_shape_0)
208 & (coord1 < src_shape_1)
209 & (coord3 < src_shape_3)
210 & (coord4 < src_shape_4)
211 )
212 elif DIM == 3:
213 target = coord3
214 idx_base = idx_full_offset - coord3 * idx_stride_3
215 src_base = src_full_offset - coord3 * src_stride_3
216 idx_scan_stride = idx_stride_3
217 src_scan_stride = src_stride_3
218 idx_scan_shape = idx_shape_3
219 valid_other = (
220 (coord0 < idx_shape_0)
221 & (coord1 < idx_shape_1)
222 & (coord2 < idx_shape_2)
223 & (coord4 < idx_shape_4)
224 & (coord0 < src_shape_0)
225 & (coord1 < src_shape_1)
226 & (coord2 < src_shape_2)
227 & (coord4 < src_shape_4)
228 )
229 else:
230 target = coord4
231 idx_base = idx_full_offset - coord4 * idx_stride_4
232 src_base = src_full_offset - coord4 * src_stride_4
233 idx_scan_stride = idx_stride_4
234 src_scan_stride = src_stride_4
235 idx_scan_shape = idx_shape_4
236 valid_other = (
237 (coord0 < idx_shape_0)
238 & (coord1 < idx_shape_1)
239 & (coord2 < idx_shape_2)
240 & (coord3 < idx_shape_3)
241 & (coord0 < src_shape_0)
242 & (coord1 < src_shape_1)
243 & (coord2 < src_shape_2)
244 & (coord3 < src_shape_3)
245 )
247 lanes = tl.arange(0, BLOCK)
248 acc = tl.load(out_ptr + out_offset, mask=in_bounds, other=1.0).to(tl.float32)
249 has_contrib = False
251 for start in range(0, src_shape_dim, BLOCK):
252 scan = start + lanes
253 valid = (
254 in_bounds & valid_other & (scan < src_shape_dim) & (scan < idx_scan_shape)
255 )
256 idx_val = tl.load(
257 index_ptr + idx_base + scan * idx_scan_stride,
258 mask=valid,
259 other=-1,
260 ).to(tl.int64)
261 match = valid & (idx_val == target)
262 src_val = tl.load(
263 src_ptr + src_base + scan * src_scan_stride,
264 mask=valid,
265 other=1.0,
266 ).to(tl.float32)
267 factors = tl.where(match, src_val, 1.0)
268 prefix = tl.cumprod(factors, 0)
269 tile_prod = tl.sum(tl.where(lanes == (BLOCK - 1), prefix, 0.0))
270 acc *= tile_prod
271 has_contrib |= tl.sum(match.to(tl.int32)) > 0
273 tl.store(out_ptr + out_offset, acc, mask=in_bounds)
274 if USE_MASK:
275 tl.store(mask_ptr + out_offset, has_contrib.to(tl.int32), mask=in_bounds)
278# ---------------------------------------------------------------------------
279# 2D Fast Path Kernels with LOOP
280# Specialized for 2D tensors to avoid 5D coordinate decoding overhead.
281# Uses 1D grid with LOOP=4 to amortize kernel launch overhead.
282# ---------------------------------------------------------------------------
285@libentry()
286@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop})
287@triton.jit(do_not_specialize=["N"])
288def scatter_reduce_sum_2d_kernel(
289 index_ptr,
290 src_ptr,
291 out_ptr,
292 mask_ptr,
293 N,
294 idx_ncols,
295 src_ncols,
296 out_ncols,
297 DIM: tl.constexpr,
298 USE_MASK: tl.constexpr,
299 BLOCK: tl.constexpr,
300 LOOP: tl.constexpr,
301):
302 pid = tl.program_id(axis=0)
303 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK)
305 for i in range(LOOP):
306 offsets = (base_offsets + i * BLOCK).to(tl.int64)
307 mask = offsets < N
309 row = offsets // idx_ncols
310 col = offsets % idx_ncols
312 if DIM == 0:
313 src_offsets = row * src_ncols + col
314 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64)
315 out_offsets = idx * out_ncols + col
316 else:
317 src_offsets = row * src_ncols + col
318 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64)
319 out_offsets = row * out_ncols + idx
321 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32)
322 tl.atomic_add(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed")
324 if USE_MASK:
325 ones = tl.full((BLOCK,), 1, dtype=tl.int32)
326 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed")
329@libentry()
330@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop})
331@triton.jit(do_not_specialize=["N"])
332def scatter_reduce_prod_2d_kernel(
333 index_ptr,
334 src_ptr,
335 out_ptr,
336 mask_ptr,
337 N,
338 idx_ncols,
339 src_ncols,
340 out_ncols,
341 DIM: tl.constexpr,
342 USE_MASK: tl.constexpr,
343 BLOCK: tl.constexpr,
344 LOOP: tl.constexpr,
345):
346 pid = tl.program_id(axis=0)
347 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK)
349 for i in range(LOOP):
350 offsets = (base_offsets + i * BLOCK).to(tl.int64)
351 mask = offsets < N
353 row = offsets // idx_ncols
354 col = offsets % idx_ncols
356 if DIM == 0:
357 src_offsets = row * src_ncols + col
358 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64)
359 out_offsets = idx * out_ncols + col
360 else:
361 src_offsets = row * src_ncols + col
362 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64)
363 out_offsets = row * out_ncols + idx
365 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32)
367 # CAS loop for product
368 stop = tl.where(mask, 0, 1).to(tl.int1)
369 block_stop = False
370 out_ptr_u32 = (out_ptr + out_offsets).to(
371 tl.pointer_type(tl.uint32, 1), bitcast=True
372 )
373 while not block_stop:
374 cur_bits = tl.load(out_ptr_u32, mask=mask, other=0)
375 cur_val = cur_bits.to(tl.float32, bitcast=True)
376 new_val = tl.where(stop, cur_val, cur_val * src_val)
377 is_nan = new_val != new_val
378 new_val = tl.where(is_nan, src_val, new_val)
379 new_bits = new_val.to(tl.uint32, bitcast=True)
380 # Sunrise/PTPU is more stable when product CAS operates on the raw
381 # float32 bit pattern instead of a floating-pointer CAS.
382 cas_res = tl.atomic_cas(out_ptr_u32, cur_bits, new_bits, sem="acq_rel")
383 stop |= (cur_bits == cas_res) | is_nan
384 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK
386 if USE_MASK:
387 ones = tl.full((BLOCK,), 1, dtype=tl.int32)
388 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed")
391@libentry()
392@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop})
393@triton.jit(do_not_specialize=["N"])
394def scatter_reduce_mean_2d_kernel(
395 index_ptr,
396 src_ptr,
397 out_ptr,
398 count_ptr,
399 mask_ptr,
400 N,
401 idx_ncols,
402 src_ncols,
403 out_ncols,
404 DIM: tl.constexpr,
405 USE_MASK: tl.constexpr,
406 BLOCK: tl.constexpr,
407 LOOP: tl.constexpr,
408):
409 pid = tl.program_id(axis=0)
410 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK)
412 for i in range(LOOP):
413 offsets = (base_offsets + i * BLOCK).to(tl.int64)
414 mask = offsets < N
416 row = offsets // idx_ncols
417 col = offsets % idx_ncols
419 if DIM == 0:
420 src_offsets = row * src_ncols + col
421 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64)
422 out_offsets = idx * out_ncols + col
423 else:
424 src_offsets = row * src_ncols + col
425 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64)
426 out_offsets = row * out_ncols + idx
428 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32)
430 tl.atomic_add(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed")
431 ones_f = tl.full((BLOCK,), 1.0, dtype=tl.float32)
432 tl.atomic_add(count_ptr + out_offsets, ones_f, mask=mask, sem="relaxed")
434 if USE_MASK:
435 ones_i = tl.full((BLOCK,), 1, dtype=tl.int32)
436 tl.atomic_add(mask_ptr + out_offsets, ones_i, mask=mask, sem="relaxed")
439@libentry()
440@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop})
441@triton.jit(do_not_specialize=["N"])
442def scatter_reduce_amax_2d_kernel(
443 index_ptr,
444 src_ptr,
445 out_ptr,
446 mask_ptr,
447 N,
448 idx_ncols,
449 src_ncols,
450 out_ncols,
451 DIM: tl.constexpr,
452 IS_AMAX: tl.constexpr,
453 USE_MASK: tl.constexpr,
454 USE_CAS: tl.constexpr,
455 BLOCK: tl.constexpr,
456 LOOP: tl.constexpr,
457):
458 pid = tl.program_id(axis=0)
459 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK)
461 for i in range(LOOP):
462 offsets = (base_offsets + i * BLOCK).to(tl.int64)
463 mask = offsets < N
465 row = offsets // idx_ncols
466 col = offsets % idx_ncols
468 if DIM == 0:
469 src_offsets = row * src_ncols + col
470 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64)
471 out_offsets = idx * out_ncols + col
472 else:
473 src_offsets = row * src_ncols + col
474 idx = tl.load(index_ptr + src_offsets, mask=mask, other=0).to(tl.int64)
475 out_offsets = row * out_ncols + idx
477 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32)
479 if USE_CAS:
480 stop = tl.where(mask, 0, 1).to(tl.int1)
481 block_stop = False
482 while not block_stop:
483 cur_val = tl.load(out_ptr + out_offsets, mask=mask, other=0.0)
484 if IS_AMAX:
485 new_val = tl.maximum(cur_val, src_val)
486 else:
487 new_val = tl.minimum(cur_val, src_val)
488 cas_res = tl.atomic_cas(
489 out_ptr + out_offsets, cur_val, new_val, sem="relaxed"
490 )
491 stop |= cur_val == cas_res
492 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK
493 else:
494 if IS_AMAX:
495 tl.atomic_max(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed")
496 else:
497 tl.atomic_min(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed")
499 if USE_MASK:
500 ones = tl.full((BLOCK,), 1, dtype=tl.int32)
501 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed")
504# ---------------------------------------------------------------------------
505# Generic 5D Kernels with LOOP optimization
506# For tensors with ndim != 2.
507# ---------------------------------------------------------------------------
510@libentry()
511@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop})
512@triton.jit(do_not_specialize=["N"])
513def scatter_reduce_sum_kernel(
514 index_ptr,
515 src_ptr,
516 out_ptr,
517 mask_ptr,
518 N,
519 out_stride_dim,
520 src_stride_dim,
521 src_shape_dim,
522 out_shape_dim,
523 DIM: tl.constexpr,
524 USE_MASK: tl.constexpr,
525 src_stride_0,
526 src_stride_1,
527 src_stride_2,
528 src_stride_3,
529 src_stride_4,
530 idx_shape_0,
531 idx_shape_1,
532 idx_shape_2,
533 idx_shape_3,
534 idx_shape_4,
535 src_shape_0,
536 src_shape_1,
537 src_shape_2,
538 src_shape_3,
539 src_shape_4,
540 idx_stride_0,
541 idx_stride_1,
542 idx_stride_2,
543 idx_stride_3,
544 idx_stride_4,
545 out_stride_0,
546 out_stride_1,
547 out_stride_2,
548 out_stride_3,
549 out_stride_4,
550 BLOCK: tl.constexpr,
551 LOOP: tl.constexpr,
552):
553 pid = tl.program_id(axis=0)
554 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK)
556 for i in range(LOOP):
557 offsets = (base_offsets + i * BLOCK).to(tl.int64)
558 mask = offsets < N
560 remaining = offsets
561 coord0 = remaining // (idx_shape_1 * idx_shape_2 * idx_shape_3 * idx_shape_4)
562 remaining = remaining % (idx_shape_1 * idx_shape_2 * idx_shape_3 * idx_shape_4)
563 coord1 = remaining // (idx_shape_2 * idx_shape_3 * idx_shape_4)
564 remaining = remaining % (idx_shape_2 * idx_shape_3 * idx_shape_4)
565 coord2 = remaining // (idx_shape_3 * idx_shape_4)
566 remaining = remaining % (idx_shape_3 * idx_shape_4)
567 coord3 = remaining // idx_shape_4
568 coord4 = remaining % idx_shape_4
570 idx_offsets = (
571 coord0 * idx_stride_0
572 + coord1 * idx_stride_1
573 + coord2 * idx_stride_2
574 + coord3 * idx_stride_3
575 + coord4 * idx_stride_4
576 )
577 src_offsets = (
578 coord0 * src_stride_0
579 + coord1 * src_stride_1
580 + coord2 * src_stride_2
581 + coord3 * src_stride_3
582 + coord4 * src_stride_4
583 )
585 idx = tl.load(index_ptr + idx_offsets, mask=mask, other=0).to(tl.int64)
587 if DIM == 0:
588 out_offsets = (
589 idx * out_stride_0
590 + coord1 * out_stride_1
591 + coord2 * out_stride_2
592 + coord3 * out_stride_3
593 + coord4 * out_stride_4
594 )
595 elif DIM == 1:
596 out_offsets = (
597 coord0 * out_stride_0
598 + idx * out_stride_1
599 + coord2 * out_stride_2
600 + coord3 * out_stride_3
601 + coord4 * out_stride_4
602 )
603 elif DIM == 2:
604 out_offsets = (
605 coord0 * out_stride_0
606 + coord1 * out_stride_1
607 + idx * out_stride_2
608 + coord3 * out_stride_3
609 + coord4 * out_stride_4
610 )
611 elif DIM == 3:
612 out_offsets = (
613 coord0 * out_stride_0
614 + coord1 * out_stride_1
615 + coord2 * out_stride_2
616 + idx * out_stride_3
617 + coord4 * out_stride_4
618 )
619 else:
620 out_offsets = (
621 coord0 * out_stride_0
622 + coord1 * out_stride_1
623 + coord2 * out_stride_2
624 + coord3 * out_stride_3
625 + idx * out_stride_4
626 )
628 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32)
629 tl.atomic_add(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed")
631 if USE_MASK:
632 ones = tl.full((BLOCK,), 1, dtype=tl.int32)
633 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed")
636@libentry()
637@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop})
638@triton.jit(do_not_specialize=["N"])
639def scatter_reduce_prod_kernel(
640 index_ptr,
641 src_ptr,
642 out_ptr,
643 mask_ptr,
644 N,
645 out_stride_dim,
646 src_stride_dim,
647 src_shape_dim,
648 out_shape_dim,
649 DIM: tl.constexpr,
650 USE_MASK: tl.constexpr,
651 src_stride_0,
652 src_stride_1,
653 src_stride_2,
654 src_stride_3,
655 src_stride_4,
656 idx_shape_0,
657 idx_shape_1,
658 idx_shape_2,
659 idx_shape_3,
660 idx_shape_4,
661 src_shape_0,
662 src_shape_1,
663 src_shape_2,
664 src_shape_3,
665 src_shape_4,
666 idx_stride_0,
667 idx_stride_1,
668 idx_stride_2,
669 idx_stride_3,
670 idx_stride_4,
671 out_stride_0,
672 out_stride_1,
673 out_stride_2,
674 out_stride_3,
675 out_stride_4,
676 BLOCK: tl.constexpr,
677 LOOP: tl.constexpr,
678):
679 pid = tl.program_id(axis=0)
680 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK)
682 for i in range(LOOP):
683 offsets = (base_offsets + i * BLOCK).to(tl.int64)
684 mask = offsets < N
686 remaining = offsets
687 coord0 = remaining // (idx_shape_1 * idx_shape_2 * idx_shape_3 * idx_shape_4)
688 remaining = remaining % (idx_shape_1 * idx_shape_2 * idx_shape_3 * idx_shape_4)
689 coord1 = remaining // (idx_shape_2 * idx_shape_3 * idx_shape_4)
690 remaining = remaining % (idx_shape_2 * idx_shape_3 * idx_shape_4)
691 coord2 = remaining // (idx_shape_3 * idx_shape_4)
692 remaining = remaining % (idx_shape_3 * idx_shape_4)
693 coord3 = remaining // idx_shape_4
694 coord4 = remaining % idx_shape_4
696 idx_offsets = (
697 coord0 * idx_stride_0
698 + coord1 * idx_stride_1
699 + coord2 * idx_stride_2
700 + coord3 * idx_stride_3
701 + coord4 * idx_stride_4
702 )
703 src_offsets = (
704 coord0 * src_stride_0
705 + coord1 * src_stride_1
706 + coord2 * src_stride_2
707 + coord3 * src_stride_3
708 + coord4 * src_stride_4
709 )
711 idx = tl.load(index_ptr + idx_offsets, mask=mask, other=0).to(tl.int64)
713 if DIM == 0:
714 out_offsets = (
715 idx * out_stride_0
716 + coord1 * out_stride_1
717 + coord2 * out_stride_2
718 + coord3 * out_stride_3
719 + coord4 * out_stride_4
720 )
721 elif DIM == 1:
722 out_offsets = (
723 coord0 * out_stride_0
724 + idx * out_stride_1
725 + coord2 * out_stride_2
726 + coord3 * out_stride_3
727 + coord4 * out_stride_4
728 )
729 elif DIM == 2:
730 out_offsets = (
731 coord0 * out_stride_0
732 + coord1 * out_stride_1
733 + idx * out_stride_2
734 + coord3 * out_stride_3
735 + coord4 * out_stride_4
736 )
737 elif DIM == 3:
738 out_offsets = (
739 coord0 * out_stride_0
740 + coord1 * out_stride_1
741 + coord2 * out_stride_2
742 + idx * out_stride_3
743 + coord4 * out_stride_4
744 )
745 else:
746 out_offsets = (
747 coord0 * out_stride_0
748 + coord1 * out_stride_1
749 + coord2 * out_stride_2
750 + coord3 * out_stride_3
751 + idx * out_stride_4
752 )
754 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32)
756 # CAS loop for product. NaN/Inf guard: if cur_val is NaN, mark as done
757 # to prevent infinite spin (NaN != NaN causes CAS to always fail).
758 stop = tl.where(mask, 0, 1).to(tl.int1)
759 block_stop = False
760 out_ptr_u32 = (out_ptr + out_offsets).to(
761 tl.pointer_type(tl.uint32, 1), bitcast=True
762 )
763 while not block_stop:
764 cur_bits = tl.load(out_ptr_u32, mask=mask, other=0)
765 cur_val = cur_bits.to(tl.float32, bitcast=True)
766 new_val = tl.where(stop, cur_val, cur_val * src_val)
767 # Detect NaN: if new_val != new_val (NaN check), use src_val directly
768 is_nan = new_val != new_val
769 new_val = tl.where(is_nan, src_val, new_val)
770 new_bits = new_val.to(tl.uint32, bitcast=True)
771 # Sunrise/PTPU is more stable when product CAS operates on the raw
772 # float32 bit pattern instead of a floating-pointer CAS.
773 cas_res = tl.atomic_cas(out_ptr_u32, cur_bits, new_bits, sem="acq_rel")
774 # Mark done if CAS succeeded OR if value is NaN (can't recover)
775 stop |= (cur_bits == cas_res) | is_nan
776 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK
778 if USE_MASK:
779 ones = tl.full((BLOCK,), 1, dtype=tl.int32)
780 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed")
783@libentry()
784@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop})
785@triton.jit(do_not_specialize=["N"])
786def scatter_reduce_mean_kernel(
787 index_ptr,
788 src_ptr,
789 out_ptr,
790 count_ptr,
791 mask_ptr,
792 N,
793 out_stride_dim,
794 src_stride_dim,
795 src_shape_dim,
796 out_shape_dim,
797 DIM: tl.constexpr,
798 USE_MASK: tl.constexpr,
799 src_stride_0,
800 src_stride_1,
801 src_stride_2,
802 src_stride_3,
803 src_stride_4,
804 idx_shape_0,
805 idx_shape_1,
806 idx_shape_2,
807 idx_shape_3,
808 idx_shape_4,
809 src_shape_0,
810 src_shape_1,
811 src_shape_2,
812 src_shape_3,
813 src_shape_4,
814 idx_stride_0,
815 idx_stride_1,
816 idx_stride_2,
817 idx_stride_3,
818 idx_stride_4,
819 out_stride_0,
820 out_stride_1,
821 out_stride_2,
822 out_stride_3,
823 out_stride_4,
824 BLOCK: tl.constexpr,
825 LOOP: tl.constexpr,
826):
827 pid = tl.program_id(axis=0)
828 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK)
830 for i in range(LOOP):
831 offsets = (base_offsets + i * BLOCK).to(tl.int64)
832 mask = offsets < N
834 remaining = offsets
835 coord0 = remaining // (idx_shape_1 * idx_shape_2 * idx_shape_3 * idx_shape_4)
836 remaining = remaining % (idx_shape_1 * idx_shape_2 * idx_shape_3 * idx_shape_4)
837 coord1 = remaining // (idx_shape_2 * idx_shape_3 * idx_shape_4)
838 remaining = remaining % (idx_shape_2 * idx_shape_3 * idx_shape_4)
839 coord2 = remaining // (idx_shape_3 * idx_shape_4)
840 remaining = remaining % (idx_shape_3 * idx_shape_4)
841 coord3 = remaining // idx_shape_4
842 coord4 = remaining % idx_shape_4
844 idx_offsets = (
845 coord0 * idx_stride_0
846 + coord1 * idx_stride_1
847 + coord2 * idx_stride_2
848 + coord3 * idx_stride_3
849 + coord4 * idx_stride_4
850 )
851 src_offsets = (
852 coord0 * src_stride_0
853 + coord1 * src_stride_1
854 + coord2 * src_stride_2
855 + coord3 * src_stride_3
856 + coord4 * src_stride_4
857 )
859 idx = tl.load(index_ptr + idx_offsets, mask=mask, other=0).to(tl.int64)
861 if DIM == 0:
862 out_offsets = (
863 idx * out_stride_0
864 + coord1 * out_stride_1
865 + coord2 * out_stride_2
866 + coord3 * out_stride_3
867 + coord4 * out_stride_4
868 )
869 elif DIM == 1:
870 out_offsets = (
871 coord0 * out_stride_0
872 + idx * out_stride_1
873 + coord2 * out_stride_2
874 + coord3 * out_stride_3
875 + coord4 * out_stride_4
876 )
877 elif DIM == 2:
878 out_offsets = (
879 coord0 * out_stride_0
880 + coord1 * out_stride_1
881 + idx * out_stride_2
882 + coord3 * out_stride_3
883 + coord4 * out_stride_4
884 )
885 elif DIM == 3:
886 out_offsets = (
887 coord0 * out_stride_0
888 + coord1 * out_stride_1
889 + coord2 * out_stride_2
890 + idx * out_stride_3
891 + coord4 * out_stride_4
892 )
893 else:
894 out_offsets = (
895 coord0 * out_stride_0
896 + coord1 * out_stride_1
897 + coord2 * out_stride_2
898 + coord3 * out_stride_3
899 + idx * out_stride_4
900 )
902 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32)
904 tl.atomic_add(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed")
905 ones_f = tl.full((BLOCK,), 1.0, dtype=tl.float32)
906 tl.atomic_add(count_ptr + out_offsets, ones_f, mask=mask, sem="relaxed")
908 if USE_MASK:
909 ones_i = tl.full((BLOCK,), 1, dtype=tl.int32)
910 tl.atomic_add(mask_ptr + out_offsets, ones_i, mask=mask, sem="relaxed")
913@libentry()
914@triton.heuristics({"BLOCK": heur_block, "LOOP": heur_loop})
915@triton.jit(do_not_specialize=["N"])
916def scatter_reduce_amax_kernel(
917 index_ptr,
918 src_ptr,
919 out_ptr,
920 mask_ptr,
921 N,
922 out_stride_dim,
923 src_stride_dim,
924 src_shape_dim,
925 out_shape_dim,
926 DIM: tl.constexpr,
927 IS_AMAX: tl.constexpr,
928 USE_MASK: tl.constexpr,
929 USE_CAS: tl.constexpr,
930 src_stride_0,
931 src_stride_1,
932 src_stride_2,
933 src_stride_3,
934 src_stride_4,
935 idx_shape_0,
936 idx_shape_1,
937 idx_shape_2,
938 idx_shape_3,
939 idx_shape_4,
940 src_shape_0,
941 src_shape_1,
942 src_shape_2,
943 src_shape_3,
944 src_shape_4,
945 idx_stride_0,
946 idx_stride_1,
947 idx_stride_2,
948 idx_stride_3,
949 idx_stride_4,
950 out_stride_0,
951 out_stride_1,
952 out_stride_2,
953 out_stride_3,
954 out_stride_4,
955 BLOCK: tl.constexpr,
956 LOOP: tl.constexpr,
957):
958 pid = tl.program_id(axis=0)
959 base_offsets = pid * BLOCK * LOOP + tl.arange(0, BLOCK)
961 for i in range(LOOP):
962 offsets = (base_offsets + i * BLOCK).to(tl.int64)
963 mask = offsets < N
965 remaining = offsets
966 coord0 = remaining // (idx_shape_1 * idx_shape_2 * idx_shape_3 * idx_shape_4)
967 remaining = remaining % (idx_shape_1 * idx_shape_2 * idx_shape_3 * idx_shape_4)
968 coord1 = remaining // (idx_shape_2 * idx_shape_3 * idx_shape_4)
969 remaining = remaining % (idx_shape_2 * idx_shape_3 * idx_shape_4)
970 coord2 = remaining // (idx_shape_3 * idx_shape_4)
971 remaining = remaining % (idx_shape_3 * idx_shape_4)
972 coord3 = remaining // idx_shape_4
973 coord4 = remaining % idx_shape_4
975 idx_offsets = (
976 coord0 * idx_stride_0
977 + coord1 * idx_stride_1
978 + coord2 * idx_stride_2
979 + coord3 * idx_stride_3
980 + coord4 * idx_stride_4
981 )
982 src_offsets = (
983 coord0 * src_stride_0
984 + coord1 * src_stride_1
985 + coord2 * src_stride_2
986 + coord3 * src_stride_3
987 + coord4 * src_stride_4
988 )
990 idx = tl.load(index_ptr + idx_offsets, mask=mask, other=0).to(tl.int64)
992 if DIM == 0:
993 out_offsets = (
994 idx * out_stride_0
995 + coord1 * out_stride_1
996 + coord2 * out_stride_2
997 + coord3 * out_stride_3
998 + coord4 * out_stride_4
999 )
1000 elif DIM == 1:
1001 out_offsets = (
1002 coord0 * out_stride_0
1003 + idx * out_stride_1
1004 + coord2 * out_stride_2
1005 + coord3 * out_stride_3
1006 + coord4 * out_stride_4
1007 )
1008 elif DIM == 2:
1009 out_offsets = (
1010 coord0 * out_stride_0
1011 + coord1 * out_stride_1
1012 + idx * out_stride_2
1013 + coord3 * out_stride_3
1014 + coord4 * out_stride_4
1015 )
1016 elif DIM == 3:
1017 out_offsets = (
1018 coord0 * out_stride_0
1019 + coord1 * out_stride_1
1020 + coord2 * out_stride_2
1021 + idx * out_stride_3
1022 + coord4 * out_stride_4
1023 )
1024 else:
1025 out_offsets = (
1026 coord0 * out_stride_0
1027 + coord1 * out_stride_1
1028 + coord2 * out_stride_2
1029 + coord3 * out_stride_3
1030 + idx * out_stride_4
1031 )
1033 src_val = tl.load(src_ptr + src_offsets, mask=mask, other=0).to(tl.float32)
1035 if USE_CAS:
1036 stop = tl.where(mask, 0, 1).to(tl.int1)
1037 block_stop = False
1038 while not block_stop:
1039 cur_val = tl.load(out_ptr + out_offsets, mask=mask, other=0.0)
1040 if IS_AMAX:
1041 new_val = tl.maximum(cur_val, src_val)
1042 else:
1043 new_val = tl.minimum(cur_val, src_val)
1044 cas_res = tl.atomic_cas(
1045 out_ptr + out_offsets, cur_val, new_val, sem="relaxed"
1046 )
1047 stop |= cur_val == cas_res
1048 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK
1049 else:
1050 if IS_AMAX:
1051 tl.atomic_max(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed")
1052 else:
1053 tl.atomic_min(out_ptr + out_offsets, src_val, mask=mask, sem="relaxed")
1055 if USE_MASK:
1056 ones = tl.full((BLOCK,), 1, dtype=tl.int32)
1057 tl.atomic_add(mask_ptr + out_offsets, ones, mask=mask, sem="relaxed")
1060# ---------------------------------------------------------------------------
1061# Python entry points
1062# ---------------------------------------------------------------------------
1065def scatter_reduce(inp, dim, index, src, reduce, *, include_self=True):
1066 """Triton-accelerated scatter_reduce operation.
1068 Scatters src values into the output tensor at positions determined by index,
1069 applying the specified reduction. Supports sum, prod, mean, amax, amin.
1071 Args:
1072 inp: Input tensor (1D-5D).
1073 dim: Dimension along which to scatter.
1074 index: Index tensor mapping source elements to output positions.
1075 src: Source tensor containing values to scatter.
1076 reduce: Reduction mode - "sum", "prod", "mean", "amax", or "amin".
1077 include_self: If True, include inp values in the reduction.
1079 Returns:
1080 Output tensor with same shape and dtype as inp.
1081 """
1082 logger.debug("GEMS SCATTER_REDUCE_TWO")
1084 assert reduce in (
1085 "sum",
1086 "prod",
1087 "mean",
1088 "amax",
1089 "amin",
1090 ), f"Unsupported reduce: {reduce}"
1091 assert inp.ndim <= 5, f"scatter_reduce supports up to 5D tensors, got {inp.ndim}D"
1093 dim = dim % inp.ndim
1094 padded_dim = dim + (5 - inp.ndim)
1096 out_stride_dim = inp.stride(dim)
1097 out_shape_dim = inp.size(dim)
1098 src_stride_dim = src.stride(dim)
1099 src_shape_dim = src.size(dim)
1100 N = index.numel()
1102 # Avoid double clone: merge contiguous + float32 cast
1103 inp_f32 = inp.to(torch.float32).contiguous()
1105 if include_self:
1106 out = inp_f32.clone()
1107 else:
1108 if reduce in ("sum", "mean"):
1109 out = torch.zeros_like(inp_f32)
1110 elif reduce == "prod":
1111 out = torch.ones_like(inp_f32)
1112 elif reduce == "amax":
1113 out = torch.full_like(inp_f32, float("-inf"))
1114 elif reduce == "amin":
1115 out = torch.full_like(inp_f32, float("inf"))
1117 if N == 0:
1118 return out.to(inp.dtype) if not include_self else inp_f32.to(inp.dtype)
1120 use_mask = not include_self
1121 if use_mask:
1122 reduced_mask = torch.zeros(out.shape, dtype=torch.int32, device=inp.device)
1124 if reduce == "mean":
1125 if include_self:
1126 count = torch.ones_like(out, dtype=torch.float32)
1127 else:
1128 count = torch.zeros_like(out, dtype=torch.float32)
1130 src = src.contiguous()
1131 index = index.contiguous()
1133 # Convert strides/shapes to int64 to avoid overflow in kernel arithmetic
1134 idx_shapes = [int(x) for x in _pad5(list(index.shape), 1)]
1135 src_shapes = [int(x) for x in _pad5(list(src.shape), 1)]
1136 src_strides_p = [int(x) for x in _pad5(list(src.stride()), 0)]
1137 idx_strides_p = [int(x) for x in _pad5(list(index.stride()), 0)]
1138 out_shapes = [int(x) for x in _pad5(list(out.shape), 1)]
1139 out_strides_p = [int(x) for x in _pad5(list(out.stride()), 0)]
1141 grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * meta["LOOP"]),)
1143 dummy_mask = torch.empty(1, dtype=torch.int32, device=inp.device)
1144 mask_ptr = reduced_mask if use_mask else dummy_mask
1146 # Use 2D fast path for 2D tensors (most common case)
1147 use_2d = inp.ndim == 2
1149 # For 2D kernels, use raw dim (0 or 1) instead of padded_dim
1150 dim_2d = dim
1152 with torch_device_fn.device(inp.device):
1153 if reduce == "sum":
1154 if use_2d:
1155 idx_ncols = index.shape[1]
1156 src_ncols = src.shape[1]
1157 out_ncols = out.shape[1]
1158 scatter_reduce_sum_2d_kernel[grid](
1159 index,
1160 src,
1161 out,
1162 mask_ptr,
1163 N,
1164 idx_ncols,
1165 src_ncols,
1166 out_ncols,
1167 dim_2d,
1168 use_mask,
1169 )
1170 else:
1171 scatter_reduce_sum_kernel[grid](
1172 index,
1173 src,
1174 out,
1175 mask_ptr,
1176 N,
1177 out_stride_dim,
1178 src_stride_dim,
1179 src_shape_dim,
1180 out_shape_dim,
1181 padded_dim,
1182 use_mask,
1183 src_strides_p[0],
1184 src_strides_p[1],
1185 src_strides_p[2],
1186 src_strides_p[3],
1187 src_strides_p[4],
1188 idx_shapes[0],
1189 idx_shapes[1],
1190 idx_shapes[2],
1191 idx_shapes[3],
1192 idx_shapes[4],
1193 src_shapes[0],
1194 src_shapes[1],
1195 src_shapes[2],
1196 src_shapes[3],
1197 src_shapes[4],
1198 idx_strides_p[0],
1199 idx_strides_p[1],
1200 idx_strides_p[2],
1201 idx_strides_p[3],
1202 idx_strides_p[4],
1203 out_strides_p[0],
1204 out_strides_p[1],
1205 out_strides_p[2],
1206 out_strides_p[3],
1207 out_strides_p[4],
1208 )
1209 elif reduce == "prod":
1210 scan_grid = (out.numel(),)
1211 scatter_reduce_prod_scan_kernel[scan_grid](
1212 index,
1213 src,
1214 out,
1215 mask_ptr,
1216 out.numel(),
1217 padded_dim,
1218 use_mask,
1219 src_shape_dim,
1220 src_strides_p[0],
1221 src_strides_p[1],
1222 src_strides_p[2],
1223 src_strides_p[3],
1224 src_strides_p[4],
1225 idx_shapes[0],
1226 idx_shapes[1],
1227 idx_shapes[2],
1228 idx_shapes[3],
1229 idx_shapes[4],
1230 src_shapes[0],
1231 src_shapes[1],
1232 src_shapes[2],
1233 src_shapes[3],
1234 src_shapes[4],
1235 idx_strides_p[0],
1236 idx_strides_p[1],
1237 idx_strides_p[2],
1238 idx_strides_p[3],
1239 idx_strides_p[4],
1240 out_shapes[0],
1241 out_shapes[1],
1242 out_shapes[2],
1243 out_shapes[3],
1244 out_shapes[4],
1245 out_strides_p[0],
1246 out_strides_p[1],
1247 out_strides_p[2],
1248 out_strides_p[3],
1249 out_strides_p[4],
1250 )
1251 elif reduce == "mean":
1252 if use_2d:
1253 idx_ncols = index.shape[1]
1254 src_ncols = src.shape[1]
1255 out_ncols = out.shape[1]
1256 scatter_reduce_mean_2d_kernel[grid](
1257 index,
1258 src,
1259 out,
1260 count,
1261 mask_ptr,
1262 N,
1263 idx_ncols,
1264 src_ncols,
1265 out_ncols,
1266 dim_2d,
1267 use_mask,
1268 )
1269 else:
1270 scatter_reduce_mean_kernel[grid](
1271 index,
1272 src,
1273 out,
1274 count,
1275 mask_ptr,
1276 N,
1277 out_stride_dim,
1278 src_stride_dim,
1279 src_shape_dim,
1280 out_shape_dim,
1281 padded_dim,
1282 use_mask,
1283 src_strides_p[0],
1284 src_strides_p[1],
1285 src_strides_p[2],
1286 src_strides_p[3],
1287 src_strides_p[4],
1288 idx_shapes[0],
1289 idx_shapes[1],
1290 idx_shapes[2],
1291 idx_shapes[3],
1292 idx_shapes[4],
1293 src_shapes[0],
1294 src_shapes[1],
1295 src_shapes[2],
1296 src_shapes[3],
1297 src_shapes[4],
1298 idx_strides_p[0],
1299 idx_strides_p[1],
1300 idx_strides_p[2],
1301 idx_strides_p[3],
1302 idx_strides_p[4],
1303 out_strides_p[0],
1304 out_strides_p[1],
1305 out_strides_p[2],
1306 out_strides_p[3],
1307 out_strides_p[4],
1308 )
1309 has_contributions = count > 0
1310 count = torch.clamp(count, min=1.0)
1311 out = out / count
1312 out = torch.where(has_contributions, out, inp_f32)
1313 elif reduce in ("amax", "amin"):
1314 use_cas = _needs_cas_fallback()
1315 if use_2d:
1316 idx_ncols = index.shape[1]
1317 src_ncols = src.shape[1]
1318 out_ncols = out.shape[1]
1319 scatter_reduce_amax_2d_kernel[grid](
1320 index,
1321 src,
1322 out,
1323 mask_ptr,
1324 N,
1325 idx_ncols,
1326 src_ncols,
1327 out_ncols,
1328 dim_2d,
1329 reduce == "amax",
1330 use_mask,
1331 use_cas,
1332 )
1333 else:
1334 scatter_reduce_amax_kernel[grid](
1335 index,
1336 src,
1337 out,
1338 mask_ptr,
1339 N,
1340 out_stride_dim,
1341 src_stride_dim,
1342 src_shape_dim,
1343 out_shape_dim,
1344 padded_dim,
1345 reduce == "amax",
1346 use_mask,
1347 use_cas,
1348 src_strides_p[0],
1349 src_strides_p[1],
1350 src_strides_p[2],
1351 src_strides_p[3],
1352 src_strides_p[4],
1353 idx_shapes[0],
1354 idx_shapes[1],
1355 idx_shapes[2],
1356 idx_shapes[3],
1357 idx_shapes[4],
1358 src_shapes[0],
1359 src_shapes[1],
1360 src_shapes[2],
1361 src_shapes[3],
1362 src_shapes[4],
1363 idx_strides_p[0],
1364 idx_strides_p[1],
1365 idx_strides_p[2],
1366 idx_strides_p[3],
1367 idx_strides_p[4],
1368 out_strides_p[0],
1369 out_strides_p[1],
1370 out_strides_p[2],
1371 out_strides_p[3],
1372 out_strides_p[4],
1373 )
1375 if use_mask and reduce != "mean":
1376 unreduced = reduced_mask == 0
1377 out = torch.where(unreduced, inp_f32, out)
1379 return out.to(inp.dtype)
1382def scatter_reduce_(inp, dim, index, src, reduce, *, include_self=True):
1383 """In-place variant of scatter_reduce. Modifies inp in-place."""
1384 logger.debug("GEMS SCATTER_REDUCE_TWO_")
1386 result = scatter_reduce(inp, dim, index, src, reduce, include_self=include_self)
1387 inp.copy_(result)
1388 return inp
1391def scatter_reduce_out(inp, dim, index, src, reduce, *, include_self=True, out=None):
1392 """Out-variant of scatter_reduce. Writes result to out tensor if provided."""
1393 logger.debug("GEMS SCATTER_REDUCE_TWO_OUT")
1395 result = scatter_reduce(inp, dim, index, src, reduce, include_self=include_self)
1396 if out is not None:
1397 out.copy_(result)
1398 return out
1399 return result