Coverage for src/flag_gems/runtime/backend/_ascend/ops/scatter.py: 0%
73 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
5from flag_gems.ops.scatter import scatter as _scatter
6from flag_gems.ops.scatter import scatter_ as _scatter_
7from flag_gems.utils.shape_utils import restride_dim
9logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
12def _compute_flat_offset(shape, strides, dim, N):
13 idx = torch.arange(N, device="cpu", dtype=torch.int64)
14 coord = torch.empty((len(shape), N), dtype=torch.int64, device="cpu")
15 for i in reversed(range(len(shape))):
16 coord[i] = idx % shape[i]
17 idx = idx // shape[i]
18 offset = torch.zeros(N, dtype=torch.int64, device="cpu")
19 for i in range(len(shape)):
20 if i != dim:
21 offset += coord[i] * strides[i]
22 return offset
25def scatter_add_no_atomic(inp, dim, index, src):
26 out = inp.clone()
27 dim = dim % inp.ndim
28 dim_stride = inp.stride(dim)
30 src_strided = src.as_strided(index.shape, src.stride())
31 inp_restrided = restride_dim(inp, dim, index.shape)
33 N = index.numel()
34 if N == 0:
35 return out
37 flat_index = index.reshape(-1).to(torch.int64).cpu()
38 flat_src = src_strided.reshape(-1).contiguous().cpu()
39 out_cpu = out.cpu()
41 base_offset = _compute_flat_offset(index.shape, inp_restrided.stride(), dim, N).to(
42 torch.int64
43 )
45 flat_out = out_cpu.reshape(-1)
46 for i in range(N):
47 idx = flat_index[i].item()
48 out_offset = base_offset[i].item() + idx * dim_stride
49 flat_out[out_offset] += flat_src[i].item()
51 return out_cpu.to(inp.device)
54def scatter_reduce_multiply_no_atomic(inp, dim, index, src):
55 out = inp.clone()
56 dim = dim % inp.ndim
57 dim_stride = inp.stride(dim)
59 src_strided = src.as_strided(index.shape, src.stride())
60 inp_restrided = restride_dim(inp, dim, index.shape)
62 N = index.numel()
63 if N == 0:
64 return out
66 flat_index = index.reshape(-1).to(torch.int64).cpu()
67 flat_src = src_strided.reshape(-1).contiguous().cpu()
68 out_cpu = out.cpu()
70 base_offset = _compute_flat_offset(index.shape, inp_restrided.stride(), dim, N).to(
71 torch.int64
72 )
74 flat_out = out_cpu.reshape(-1)
75 for i in range(N):
76 idx = flat_index[i].item()
77 out_offset = base_offset[i].item() + idx * dim_stride
78 flat_out[out_offset] *= flat_src[i].item()
80 return out_cpu.to(inp.device)
83def scatter(inp, dim, index, src, reduce=None):
84 logger.debug("GEMS_ASCEND SCATTER")
85 if reduce == "add":
86 return scatter_add_no_atomic(inp, dim, index, src)
87 if reduce == "multiply":
88 return scatter_reduce_multiply_no_atomic(inp, dim, index, src)
89 return _scatter(inp, dim, index, src, reduce)
92def scatter_(inp, dim, index, src, reduce=None):
93 logger.debug("GEMS_ASCEND SCATTER_")
94 if reduce == "add":
95 result = scatter_add_no_atomic(inp, dim, index, src)
96 inp.copy_(result)
97 return inp
98 if reduce == "multiply":
99 result = scatter_reduce_multiply_no_atomic(inp, dim, index, src)
100 inp.copy_(result)
101 return inp
102 return _scatter_(inp, dim, index, src, reduce)