Coverage for src/flag_gems/runtime/backend/_ascend/ops/scatter_add_.py: 0%
36 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
5from flag_gems.utils.shape_utils import restride_dim
7logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
10def _compute_flat_offset(shape, strides, dim, N):
11 idx = torch.arange(N, device="cpu", dtype=torch.int64)
12 coord = torch.empty((len(shape), N), dtype=torch.int64, device="cpu")
13 for i in reversed(range(len(shape))):
14 coord[i] = idx % shape[i]
15 idx = idx // shape[i]
16 offset = torch.zeros(N, dtype=torch.int64, device="cpu")
17 for i in range(len(shape)):
18 if i != dim:
19 offset += coord[i] * strides[i]
20 return offset
23def scatter_add_(inp, dim, index, src):
24 logger.debug("GEMS_ASCEND SCATTER_ADD_")
25 out = inp
26 dim = dim % inp.ndim
27 dim_stride = inp.stride(dim)
29 src_strided = src.as_strided(index.shape, src.stride())
30 inp_restrided = restride_dim(inp, dim, index.shape)
32 N = index.numel()
33 if N == 0:
34 return out
36 flat_index = index.reshape(-1).to(torch.int64).cpu()
37 flat_src = src_strided.reshape(-1).contiguous().cpu()
38 out_cpu = out.cpu()
40 base_offset = _compute_flat_offset(index.shape, inp_restrided.stride(), dim, N).to(
41 torch.int64
42 )
44 flat_out = out_cpu.reshape(-1)
45 for i in range(N):
46 idx = flat_index[i].item()
47 out_offset = base_offset[i].item() + idx * dim_stride
48 flat_out[out_offset] += flat_src[i].item()
50 inp.copy_(out_cpu.to(inp.device))
51 return inp