Coverage for src/flag_gems/runtime/backend/_ascend/ops/gather_collapsed_uintdiv.py: 0%
255 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import importlib.util
2import os
3import sys
4from typing import List, Tuple
6import torch
8from flag_gems.utils.code_utils import IndentedBuffer
10WARP_LIST = [8, 16, 32, 64]
11MEM_LIST = [120 * 1024, 216 * 1024]
12BLOCK_SIZE_LIST = [32, 64, 128, 256, 512, 1024, 2048]
15CACHE_DIR = os.path.join(os.getcwd(), "__triton_cache__")
16if not os.path.exists(CACHE_DIR):
17 os.makedirs(CACHE_DIR, exist_ok=True)
18sys.path.append(CACHE_DIR)
21def normalize_dim(dim: int, ndim: int) -> int:
22 if dim < 0:
23 dim += ndim
24 if dim < 0 or dim >= ndim:
25 raise ValueError(f"dim={dim} out of range for ndim={ndim}")
26 return dim
29def apply_prefix_narrows(
30 inp: torch.Tensor, narrows: List[Tuple[int, int]]
31) -> torch.Tensor:
32 for axis, new_size in narrows:
33 if new_size == inp.shape[axis]:
34 continue
35 inp = inp.narrow(axis, 0, new_size)
36 return inp
39def can_collapse_axes(
40 inp: torch.Tensor, index: torch.Tensor, dim: int
41) -> Tuple[bool, List[Tuple[int, int]]]:
42 """
43 Determine whether we can use the collapsed (3D) gather kernel.
44 Gather definition (dim = d):
45 Y[t0..tN-1] =
46 inp[t0..t_{d-1}, index[t0..tN-1], t_{d+1}..t_{N-1}]
48 Shape constraints:
49 - For i != d: index.shape[i] <= inp.shape[i]
50 - Output only accesses inp at coordinates 0 <= t_i < index.shape[i]
52 Collapsed kernel assumption:
53 We fold tensor into (Outer, Dim, Inner):
54 Outer = ∏_{i<d} shape[i]
55 Inner = ∏_{i>d} shape[i]
56 The same (off_outer, off_inner) must map consistently
57 in inp and index/out (linear isomorphism).
59 Policy:
60 - For i < dim (outer side):
61 allow index.shape[i] <= inp.shape[i].
62 If strictly smaller, we can prefix-narrow inp so that
63 outer dimensions match and linear mapping remains valid.
64 - For i > dim (inner side):
65 require exact equality to preserve inner linear mapping.
66 """
67 if inp.ndim != index.ndim:
68 return False, []
70 dim = normalize_dim(dim, inp.ndim)
71 narrows: List[Tuple[int, int]] = []
73 for i in range(inp.ndim):
74 if i == dim:
75 continue
77 inp_i = int(inp.shape[i])
78 idx_i = int(index.shape[i])
80 if i < dim:
81 if idx_i == inp_i:
82 continue
83 if idx_i < inp_i:
84 narrows.append((i, idx_i))
85 continue
86 return False, []
87 else:
88 if idx_i != inp_i:
89 return False, []
91 return True, narrows
94LIBDIVIDE_ADD_MARKER = 0x40
97def _clz32(x: int) -> int:
98 return 32 - x.bit_length() if x else 32
101def calc_magic_u32_libdivide(d: int):
102 """return (magic:uint32, more:uint8)"""
103 assert 1 <= d <= 0xFFFFFFFF
104 if (d & (d - 1)) == 0:
105 shift = d.bit_length() - 1
106 return 0, shift & 0xFF
107 floor_log_2_d = 31 - _clz32(d)
108 two_to = 1 << (32 + floor_log_2_d)
109 proposed_m = two_to // d
110 rem = two_to - proposed_m * d
111 e = d - rem
112 two_power = 1 << floor_log_2_d
113 if e < two_power:
114 magic = (proposed_m + 1) & 0xFFFFFFFF
115 more = floor_log_2_d & 0xFF
116 return magic, more
117 else:
118 proposed_m2 = proposed_m * 2
119 twice_rem = rem * 2
120 if twice_rem >= d or twice_rem < rem:
121 proposed_m2 += 1
122 magic = (proposed_m2 + 1) & 0xFFFFFFFF
123 more = (floor_log_2_d | LIBDIVIDE_ADD_MARKER) & 0xFF
124 return magic, more
127def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
128 code.writeline("import torch")
129 code.writeline("import triton")
130 code.writeline("import triton.language as tl")
131 code.newline()
132 return code
135def generate_collapsed_device_functions(code: IndentedBuffer) -> IndentedBuffer:
136 code.writeline("# Device Functions for Fast Division (collapsed path)")
137 code.newline()
139 code.writeline("@triton.jit")
140 code.writeline("def fast_divide_shift(n, shift):")
141 with code.indent():
142 code.writeline("return n >> shift")
143 code.newline()
145 code.writeline("@triton.jit")
146 code.writeline("def fast_divide_mul_noadd(n, magic, shift):")
147 with code.indent():
148 code.writeline("return (tl.umulhi(n, magic) >> shift)")
149 code.newline()
151 code.writeline("@triton.jit")
152 code.writeline("def fast_divide_mul_add(n, magic, shift):")
153 with code.indent():
154 code.writeline("q0 = tl.umulhi(n, magic)")
155 code.writeline("t = ((n - q0) >> 1) + q0")
156 code.writeline("return (t >> shift)")
157 code.newline()
159 return code
162def _collapsed_3d_views(
163 inp: torch.Tensor, dim: int, index: torch.Tensor, out: torch.Tensor
164):
165 dim = normalize_dim(dim, inp.ndim)
167 # Collapse Axes to 3D: (Outer, Dim, Inner)
168 idx_outer = 1
169 for i in range(dim):
170 idx_outer *= index.shape[i]
171 idx_inner = 1
172 for i in range(dim + 1, index.ndim):
173 idx_inner *= index.shape[i]
175 inp_outer = 1
176 for i in range(dim):
177 inp_outer *= inp.shape[i]
178 inp_inner = 1
179 for i in range(dim + 1, inp.ndim):
180 inp_inner *= inp.shape[i]
182 inp_3d = inp.contiguous().view(inp_outer, inp.shape[dim], inp_inner)
183 idx_3d = index.contiguous().view(idx_outer, index.shape[dim], idx_inner)
184 out_3d = out.view(idx_outer, index.shape[dim], idx_inner)
186 SIZE_OUTER = idx_outer
187 SIZE_DIM = idx_3d.shape[1]
188 SIZE_INNER = idx_inner
190 return inp_3d, idx_3d, out_3d, SIZE_OUTER, SIZE_DIM, SIZE_INNER
193def generate_collapsed_kernel(
194 kernel_name: str, div_kinds: List[str], code: IndentedBuffer
195) -> IndentedBuffer:
196 code.newline()
198 code.writeline("# Autotune Configuration Lists")
199 code.writeline("WARP_LIST = [8, 16, 32]")
200 code.writeline("MEM_LIST = [120 * 1024, 216 * 1024]")
201 code.writeline("BLOCK_SIZE_LIST = [32, 64, 128, 256, 512, 1024, 2048]")
202 code.writeline("REORDER_LIST = [True, False]")
203 code.newline()
205 code.writeline("@triton.autotune(configs=[")
206 with code.indent():
207 code.writeline(
208 "triton.Config("
209 "kwargs={'BLOCK_SIZE': size, 'shared_mem_dynamic_size': localmem, "
210 "'enable_simt_reorder_instruction': is_reorder}, num_warps=warp)"
211 )
212 code.writeline("for warp in WARP_LIST")
213 code.writeline("for localmem in MEM_LIST")
214 code.writeline("for size in BLOCK_SIZE_LIST")
215 code.writeline("for is_reorder in REORDER_LIST")
216 code.writeline("],")
217 code.writeline("key=['num_elements'], ")
218 code.writeline("warmup=25, ")
219 code.writeline("rep=100)")
221 code.writeline("@triton.jit")
222 code.writeline(f"def {kernel_name}(")
223 with code.indent():
224 args = [
225 "inp_ptr, ",
226 "index_ptr, ",
227 "out_ptr, ",
228 "SIZE_OUTER, ",
229 "SIZE_DIM, ",
230 "SIZE_INNER, ",
231 "stride_inp_outer, ",
232 "stride_inp_dim, ",
233 "stride_inp_inner, ",
234 "stride_idx_outer, ",
235 "stride_idx_dim, ",
236 "stride_idx_inner, ",
237 "stride_out_outer, ",
238 "stride_out_dim, ",
239 "stride_out_inner, ",
240 "inner_magic: tl.uint32, ",
241 "inner_shift: tl.uint32, ",
242 "dim_magic: tl.uint32, ",
243 "dim_shift: tl.uint32, ",
244 "num_elements, ",
245 "with_negative_index: tl.constexpr, ",
246 "BLOCK_SIZE: tl.constexpr, ",
247 ]
248 code.writelines(args)
249 code.writeline("):")
251 with code.indent():
252 code.writeline("pid = tl.program_id(0)")
253 code.writeline("num_programs = tl.num_programs(0)")
254 code.writeline("elements_per_prog = tl.cdiv(num_elements, num_programs)")
255 code.writeline("prog_start = pid * elements_per_prog")
256 code.writeline(
257 "prog_end = tl.minimum(prog_start + elements_per_prog, num_elements)"
258 )
259 code.newline()
261 code.writeline("for block_start in range(prog_start, prog_end, BLOCK_SIZE):")
262 with code.indent():
263 code.writeline("offsets = block_start + tl.arange(0, BLOCK_SIZE)")
264 code.writeline("mask = offsets < prog_end")
265 code.newline()
266 code.writeline("idx_val = tl.load(index_ptr + offsets, mask=mask, other=0)")
268 code.newline()
269 code.writeline("if with_negative_index:")
270 with code.indent():
271 code.writeline(
272 "idx_val = tl.where(idx_val < 0, idx_val + SIZE_DIM, idx_val)"
273 )
274 code.newline()
276 # offsets -> (off_outer, off_dim, off_inner)
277 # q1 = offsets // SIZE_INNER
278 # code.writeline("q1 = offsets // SIZE_INNER")
279 if div_kinds[0] == "S":
280 code.writeline("q1 = fast_divide_shift(offsets, inner_shift)")
281 elif div_kinds[0] == "A":
282 code.writeline(
283 "q1 = fast_divide_mul_add(offsets, inner_magic, inner_shift)"
284 )
285 else:
286 code.writeline(
287 "q1 = fast_divide_mul_noadd(offsets, inner_magic, inner_shift)"
288 )
290 code.writeline("off_inner = offsets - q1 * SIZE_INNER")
291 code.writeline("tmp = q1")
292 code.newline()
294 # q2 = tmp // SIZE_DIM
295 # code.writeline("q2 = tmp // SIZE_DIM")
296 if div_kinds[1] == "S":
297 code.writeline("q2 = fast_divide_shift(tmp, dim_shift)")
298 elif div_kinds[1] == "A":
299 code.writeline("q2 = fast_divide_mul_add(tmp, dim_magic, dim_shift)")
300 else:
301 code.writeline("q2 = fast_divide_mul_noadd(tmp, dim_magic, dim_shift)")
303 code.writeline("off_dim = tmp - q2 * SIZE_DIM")
304 code.writeline("off_outer = q2")
305 code.newline()
307 code.writeline("inp_off = (")
308 with code.indent():
309 code.writeline("off_outer * stride_inp_outer")
310 code.writeline("+ idx_val * stride_inp_dim")
311 code.writeline("+ off_inner * stride_inp_inner")
312 code.writeline(")")
313 code.writeline("val = tl.load(inp_ptr + inp_off, mask=mask, other=0.0)")
314 code.newline()
316 code.writeline("tl.store(out_ptr + offsets, val, mask=mask)")
318 code.newline()
319 return code
322def generate_collapsed_wrapper(
323 wrapper_name: str, kernel_name: str, code: IndentedBuffer
324) -> IndentedBuffer:
325 code.writeline(
326 f"def {wrapper_name}("
327 f"inp, index, out, grid, inner_magic, inner_shift, dim_magic, dim_shift, with_negative_index):"
328 )
329 with code.indent():
330 code.writeline("inp_shape = inp.shape")
331 code.writeline("inp_stride = inp.stride()")
332 code.writeline("index_shape = index.shape")
333 code.writeline("index_stride = index.stride()")
334 code.writeline("out_stride = out.stride()")
335 code.writeline("num_elements = out.numel()")
336 code.newline()
338 code.writeline(f"{kernel_name}[grid](")
339 with code.indent():
340 args = [
341 "inp, ",
342 "index, ",
343 "out, ",
344 "index_shape[0], ", # SIZE_OUTER
345 "index_shape[1], ", # SIZE_DIM
346 "index_shape[2], ", # SIZE_INNER
347 "inp_stride[0], ",
348 "inp_stride[1], ",
349 "inp_stride[2], ",
350 "index_stride[0], ",
351 "index_stride[1], ",
352 "index_stride[2], ",
353 "out_stride[0], ",
354 "out_stride[1], ",
355 "out_stride[2], ",
356 "inner_magic, ",
357 "inner_shift, ",
358 "dim_magic, ",
359 "dim_shift, ",
360 "num_elements, ",
361 "with_negative_index, ",
362 "force_simt_only=False, ",
363 ]
364 code.writelines(args)
365 code.writeline(")")
366 code.writeline("return out")
367 code.newline()
368 return code
371def generate_collapsed_code(
372 wrapper_name: str, kernel_name: str, div_kinds: List[str]
373) -> str:
374 code = IndentedBuffer()
375 code = generate_imports(code)
376 code = generate_collapsed_device_functions(code)
377 code = generate_collapsed_kernel(kernel_name, div_kinds, code)
378 code = generate_collapsed_wrapper(wrapper_name, kernel_name, code)
379 return code.getvalue()
382class CollapsedGatherFunction:
383 def __init__(self):
384 self.overloads = {}
386 def __call__(
387 self, inp, index, out, grid, magic_shift_map=None, with_negative_index=False
388 ):
389 assert inp.ndim == 3
390 assert index.ndim == 3
391 assert out.ndim == 3
393 if magic_shift_map is None:
394 # two divisors only: SIZE_INNER, SIZE_DIM
395 inner_magic, inner_more = calc_magic_u32_libdivide(int(index.shape[2]))
396 dim_magic, dim_more = calc_magic_u32_libdivide(int(index.shape[1]))
397 else:
398 (inner_magic, inner_more), (dim_magic, dim_more) = magic_shift_map
400 inner_shift = int(inner_more) & 0x1F
401 dim_shift = int(dim_more) & 0x1F
403 inner_kind = (
404 "S" if int(inner_magic) == 0 else ("A" if (int(inner_more) & 0x40) else "M")
405 )
406 dim_kind = (
407 "S" if int(dim_magic) == 0 else ("A" if (int(dim_more) & 0x40) else "M")
408 )
410 pattern = inner_kind + dim_kind
411 key = f"collapsed_pat_{pattern}"
413 if key not in self.overloads:
414 kernel_name = f"_gather_collapsed_kernel_{pattern}"
415 wrapper_name = f"_gather_collapsed_wrapper_{pattern}"
417 src_code = generate_collapsed_code(
418 wrapper_name, kernel_name, [inner_kind, dim_kind]
419 )
421 file_name = f"{key}.py"
422 file_path = os.path.join(CACHE_DIR, file_name)
423 with open(file_path, "w", encoding="utf-8") as f:
424 f.write(src_code)
426 spec = importlib.util.spec_from_file_location(
427 f"dynamic_collapsed_mod_{key}", file_path
428 )
429 mod = importlib.util.module_from_spec(spec)
430 assert spec.loader is not None
431 spec.loader.exec_module(mod)
433 self.overloads[key] = getattr(mod, wrapper_name)
435 return self.overloads[key](
436 inp,
437 index,
438 out,
439 grid,
440 inner_magic,
441 inner_shift,
442 dim_magic,
443 dim_shift,
444 with_negative_index,
445 )
448collapsed_gather = CollapsedGatherFunction()
451def gather_collapsed(
452 inp: torch.Tensor,
453 dim: int,
454 index: torch.Tensor,
455 out: torch.Tensor,
456 grid_fn,
457 return_run_kernel: bool = True,
458 with_negative_index: bool = False,
459):
460 if out.shape != index.shape:
461 raise ValueError(f"out.shape {out.shape} must equal index.shape {index.shape}")
463 dim = normalize_dim(dim, inp.ndim)
464 inp_3d, idx_3d, out_3d, SIZE_OUTER, SIZE_DIM, SIZE_INNER = _collapsed_3d_views(
465 inp, dim, index, out
466 )
468 def _run_kernel():
469 collapsed_gather(
470 inp_3d, idx_3d, out_3d, grid_fn, with_negative_index=with_negative_index
471 )
473 if return_run_kernel:
474 return _run_kernel
476 _run_kernel()