Coverage for src/flag_gems/runtime/backend/_ascend/ops/scatter_add_.py: 0%

36 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import torch 

4 

5from flag_gems.utils.shape_utils import restride_dim 

6 

7logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

8 

9 

10def _compute_flat_offset(shape, strides, dim, N): 

11 idx = torch.arange(N, device="cpu", dtype=torch.int64) 

12 coord = torch.empty((len(shape), N), dtype=torch.int64, device="cpu") 

13 for i in reversed(range(len(shape))): 

14 coord[i] = idx % shape[i] 

15 idx = idx // shape[i] 

16 offset = torch.zeros(N, dtype=torch.int64, device="cpu") 

17 for i in range(len(shape)): 

18 if i != dim: 

19 offset += coord[i] * strides[i] 

20 return offset 

21 

22 

23def scatter_add_(inp, dim, index, src): 

24 logger.debug("GEMS_ASCEND SCATTER_ADD_") 

25 out = inp 

26 dim = dim % inp.ndim 

27 dim_stride = inp.stride(dim) 

28 

29 src_strided = src.as_strided(index.shape, src.stride()) 

30 inp_restrided = restride_dim(inp, dim, index.shape) 

31 

32 N = index.numel() 

33 if N == 0: 

34 return out 

35 

36 flat_index = index.reshape(-1).to(torch.int64).cpu() 

37 flat_src = src_strided.reshape(-1).contiguous().cpu() 

38 out_cpu = out.cpu() 

39 

40 base_offset = _compute_flat_offset(index.shape, inp_restrided.stride(), dim, N).to( 

41 torch.int64 

42 ) 

43 

44 flat_out = out_cpu.reshape(-1) 

45 for i in range(N): 

46 idx = flat_index[i].item() 

47 out_offset = base_offset[i].item() + idx * dim_stride 

48 flat_out[out_offset] += flat_src[i].item() 

49 

50 inp.copy_(out_cpu.to(inp.device)) 

51 return inp