Coverage for src/flag_gems/runtime/backend/_ascend/ops/gather_collapsed.py: 0%
92 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1from typing import List, Tuple
3import torch
4import triton
5import triton.language as tl
7WARP_LIST = [8, 16, 32, 64]
8REORDER_LIST = [True, False]
9MEM_LIST = [120 * 1024, 216 * 1024]
10BLOCK_SIZE_LIST = [32, 64, 128, 256, 512, 1024, 2048]
13def normalize_dim(dim: int, ndim: int) -> int:
14 if dim < 0:
15 dim += ndim
16 if dim < 0 or dim >= ndim:
17 raise ValueError(f"dim={dim} out of range for ndim={ndim}")
18 return dim
21def apply_prefix_narrows(
22 inp: torch.Tensor, narrows: List[Tuple[int, int]]
23) -> torch.Tensor:
24 for axis, new_size in narrows:
25 if new_size == inp.shape[axis]:
26 continue
27 inp = inp.narrow(axis, 0, new_size)
28 return inp
31def can_collapse_axes(
32 inp: torch.Tensor, index: torch.Tensor, dim: int
33) -> Tuple[bool, List[Tuple[int, int]]]:
34 """
35 Determine whether we can use the collapsed (3D) gather kernel.
36 Gather definition (dim = d):
37 Y[t0..tN-1] =
38 inp[t0..t_{d-1}, index[t0..tN-1], t_{d+1}..t_{N-1}]
40 Shape constraints:
41 - For i != d: index.shape[i] <= inp.shape[i]
42 - Output only accesses inp at coordinates 0 <= t_i < index.shape[i]
44 Collapsed kernel assumption:
45 We fold tensor into (Outer, Dim, Inner):
46 Outer = ∏_{i<d} shape[i]
47 Inner = ∏_{i>d} shape[i]
48 The same (off_outer, off_inner) must map consistently
49 in inp and index/out (linear isomorphism).
51 Policy:
52 - For i < dim (outer side):
53 allow index.shape[i] <= inp.shape[i].
54 If strictly smaller, we can prefix-narrow inp so that
55 outer dimensions match and linear mapping remains valid.
56 - For i > dim (inner side):
57 require exact equality to preserve inner linear mapping.
58 """
59 if inp.ndim != index.ndim:
60 return False, []
62 dim = normalize_dim(dim, inp.ndim)
63 narrows: List[Tuple[int, int]] = []
65 for i in range(inp.ndim):
66 if i == dim:
67 continue
69 inp_i = int(inp.shape[i])
70 idx_i = int(index.shape[i])
72 if i < dim:
73 if idx_i == inp_i:
74 continue
75 if idx_i < inp_i:
76 narrows.append((i, idx_i))
77 continue
78 return False, []
79 else:
80 if idx_i != inp_i:
81 return False, []
83 return True, narrows
86@triton.autotune(
87 configs=[
88 triton.Config(
89 kwargs={
90 "BLOCK_SIZE": size,
91 "shared_mem_dynamic_size": localmem,
92 "enable_simt_reorder_instruction": is_reorder,
93 },
94 num_warps=warp,
95 )
96 for warp in WARP_LIST
97 for localmem in MEM_LIST
98 for size in BLOCK_SIZE_LIST
99 for is_reorder in REORDER_LIST
100 ],
101 key=["num_elements"],
102 warmup=25,
103 rep=100,
104)
105@triton.jit
106def gather_kernel_collapsed(
107 inp_ptr,
108 index_ptr,
109 out_ptr,
110 SIZE_OUTER,
111 SIZE_DIM,
112 SIZE_INNER,
113 stride_inp_outer,
114 stride_inp_dim,
115 stride_inp_inner,
116 stride_idx_outer,
117 stride_idx_dim,
118 stride_idx_inner,
119 stride_out_outer,
120 stride_out_dim,
121 stride_out_inner,
122 num_elements,
123 with_negative_index: tl.constexpr,
124 BLOCK_SIZE: tl.constexpr,
125):
126 pid = tl.program_id(0)
127 num_programs = tl.num_programs(0)
128 elements_per_prog = tl.cdiv(num_elements, num_programs)
129 prog_start = pid * elements_per_prog
130 prog_end = tl.minimum(prog_start + elements_per_prog, num_elements)
132 for block_start in range(prog_start, prog_end, BLOCK_SIZE):
133 offsets = block_start + tl.arange(0, BLOCK_SIZE)
134 mask = offsets < prog_end
136 idx_val = tl.load(index_ptr + offsets, mask=mask, other=0).to(tl.int64)
138 if with_negative_index:
139 idx_val = tl.where(idx_val < 0, idx_val + SIZE_DIM, idx_val)
140 # Coordinate Reconstruction: (outer, dim, inner)
141 off_inner = offsets % SIZE_INNER
142 tmp = offsets // SIZE_INNER
143 off_outer = tmp // SIZE_DIM
145 # Input Offset Calculation
146 inp_off = (
147 off_outer * stride_inp_outer
148 + idx_val * stride_inp_dim
149 + off_inner * stride_inp_inner
150 )
151 val = tl.load(inp_ptr + inp_off, mask=mask, other=0.0)
153 # Output Store
154 tl.store(out_ptr + offsets, val, mask=mask)
157def _collapsed_3d_views(
158 inp: torch.Tensor, dim: int, index: torch.Tensor, out: torch.Tensor
159):
160 dim = normalize_dim(dim, inp.ndim)
162 # Collapse Axes to 3D: (Outer, Dim, Inner)
163 idx_outer = 1
164 for i in range(dim):
165 idx_outer *= index.shape[i]
166 idx_inner = 1
167 for i in range(dim + 1, index.ndim):
168 idx_inner *= index.shape[i]
170 inp_outer = 1
171 for i in range(dim):
172 inp_outer *= inp.shape[i]
173 inp_inner = 1
174 for i in range(dim + 1, inp.ndim):
175 inp_inner *= inp.shape[i]
177 inp_3d = inp.contiguous().view(inp_outer, inp.shape[dim], inp_inner)
178 idx_3d = index.contiguous().view(idx_outer, index.shape[dim], idx_inner)
179 out_3d = out.view(idx_outer, index.shape[dim], idx_inner)
181 SIZE_OUTER = idx_outer
182 SIZE_DIM = idx_3d.shape[1]
183 SIZE_INNER = idx_inner
185 return inp_3d, idx_3d, out_3d, SIZE_OUTER, SIZE_DIM, SIZE_INNER
188def gather_collapsed(
189 inp: torch.Tensor,
190 dim: int,
191 index: torch.Tensor,
192 out: torch.Tensor,
193 grid_fn,
194 return_run_kernel: bool = True,
195 with_negative_index=False,
196):
197 if out.shape != index.shape:
198 raise ValueError(f"out.shape {out.shape} must equal index.shape {index.shape}")
200 dim = normalize_dim(dim, inp.ndim)
202 inp_3d, idx_3d, out_3d, SIZE_OUTER, SIZE_DIM, SIZE_INNER = _collapsed_3d_views(
203 inp, dim, index, out
204 )
205 num_elements = out_3d.numel()
207 def _run_kernel():
208 gather_kernel_collapsed[grid_fn](
209 inp_3d,
210 idx_3d,
211 out_3d,
212 # Shapes
213 SIZE_OUTER,
214 SIZE_DIM,
215 SIZE_INNER,
216 # Strides
217 inp_3d.stride(0),
218 inp_3d.stride(1),
219 inp_3d.stride(2),
220 idx_3d.stride(0),
221 idx_3d.stride(1),
222 idx_3d.stride(2),
223 out_3d.stride(0),
224 out_3d.stride(1),
225 out_3d.stride(2),
226 # Meta
227 num_elements,
228 with_negative_index,
229 force_simt_only=False,
230 )
232 if return_run_kernel:
233 return _run_kernel
235 _run_kernel()