Coverage for src/flag_gems/runtime/backend/_ascend/ops/gather.py: 0%
58 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import libentry
8from flag_gems.utils.shape_utils import restride_dim
10logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
11UB_SIZE_BYTES = 192 * 1024
14def compute_base_offset(shape, strides, dim):
15 idx = torch.arange(int(torch.prod(torch.tensor(shape))), device="cpu")
16 coord = torch.empty((len(shape), idx.numel()), dtype=torch.long, device="cpu")
17 for i in reversed(range(len(shape))):
18 coord[i] = idx % shape[i]
19 idx = idx // shape[i]
21 offset = torch.zeros_like(coord[0])
22 for i in range(len(shape)):
23 if i != dim:
24 offset += coord[i] * strides[i]
25 return offset
28@libentry()
29@triton.heuristics({"BLOCK_SIZE": lambda args: 1024})
30@triton.jit
31def _gather_flat_kernel_fixed(
32 inp,
33 index,
34 out,
35 base_offset,
36 inp_dim_stride,
37 N,
38 BLOCK_SIZE: tl.constexpr,
39):
40 pid = tl.program_id(0)
41 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
42 mask = offset < N
44 cur_index = tl.load(index + offset, mask=mask, other=0)
45 base = tl.load(base_offset + offset, mask=mask, other=0)
47 inp_offset = base + cur_index * inp_dim_stride
49 val = tl.load(inp + inp_offset, mask=mask, other=0)
50 tl.store(out + offset, val, mask=mask)
53def gather_flat_fixed(inp: torch.Tensor, dim: int, index: torch.Tensor, out=None):
54 logger.debug("GEMS_ASCEND GATHER (fixed version)")
56 if out is None:
57 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device)
59 N = index.numel()
60 dim_stride = inp.stride(dim)
61 inp_strided = restride_dim(inp, dim, index.shape)
62 if dim == -1:
63 dim = inp_strided.dim() - 1
64 base_offset = compute_base_offset(index.shape, inp_strided.stride(), dim).to(
65 torch.int64
66 )
67 base_offset = base_offset.npu()
68 grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),)
69 _gather_flat_kernel_fixed[grid](
70 inp_strided,
71 index,
72 out,
73 base_offset,
74 dim_stride,
75 N,
76 )
77 return out
80def gather(inp, dim, index, out=None, sparse_grad=False):
81 logger.debug("GEMS_ASCEND GATHER")
82 if inp.ndim != index.ndim:
83 raise IndexError(
84 f"self and index must have the same number of dimensions, "
85 f"got self.ndim = {inp.ndim} and index.ndim = {index.ndim}"
86 )
87 if out is None:
88 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device)
90 dim = dim % inp.dim()
91 return gather_flat_fixed(inp, dim, index, out)
94def gather_backward(grad, self, dim, index, sparse_grad):
95 logger.debug("GEMS_ASCEND GATHER BACKWARD")
96 from .scatter import scatter_
98 result = grad.new_zeros(self.shape)
99 return scatter_(result, dim, index, grad, reduce="add")