Coverage for src/flag_gems/runtime/backend/_ascend/ops/gather.py: 0%

58 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import libentry 

8from flag_gems.utils.shape_utils import restride_dim 

9 

10logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

11UB_SIZE_BYTES = 192 * 1024 

12 

13 

14def compute_base_offset(shape, strides, dim): 

15 idx = torch.arange(int(torch.prod(torch.tensor(shape))), device="cpu") 

16 coord = torch.empty((len(shape), idx.numel()), dtype=torch.long, device="cpu") 

17 for i in reversed(range(len(shape))): 

18 coord[i] = idx % shape[i] 

19 idx = idx // shape[i] 

20 

21 offset = torch.zeros_like(coord[0]) 

22 for i in range(len(shape)): 

23 if i != dim: 

24 offset += coord[i] * strides[i] 

25 return offset 

26 

27 

28@libentry() 

29@triton.heuristics({"BLOCK_SIZE": lambda args: 1024}) 

30@triton.jit 

31def _gather_flat_kernel_fixed( 

32 inp, 

33 index, 

34 out, 

35 base_offset, 

36 inp_dim_stride, 

37 N, 

38 BLOCK_SIZE: tl.constexpr, 

39): 

40 pid = tl.program_id(0) 

41 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

42 mask = offset < N 

43 

44 cur_index = tl.load(index + offset, mask=mask, other=0) 

45 base = tl.load(base_offset + offset, mask=mask, other=0) 

46 

47 inp_offset = base + cur_index * inp_dim_stride 

48 

49 val = tl.load(inp + inp_offset, mask=mask, other=0) 

50 tl.store(out + offset, val, mask=mask) 

51 

52 

53def gather_flat_fixed(inp: torch.Tensor, dim: int, index: torch.Tensor, out=None): 

54 logger.debug("GEMS_ASCEND GATHER (fixed version)") 

55 

56 if out is None: 

57 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device) 

58 

59 N = index.numel() 

60 dim_stride = inp.stride(dim) 

61 inp_strided = restride_dim(inp, dim, index.shape) 

62 if dim == -1: 

63 dim = inp_strided.dim() - 1 

64 base_offset = compute_base_offset(index.shape, inp_strided.stride(), dim).to( 

65 torch.int64 

66 ) 

67 base_offset = base_offset.npu() 

68 grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),) 

69 _gather_flat_kernel_fixed[grid]( 

70 inp_strided, 

71 index, 

72 out, 

73 base_offset, 

74 dim_stride, 

75 N, 

76 ) 

77 return out 

78 

79 

80def gather(inp, dim, index, out=None, sparse_grad=False): 

81 logger.debug("GEMS_ASCEND GATHER") 

82 if inp.ndim != index.ndim: 

83 raise IndexError( 

84 f"self and index must have the same number of dimensions, " 

85 f"got self.ndim = {inp.ndim} and index.ndim = {index.ndim}" 

86 ) 

87 if out is None: 

88 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device) 

89 

90 dim = dim % inp.dim() 

91 return gather_flat_fixed(inp, dim, index, out) 

92 

93 

94def gather_backward(grad, self, dim, index, sparse_grad): 

95 logger.debug("GEMS_ASCEND GATHER BACKWARD") 

96 from .scatter import scatter_ 

97 

98 result = grad.new_zeros(self.shape) 

99 return scatter_(result, dim, index, grad, reduce="add")