Coverage for src/flag_gems/ops/index_reduce.py: 44%
443 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7import flag_gems
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import dim_compress, libentry
11logger = logging.getLogger(__name__)
13REDUCE_PROD = 0
14REDUCE_MEAN = 1
15REDUCE_AMAX = 2
16REDUCE_AMIN = 3
19def _heur_block_m(args):
20 M = args["M"]
21 return 1 if M < 4 else 4
24def _heur_block_n(args):
25 N = args["N"]
26 return max(1, min(256, triton.next_power_of_2(N)))
29def _heur_flat_block(args):
30 total = args["TOTAL"] if "TOTAL" in args else args["N"]
31 return max(1, min(256, triton.next_power_of_2(total)))
34@libentry()
35@triton.heuristics({"BLOCK_M": _heur_block_m, "BLOCK_N": _heur_block_n})
36@triton.jit(do_not_specialize=["M", "N", "OUT_N"])
37def _index_reduce_kernel(
38 out,
39 index,
40 src,
41 count,
42 touched,
43 M,
44 N,
45 OUT_N,
46 REDUCE: tl.constexpr,
47 USE_COUNT: tl.constexpr,
48 USE_TOUCHED: tl.constexpr,
49 USE_CAS: tl.constexpr,
50 BLOCK_M: tl.constexpr,
51 BLOCK_N: tl.constexpr,
52):
53 pid_m = tl.program_id(axis=0)
54 pid_n = tl.program_id(axis=1)
56 rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
57 cols = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]
58 mask = (rows < M) & (cols < N)
60 dst_cols = tl.load(index + cols, mask=cols < N, other=0).to(tl.int64)
61 src_offsets = rows * N + cols
62 out_offsets = rows * OUT_N + dst_cols
63 values = tl.load(src + src_offsets, mask=mask, other=0.0)
65 if REDUCE == 1:
66 tl.atomic_add(out + out_offsets, values, mask=mask, sem="relaxed")
67 ones_i = tl.full((BLOCK_M, BLOCK_N), 1, dtype=tl.int32)
68 tl.atomic_add(count + out_offsets, ones_i, mask=mask, sem="relaxed")
69 elif REDUCE == 0:
70 stop = tl.where(mask, 0, 1).to(tl.int1)
71 block_stop = False
72 while not block_stop:
73 cur = tl.load(out + out_offsets, mask=mask, other=0.0)
74 new = tl.where(stop, cur, cur * values)
75 is_nan = new != new
76 new = tl.where(is_nan, values, new)
77 cas = tl.atomic_cas(out + out_offsets, cur, new, sem="relaxed")
78 stop |= (cur == cas) | is_nan
79 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK_M * BLOCK_N
80 else:
81 if USE_CAS:
82 stop = tl.where(mask, 0, 1).to(tl.int1)
83 block_stop = False
84 while not block_stop:
85 cur = tl.load(out + out_offsets, mask=mask, other=0.0)
86 if REDUCE == 2:
87 new = tl.maximum(cur, values)
88 else:
89 new = tl.minimum(cur, values)
90 cas = tl.atomic_cas(out + out_offsets, cur, new, sem="relaxed")
91 stop |= cur == cas
92 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK_M * BLOCK_N
93 else:
94 if REDUCE == 2:
95 tl.atomic_max(out + out_offsets, values, mask=mask, sem="relaxed")
96 else:
97 tl.atomic_min(out + out_offsets, values, mask=mask, sem="relaxed")
99 if USE_TOUCHED:
100 ones_i = tl.full((BLOCK_M, BLOCK_N), 1, dtype=tl.int32)
101 tl.atomic_add(touched + out_offsets, ones_i, mask=mask, sem="relaxed")
104@libentry()
105@triton.heuristics({"BLOCK": _heur_flat_block})
106@triton.jit(do_not_specialize=["TOTAL", "M", "N", "OUT_N"])
107def _index_reduce_flat_kernel(
108 out,
109 index,
110 src,
111 count,
112 touched,
113 TOTAL,
114 M,
115 N,
116 OUT_N,
117 REDUCE: tl.constexpr,
118 USE_COUNT: tl.constexpr,
119 USE_TOUCHED: tl.constexpr,
120 USE_CAS: tl.constexpr,
121 INDEX_MAJOR: tl.constexpr,
122 BLOCK: tl.constexpr,
123):
124 offsets = tl.program_id(axis=0) * BLOCK + tl.arange(0, BLOCK)
125 mask = offsets < TOTAL
127 if INDEX_MAJOR:
128 cols = offsets // M
129 rows = offsets - cols * M
130 else:
131 rows = offsets // N
132 cols = offsets - rows * N
134 dst_cols = tl.load(index + cols, mask=mask, other=0).to(tl.int64)
135 src_offsets = rows * N + cols
136 out_offsets = rows * OUT_N + dst_cols
137 values = tl.load(src + src_offsets, mask=mask, other=0.0)
139 if REDUCE == 1:
140 tl.atomic_add(out + out_offsets, values, mask=mask, sem="relaxed")
141 ones_i = tl.full((BLOCK,), 1, dtype=tl.int32)
142 tl.atomic_add(count + out_offsets, ones_i, mask=mask, sem="relaxed")
143 elif REDUCE == 0:
144 stop = tl.where(mask, 0, 1).to(tl.int1)
145 block_stop = False
146 while not block_stop:
147 cur = tl.load(out + out_offsets, mask=mask, other=0.0)
148 new = tl.where(stop, cur, cur * values)
149 is_nan = new != new
150 new = tl.where(is_nan, values, new)
151 cas = tl.atomic_cas(out + out_offsets, cur, new, sem="relaxed")
152 stop |= (cur == cas) | is_nan
153 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK
154 else:
155 if USE_CAS:
156 stop = tl.where(mask, 0, 1).to(tl.int1)
157 block_stop = False
158 while not block_stop:
159 cur = tl.load(out + out_offsets, mask=mask, other=0.0)
160 if REDUCE == 2:
161 new = tl.maximum(cur, values)
162 else:
163 new = tl.minimum(cur, values)
164 new = new.to(cur.dtype)
165 cas = tl.atomic_cas(out + out_offsets, cur, new, sem="relaxed")
166 stop |= cur == cas
167 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK
168 else:
169 if REDUCE == 2:
170 tl.atomic_max(out + out_offsets, values, mask=mask, sem="relaxed")
171 else:
172 tl.atomic_min(out + out_offsets, values, mask=mask, sem="relaxed")
174 if USE_TOUCHED:
175 ones_i = tl.full((BLOCK,), 1, dtype=tl.int32)
176 tl.atomic_add(touched + out_offsets, ones_i, mask=mask, sem="relaxed")
179@libentry()
180@triton.heuristics({"BLOCK": _heur_flat_block})
181@triton.jit(do_not_specialize=["TOTAL", "PRE", "POST", "N", "OUT_N"])
182def _index_reduce_contiguous_flat_kernel(
183 out,
184 index,
185 src,
186 count,
187 touched,
188 TOTAL,
189 PRE,
190 POST,
191 N,
192 OUT_N,
193 REDUCE: tl.constexpr,
194 USE_COUNT: tl.constexpr,
195 USE_TOUCHED: tl.constexpr,
196 USE_CAS: tl.constexpr,
197 INDEX_MAJOR: tl.constexpr,
198 BLOCK: tl.constexpr,
199):
200 offsets = tl.program_id(axis=0) * BLOCK + tl.arange(0, BLOCK)
201 mask = offsets < TOTAL
202 slice_size = PRE * POST
204 if INDEX_MAJOR:
205 cols = offsets // slice_size
206 element = offsets - cols * slice_size
207 else:
208 element = offsets // N
209 cols = offsets - element * N
211 pre = element // POST
212 post = element - pre * POST
213 dst_cols = tl.load(index + cols, mask=mask, other=0).to(tl.int64)
215 src_offsets = pre * N * POST + cols * POST + post
216 out_offsets = pre * OUT_N * POST + dst_cols * POST + post
217 values = tl.load(src + src_offsets, mask=mask, other=0.0)
219 if REDUCE == 1:
220 tl.atomic_add(out + out_offsets, values, mask=mask, sem="relaxed")
221 ones_i = tl.full((BLOCK,), 1, dtype=tl.int32)
222 tl.atomic_add(count + out_offsets, ones_i, mask=mask, sem="relaxed")
223 elif REDUCE == 0:
224 stop = tl.where(mask, 0, 1).to(tl.int1)
225 block_stop = False
226 while not block_stop:
227 cur = tl.load(out + out_offsets, mask=mask, other=0.0)
228 new = tl.where(stop, cur, cur * values)
229 is_nan = new != new
230 new = tl.where(is_nan, values, new)
231 cas = tl.atomic_cas(out + out_offsets, cur, new, sem="relaxed")
232 stop |= (cur == cas) | is_nan
233 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK
234 else:
235 if USE_CAS:
236 stop = tl.where(mask, 0, 1).to(tl.int1)
237 block_stop = False
238 while not block_stop:
239 cur = tl.load(out + out_offsets, mask=mask, other=0.0)
240 if REDUCE == 2:
241 new = tl.maximum(cur, values)
242 else:
243 new = tl.minimum(cur, values)
244 new = new.to(cur.dtype)
245 cas = tl.atomic_cas(out + out_offsets, cur, new, sem="relaxed")
246 stop |= cur == cas
247 block_stop = tl.sum(stop.to(tl.int32)) == BLOCK
248 else:
249 if REDUCE == 2:
250 tl.atomic_max(out + out_offsets, values, mask=mask, sem="relaxed")
251 else:
252 tl.atomic_min(out + out_offsets, values, mask=mask, sem="relaxed")
254 if USE_TOUCHED:
255 ones_i = tl.full((BLOCK,), 1, dtype=tl.int32)
256 tl.atomic_add(touched + out_offsets, ones_i, mask=mask, sem="relaxed")
259@libentry()
260@triton.heuristics({"BLOCK": _heur_flat_block})
261@triton.jit(do_not_specialize=["TOTAL"])
262def _index_reduce_mean_finalize_kernel(
263 result,
264 acc,
265 original,
266 count,
267 TOTAL,
268 INCLUDE_SELF: tl.constexpr,
269 BLOCK: tl.constexpr,
270):
271 offsets = tl.program_id(axis=0) * BLOCK + tl.arange(0, BLOCK)
272 mask = offsets < TOTAL
274 cnt = tl.load(count + offsets, mask=mask, other=0)
275 acc_val = tl.load(acc + offsets, mask=mask, other=0.0).to(tl.float32)
276 if INCLUDE_SELF:
277 denom = cnt + 1
278 result_val = acc_val / denom.to(tl.float32)
279 else:
280 denom = tl.maximum(cnt, 1)
281 mean_val = acc_val / denom.to(tl.float32)
282 original_val = tl.load(original + offsets, mask=mask, other=0.0)
283 result_val = tl.where(cnt > 0, mean_val, original_val)
284 tl.store(result + offsets, result_val, mask=mask)
287@libentry()
288@triton.heuristics({"BLOCK_M": _heur_block_m, "BLOCK_N": _heur_block_n})
289@triton.jit(do_not_specialize=["M", "N", "OUT_N"])
290def _index_reduce_unique_kernel(
291 out,
292 index,
293 src,
294 M,
295 N,
296 OUT_N,
297 REDUCE: tl.constexpr,
298 INCLUDE_SELF: tl.constexpr,
299 BLOCK_M: tl.constexpr,
300 BLOCK_N: tl.constexpr,
301):
302 pid_m = tl.program_id(axis=0)
303 pid_n = tl.program_id(axis=1)
305 rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
306 cols = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]
307 mask = (rows < M) & (cols < N)
309 dst_cols = tl.load(index + cols, mask=cols < N, other=0).to(tl.int64)
310 src_offsets = rows * N + cols
311 out_offsets = rows * OUT_N + dst_cols
312 src_values = tl.load(src + src_offsets, mask=mask, other=0.0)
314 if INCLUDE_SELF:
315 inp_values = tl.load(out + out_offsets, mask=mask, other=0.0)
316 if REDUCE == 0:
317 result = inp_values * src_values
318 elif REDUCE == 1:
319 result = (inp_values + src_values) * 0.5
320 elif REDUCE == 2:
321 result = tl.maximum(inp_values, src_values)
322 else:
323 result = tl.minimum(inp_values, src_values)
324 else:
325 result = src_values
327 tl.store(out + out_offsets, result, mask=mask)
330@libentry()
331@triton.heuristics({"BLOCK_N": _heur_block_n})
332@triton.jit(do_not_specialize=["TOTAL", "N", "OUT_N"])
333def _index_reduce_scan_kernel(
334 out,
335 index,
336 src,
337 inp,
338 TOTAL,
339 N,
340 OUT_N,
341 REDUCE: tl.constexpr,
342 INCLUDE_SELF: tl.constexpr,
343 USE_FP64: tl.constexpr,
344 BLOCK_N: tl.constexpr,
345):
346 pid = tl.program_id(axis=0)
347 mask_out = pid < TOTAL
348 row = pid // OUT_N
349 dst_col = (pid - row * OUT_N).to(tl.int64)
350 inp_val = tl.load(inp + pid, mask=mask_out, other=0.0)
351 if USE_FP64:
352 inp_val = inp_val.to(tl.float64)
353 else:
354 inp_val = inp_val.to(tl.float32)
356 if REDUCE == 0:
357 if USE_FP64:
358 acc = inp_val if INCLUDE_SELF else tl.full((), 1.0, dtype=tl.float64)
359 else:
360 acc = inp_val if INCLUDE_SELF else tl.full((), 1.0, dtype=tl.float32)
361 elif REDUCE == 1:
362 if USE_FP64:
363 acc = inp_val if INCLUDE_SELF else tl.full((), 0.0, dtype=tl.float64)
364 else:
365 acc = inp_val if INCLUDE_SELF else tl.full((), 0.0, dtype=tl.float32)
366 elif REDUCE == 2:
367 if USE_FP64:
368 acc = (
369 inp_val
370 if INCLUDE_SELF
371 else tl.full((), float("-inf"), dtype=tl.float64)
372 )
373 else:
374 acc = (
375 inp_val
376 if INCLUDE_SELF
377 else tl.full((), float("-inf"), dtype=tl.float32)
378 )
379 else:
380 if USE_FP64:
381 acc = (
382 inp_val if INCLUDE_SELF else tl.full((), float("inf"), dtype=tl.float64)
383 )
384 else:
385 acc = (
386 inp_val if INCLUDE_SELF else tl.full((), float("inf"), dtype=tl.float32)
387 )
389 hit_count = tl.full((), 1 if INCLUDE_SELF else 0, dtype=tl.int32)
390 if REDUCE == 0:
391 col = 0
392 while col < N:
393 current_col = tl.load(index + col).to(tl.int64)
394 matched = current_col == dst_col
395 value = tl.load(src + row * N + col, mask=matched, other=1.0)
396 if USE_FP64:
397 value = value.to(tl.float64)
398 else:
399 value = value.to(tl.float32)
400 acc *= tl.where(matched, value, 1.0)
401 hit_count += matched.to(tl.int32)
402 col += 1
403 else:
404 offsets = tl.arange(0, BLOCK_N)
405 start = 0
406 while start < N:
407 cols = start + offsets
408 mask = cols < N
409 dst_cols = tl.load(index + cols, mask=mask, other=-1).to(tl.int64)
410 matched = mask & (dst_cols == dst_col)
411 values = tl.load(src + row * N + cols, mask=mask, other=0.0)
412 if USE_FP64:
413 values = values.to(tl.float64)
414 else:
415 values = values.to(tl.float32)
417 matched_count = tl.sum(matched.to(tl.int32), axis=0)
418 hit_count += matched_count
419 if REDUCE == 1:
420 acc += tl.sum(tl.where(matched, values, 0.0), axis=0)
421 elif REDUCE == 2:
422 acc = tl.maximum(
423 acc, tl.max(tl.where(matched, values, float("-inf")), axis=0)
424 )
425 else:
426 acc = tl.minimum(
427 acc, tl.min(tl.where(matched, values, float("inf")), axis=0)
428 )
429 start += BLOCK_N
431 if REDUCE == 1:
432 acc = acc / tl.maximum(hit_count, 1).to(tl.float32)
433 result = tl.where(hit_count > 0, acc, inp_val)
434 tl.store(out + pid, result, mask=mask_out)
437def _reduce_id(reduce):
438 if reduce == "prod":
439 return REDUCE_PROD
440 if reduce == "mean":
441 return REDUCE_MEAN
442 if reduce == "amax":
443 return REDUCE_AMAX
444 if reduce == "amin":
445 return REDUCE_AMIN
446 raise RuntimeError(f"Unsupported reduce: {reduce}")
449def _identity_like(inp, reduce):
450 if reduce == "prod":
451 return torch.ones_like(inp)
452 if reduce == "mean":
453 return torch.zeros_like(inp)
454 if reduce == "amax":
455 return torch.full_like(inp, float("-inf"))
456 if reduce == "amin":
457 return torch.full_like(inp, float("inf"))
458 raise RuntimeError(f"Unsupported reduce: {reduce}")
461def _needs_cas(reduce, dtype):
462 return flag_gems.vendor_name in ("iluvatar",) or (
463 reduce in ("amax", "amin") and dtype in (torch.float16, torch.bfloat16)
464 )
467def _triton_version_at_least(major, minor):
468 version = triton.__version__.split("+", 1)[0]
469 parts = []
470 for part in version.split(".")[:2]:
471 number = ""
472 for char in part:
473 if not char.isdigit():
474 break
475 number += char
476 parts.append(int(number or 0))
477 while len(parts) < 2:
478 parts.append(0)
479 return tuple(parts) >= (major, minor)
482# Triton 3.3.x rejects bf16 atomic_add during semantic type checking.
483_TRITON_SUPPORTS_BF16_ATOMIC_ADD = _triton_version_at_least(3, 4)
486def _should_scan_duplicate_index(index, out_dim, reduce, dtype):
487 if flag_gems.vendor_name == "ascend":
488 return False
489 if _TRITON_SUPPORTS_BF16_ATOMIC_ADD:
490 return False
491 if reduce != "prod" and not _needs_cas(reduce, dtype):
492 return False
493 return not _index_is_unique(index, out_dim)
496def _index_is_unique(index, out_dim):
497 if index.numel() > out_dim:
498 return False
499 if index.numel() <= 1:
500 return True
501 if flag_gems.vendor_name == "ascend":
502 index_cpu = index.cpu()
503 return index_cpu.unique().numel() == index_cpu.numel()
504 return index.unique().numel() == index.numel()
507def _prod(values):
508 result = 1
509 for value in values:
510 result *= value
511 return result
514def _validate_args(inp, dim, index, source, reduce):
515 assert reduce in ("prod", "mean", "amax", "amin"), f"Unsupported reduce: {reduce}"
516 assert inp.ndim > 0, "Expected self to have at least one dimension"
517 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
518 assert index.ndim == 1, "Index is supposed to be a vector"
519 assert index.dtype in (
520 torch.int32,
521 torch.int64,
522 ), "Expected dtype int32/int64 for index"
523 assert (
524 inp.is_floating_point()
525 ), "index_reduce_(): Expected self to be floating point"
526 assert (
527 source.dtype == inp.dtype
528 ), "index_reduce_(): Expected self and source to have same dtype"
529 assert (
530 inp.ndim == source.ndim
531 ), "Self and source should have the same number of dimensions"
532 assert index.numel() == source.size(
533 dim
534 ), "The dimth dimension of source must have the same size as the length of index"
535 assert all(
536 inp.size(i) == source.size(i) or i == dim for i in range(inp.ndim)
537 ), "source.size(d) == self.size(d) for all dimensions d != dim"
540def _restore_dim(out, inp, dim):
541 if (
542 out.data_ptr() == inp.data_ptr()
543 and out.shape == inp.shape
544 and out.stride() == inp.stride()
545 ):
546 return inp
547 final_dim = inp.ndim - 1
548 if dim != final_dim:
549 order = list(range(out.ndim - 1))
550 order.insert(dim, final_dim)
551 out = out.permute(order).contiguous()
552 inp.copy_(out)
553 return inp
556def index_reduce_(inp, dim, index, source, reduce, *, include_self=True):
557 logger.debug("GEMS INDEX_REDUCE_")
558 _validate_args(inp, dim, index, source, reduce)
560 if index.numel() == 0:
561 return inp
563 dim = dim % inp.ndim
564 index = index.contiguous()
565 reduce_id = _reduce_id(reduce)
566 use_fp32_workspace = (
567 flag_gems.vendor_name != "ascend"
568 and reduce == "mean"
569 and inp.dtype == torch.bfloat16
570 and not _TRITON_SUPPORTS_BF16_ATOMIC_ADD
571 )
573 if _should_scan_duplicate_index(index, inp.size(dim), reduce, inp.dtype):
574 inp_work = dim_compress(inp, dim)
575 source_work = dim_compress(source, dim)
576 N = index.numel()
577 out_n = inp_work.size(-1)
578 compute_dtype = (
579 torch.float64 if inp_work.dtype == torch.float64 else torch.float32
580 )
581 inp_compute = inp_work.to(compute_dtype)
582 source_compute = source_work.to(compute_dtype)
583 out = torch.empty_like(inp_compute)
584 total = inp_compute.numel()
585 grid = (total,)
586 with torch_device_fn.device(inp.device):
587 _index_reduce_scan_kernel[grid](
588 out,
589 index,
590 source_compute,
591 inp_compute,
592 total,
593 N,
594 out_n,
595 reduce_id,
596 include_self,
597 compute_dtype == torch.float64,
598 )
599 return _restore_dim(out.to(inp.dtype), inp, dim)
601 if (
602 flag_gems.vendor_name != "ascend"
603 and inp.is_contiguous()
604 and source.is_contiguous()
605 and not use_fp32_workspace
606 ):
607 pre = _prod(inp.shape[:dim])
608 post = _prod(inp.shape[dim + 1 :])
609 N = index.numel()
610 out_n = inp.size(dim)
611 total = pre * post * N
613 if include_self:
614 out = inp
615 else:
616 out = _identity_like(inp, reduce)
617 touched = torch.zeros_like(inp, dtype=torch.int32)
619 if reduce == "mean":
620 count = torch.zeros_like(inp, dtype=torch.int32)
621 else:
622 count = torch.empty(1, dtype=torch.int32, device=inp.device)
624 if include_self:
625 touched = torch.empty(1, dtype=torch.int32, device=inp.device)
627 use_cas = _needs_cas(reduce, inp.dtype)
628 index_major = post > 1 or dim == 0
629 with torch_device_fn.device(inp.device):
630 _index_reduce_contiguous_flat_kernel[
631 (lambda meta: (triton.cdiv(total, meta["BLOCK"]),))
632 ](
633 out,
634 index,
635 source,
636 count,
637 touched,
638 total,
639 pre,
640 post,
641 N,
642 out_n,
643 reduce_id,
644 reduce == "mean",
645 not include_self,
646 use_cas,
647 index_major,
648 )
650 if reduce == "mean":
651 acc = out
652 with torch_device_fn.device(inp.device):
653 _index_reduce_mean_finalize_kernel[
654 (lambda meta: (triton.cdiv(inp.numel(), meta["BLOCK"]),))
655 ](
656 inp,
657 acc,
658 inp,
659 count,
660 inp.numel(),
661 include_self,
662 )
663 elif not include_self:
664 inp.copy_(torch.where(touched == 0, inp, out))
665 return inp
667 inp_work = dim_compress(inp, dim)
668 source_work = dim_compress(source, dim)
670 M = source_work.numel() // index.numel()
671 N = index.numel()
672 out_n = inp_work.size(-1)
674 if flag_gems.vendor_name == "ascend" and _index_is_unique(index, out_n):
675 out = inp_work
676 grid = lambda meta: (
677 triton.cdiv(M, meta["BLOCK_M"]),
678 triton.cdiv(N, meta["BLOCK_N"]),
679 )
680 with torch_device_fn.device(inp.device):
681 _index_reduce_unique_kernel[grid](
682 out,
683 index,
684 source_work,
685 M,
686 N,
687 out_n,
688 reduce_id,
689 include_self,
690 False,
691 )
692 return _restore_dim(out, inp, dim)
694 if flag_gems.vendor_name == "ascend":
695 inp_compute = inp_work.to(torch.float32)
696 source_compute = source_work.to(torch.float32)
697 out = torch.empty_like(inp_compute)
698 total = inp_compute.numel()
699 grid = (total,)
700 with torch_device_fn.device(inp.device):
701 _index_reduce_scan_kernel[grid](
702 out,
703 index,
704 source_compute,
705 inp_compute,
706 total,
707 N,
708 out_n,
709 reduce_id,
710 include_self,
711 False,
712 )
713 return _restore_dim(out.to(inp.dtype), inp, dim)
715 if use_fp32_workspace:
716 inp_compute = inp_work.to(torch.float32)
717 source_compute = source_work.to(torch.float32)
718 else:
719 inp_compute = inp_work
720 source_compute = source_work
722 if include_self:
723 out = inp_compute
724 else:
725 out = _identity_like(inp_compute, reduce)
726 touched = torch.zeros_like(inp_compute, dtype=torch.int32)
728 if reduce == "mean":
729 count = torch.zeros_like(out, dtype=torch.int32)
730 else:
731 count = torch.empty(1, dtype=torch.int32, device=inp.device)
733 if include_self:
734 touched = torch.empty(1, dtype=torch.int32, device=inp.device)
736 use_cas = _needs_cas(reduce, inp_work.dtype)
737 total = M * N
738 index_major = dim == 0
740 with torch_device_fn.device(inp.device):
741 _index_reduce_flat_kernel[(lambda meta: (triton.cdiv(total, meta["BLOCK"]),))](
742 out,
743 index,
744 source_compute,
745 count,
746 touched,
747 total,
748 M,
749 N,
750 out_n,
751 reduce_id,
752 reduce == "mean",
753 not include_self,
754 use_cas,
755 index_major,
756 )
758 if reduce == "mean":
759 result = out
760 with torch_device_fn.device(inp.device):
761 _index_reduce_mean_finalize_kernel[
762 (lambda meta: (triton.cdiv(inp_compute.numel(), meta["BLOCK"]),))
763 ](
764 result,
765 out,
766 inp_compute,
767 count,
768 inp_compute.numel(),
769 include_self,
770 )
771 out = result
772 elif not include_self:
773 out = torch.where(touched == 0, inp_compute, out)
775 return _restore_dim(out.to(inp.dtype), inp, dim)