Coverage for src/flag_gems/runtime/backend/_ascend/ops/scatter.py: 0%

73 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4 

5from flag_gems.ops.scatter import scatter as _scatter 

6from flag_gems.ops.scatter import scatter_ as _scatter_ 

7from flag_gems.utils.shape_utils import restride_dim 

8 

9logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

10 

11 

12def _compute_flat_offset(shape, strides, dim, N): 

13 idx = torch.arange(N, device="cpu", dtype=torch.int64) 

14 coord = torch.empty((len(shape), N), dtype=torch.int64, device="cpu") 

15 for i in reversed(range(len(shape))): 

16 coord[i] = idx % shape[i] 

17 idx = idx // shape[i] 

18 offset = torch.zeros(N, dtype=torch.int64, device="cpu") 

19 for i in range(len(shape)): 

20 if i != dim: 

21 offset += coord[i] * strides[i] 

22 return offset 

23 

24 

25def scatter_add_no_atomic(inp, dim, index, src): 

26 out = inp.clone() 

27 dim = dim % inp.ndim 

28 dim_stride = inp.stride(dim) 

29 

30 src_strided = src.as_strided(index.shape, src.stride()) 

31 inp_restrided = restride_dim(inp, dim, index.shape) 

32 

33 N = index.numel() 

34 if N == 0: 

35 return out 

36 

37 flat_index = index.reshape(-1).to(torch.int64).cpu() 

38 flat_src = src_strided.reshape(-1).contiguous().cpu() 

39 out_cpu = out.cpu() 

40 

41 base_offset = _compute_flat_offset(index.shape, inp_restrided.stride(), dim, N).to( 

42 torch.int64 

43 ) 

44 

45 flat_out = out_cpu.reshape(-1) 

46 for i in range(N): 

47 idx = flat_index[i].item() 

48 out_offset = base_offset[i].item() + idx * dim_stride 

49 flat_out[out_offset] += flat_src[i].item() 

50 

51 return out_cpu.to(inp.device) 

52 

53 

54def scatter_reduce_multiply_no_atomic(inp, dim, index, src): 

55 out = inp.clone() 

56 dim = dim % inp.ndim 

57 dim_stride = inp.stride(dim) 

58 

59 src_strided = src.as_strided(index.shape, src.stride()) 

60 inp_restrided = restride_dim(inp, dim, index.shape) 

61 

62 N = index.numel() 

63 if N == 0: 

64 return out 

65 

66 flat_index = index.reshape(-1).to(torch.int64).cpu() 

67 flat_src = src_strided.reshape(-1).contiguous().cpu() 

68 out_cpu = out.cpu() 

69 

70 base_offset = _compute_flat_offset(index.shape, inp_restrided.stride(), dim, N).to( 

71 torch.int64 

72 ) 

73 

74 flat_out = out_cpu.reshape(-1) 

75 for i in range(N): 

76 idx = flat_index[i].item() 

77 out_offset = base_offset[i].item() + idx * dim_stride 

78 flat_out[out_offset] *= flat_src[i].item() 

79 

80 return out_cpu.to(inp.device) 

81 

82 

83def scatter(inp, dim, index, src, reduce=None): 

84 logger.debug("GEMS_ASCEND SCATTER") 

85 if reduce == "add": 

86 return scatter_add_no_atomic(inp, dim, index, src) 

87 if reduce == "multiply": 

88 return scatter_reduce_multiply_no_atomic(inp, dim, index, src) 

89 return _scatter(inp, dim, index, src, reduce) 

90 

91 

92def scatter_(inp, dim, index, src, reduce=None): 

93 logger.debug("GEMS_ASCEND SCATTER_") 

94 if reduce == "add": 

95 result = scatter_add_no_atomic(inp, dim, index, src) 

96 inp.copy_(result) 

97 return inp 

98 if reduce == "multiply": 

99 result = scatter_reduce_multiply_no_atomic(inp, dim, index, src) 

100 inp.copy_(result) 

101 return inp 

102 return _scatter_(inp, dim, index, src, reduce)