Coverage for src/flag_gems/runtime/backend/_ascend/ops/gather_ascend.py: 0%
234 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import functools
2import importlib.util
3import logging
4import os
5import sys
6from typing import List, Tuple
8import torch
10from flag_gems.utils.code_utils import IndentedBuffer
12logger = logging.getLogger(__name__)
13LIBDIVIDE_32_SHIFT_MASK = 0x1F
14LIBDIVIDE_ADD_MARKER = 0x40
16CACHE_DIR = os.path.join(os.getcwd(), "__triton_cache__")
17if not os.path.exists(CACHE_DIR):
18 os.makedirs(CACHE_DIR, exist_ok=True)
19sys.path.append(CACHE_DIR)
22def _clz32(x: int) -> int:
23 return 32 - x.bit_length() if x else 32
26def calc_magic_u32_libdivide(d: int) -> Tuple[int, int]:
27 """
28 Compute the libdivide (u32) fast division parameters for a given divisor d.
29 Returns: (magic:uint32, more:uint8)
31 - magic == 0 indicates the power-of-two path (shift only)
32 - the lower 5 bits of `more` represent the shift value
33 - bit6 (0x40) of `more` indicates the add_marker flag
34 """
35 if not (1 <= d <= 0xFFFFFFFF):
36 raise ValueError(f"d must be in [1, 2^32-1], got {d}")
38 # pow2 -> shift path
39 if (d & (d - 1)) == 0:
40 shift = d.bit_length() - 1
41 return 0, shift & 0xFF
43 floor_log_2_d = 31 - _clz32(d)
45 # 2^(32+floor_log_2_d)
46 two_to = 1 << (32 + floor_log_2_d)
48 proposed_m = two_to // d
49 rem = two_to - proposed_m * d
50 e = d - rem
51 two_power = 1 << floor_log_2_d
53 if e < two_power:
54 # no add marker
55 magic = (proposed_m + 1) & 0xFFFFFFFF
56 more = floor_log_2_d & 0xFF
57 return magic, more
58 else:
59 # add marker
60 proposed_m2 = proposed_m * 2
61 twice_rem = rem * 2
62 if twice_rem >= d or twice_rem < rem:
63 proposed_m2 += 1
64 magic = (proposed_m2 + 1) & 0xFFFFFFFF
65 more = (floor_log_2_d | LIBDIVIDE_ADD_MARKER) & 0xFF
66 return magic, more
69@functools.lru_cache(maxsize=128)
70def get_all_magics(shape_tuple: Tuple[int, ...]) -> Tuple[List[int], List[int]]:
71 magic_list, more_list = [], []
72 for d in shape_tuple:
73 magic, more = calc_magic_u32_libdivide(int(d))
74 magic_list.append(magic)
75 more_list.append(more)
76 return magic_list, more_list
79def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
80 code.writeline("import torch")
81 code.writeline("import triton")
82 code.writeline("import triton.language as tl")
83 code.newline()
84 return code
87def generate_device_functions(code: IndentedBuffer) -> IndentedBuffer:
88 code.writeline(
89 "# Device Functions for Fast Division (assume uint32 inputs, no casts/masks)"
90 )
91 code.newline()
93 # shift-only (magic==0)
94 code.writeline("@triton.jit")
95 code.writeline("def fast_divide_shift(n, shift):")
96 with code.indent():
97 code.writeline("return n >> shift")
98 code.newline()
100 # mul-noadd (add_marker==0)
101 code.writeline("@triton.jit")
102 code.writeline("def fast_divide_mul_noadd(n, magic, shift):")
103 with code.indent():
104 code.writeline("return (tl.umulhi(n, magic) >> shift)")
105 code.newline()
107 # mul-add (add_marker==1)
108 code.writeline("@triton.jit")
109 code.writeline("def fast_divide_mul_add(n, magic, shift):")
110 with code.indent():
111 code.writeline("q0 = tl.umulhi(n, magic)")
112 code.writeline("t = ((n - q0) >> 1) + q0")
113 code.writeline("return (t >> shift)")
114 code.newline()
116 return code
119def generate_gather_kernel(
120 rank: int, kernel_name: str, div_kinds: List[str], code: IndentedBuffer
121) -> IndentedBuffer:
122 code.newline()
124 # Autotune lists
125 code.writeline("# Autotune Configuration Lists")
126 code.writeline("WARP_LIST = [8, 16, 32]")
127 code.writeline("MEM_LIST = [120 * 1024, 216 * 1024]")
128 code.writeline("BLOCK_SIZE_LIST = [32, 64, 128, 256, 512, 1024]")
129 code.writeline("REORDER_LIST = [True, False]")
130 code.newline()
132 code.writeline("@triton.autotune(configs=[")
133 with code.indent():
134 code.writeline(
135 "triton.Config("
136 "kwargs={'BLOCK_SIZE': size, 'shared_mem_dynamic_size': localmem, "
137 "'enable_simt_reorder_instruction': is_reorder}, num_warps=warp)"
138 )
139 code.writeline("for warp in WARP_LIST")
140 code.writeline("for localmem in MEM_LIST")
141 code.writeline("for size in BLOCK_SIZE_LIST")
142 code.writeline("for is_reorder in REORDER_LIST")
143 code.writeline("],")
144 code.writeline("key=['num_elements'], ")
145 code.writeline("warmup=25, ")
146 code.writeline("rep=100) ")
148 code.writeline("@triton.jit")
149 code.writeline(f"def {kernel_name}(")
150 with code.indent():
151 args = [
152 "inp_ptr, ",
153 "index_ptr, ",
154 "out_ptr, ",
155 ]
156 # Unroll shapes and strides in signature to avoid metadata loads
157 args += [f"inp_shape{i}, " for i in range(rank)]
158 args += [f"index_shape{i}, " for i in range(rank)]
160 # libdivide params per dimension
161 args += [f"index_magic{i}: tl.uint32, " for i in range(rank)]
162 args += [f"index_more{i}: tl.uint32, " for i in range(rank)]
164 args += [f"inp_stride{i}, " for i in range(rank)]
165 args += [f"index_stride{i}, " for i in range(rank)]
166 args += [f"out_stride{i}, " for i in range(rank)]
168 args += [
169 "dim: tl.constexpr, ",
170 "num_elements, ",
171 "with_negative_index: tl.constexpr, ",
172 "BLOCK_SIZE: tl.constexpr, ",
173 ]
174 code.writelines(args)
175 code.writeline("):")
177 with code.indent():
178 code.writeline("pid = tl.program_id(0)")
179 code.writeline("num_programs = tl.num_programs(0)")
180 code.writeline("elements_per_prog = tl.cdiv(num_elements, num_programs)")
181 code.writeline("prog_start = pid * elements_per_prog")
182 code.writeline(
183 "prog_end = tl.minimum(prog_start + elements_per_prog, num_elements)"
184 )
185 code.newline()
187 code.writeline(
188 "# Block-Stride Loop (Processing contiguous chunks for better cache hit rate)"
189 )
190 code.writeline("for block_start in range(prog_start, prog_end, BLOCK_SIZE):")
191 with code.indent():
192 code.writeline("offsets = block_start + tl.arange(0, BLOCK_SIZE)")
193 code.writeline("mask = offsets < num_elements")
194 code.newline()
196 code.writeline(
197 "gather_index = tl.load(index_ptr + offsets, mask=mask, other=0).to(tl.int32)"
198 )
200 code.writeline("base_inp_offset = tl.zeros([BLOCK_SIZE], dtype=tl.int32)")
201 code.writeline("cur_offset = offsets.to(tl.int32)")
202 code.newline()
204 code.writeline("dim_stride = 0")
205 code.writeline("dim_size = 0")
206 code.newline()
208 for i in range(rank - 1, -1, -1):
209 if i == 0:
210 # After processing dims [rank-1 .. 1], cur_offset is already in [0, index_shape0).
211 # So: next_offset = cur_offset // index_shape0 == 0, coord_0 == cur_offset.
212 code.writeline("coord_0 = cur_offset")
213 code.writeline("cur_offset = 0")
214 else:
215 code.writeline(f"shift = index_more{i}")
216 if div_kinds[i] == "S":
217 code.writeline(
218 "next_offset = fast_divide_shift(cur_offset, shift)"
219 )
220 elif div_kinds[i] == "A":
221 code.writeline(f"magic = index_magic{i}")
222 code.writeline(
223 "next_offset = fast_divide_mul_add(cur_offset, magic, shift)"
224 )
225 else:
226 code.writeline(f"magic = index_magic{i}")
227 code.writeline(
228 "next_offset = fast_divide_mul_noadd(cur_offset, magic, shift)"
229 )
231 code.writeline(
232 f"coord_{i} = cur_offset - next_offset * index_shape{i}"
233 )
234 code.writeline("cur_offset = next_offset")
236 code.writeline(f"if dim == {i}:")
237 with code.indent():
238 code.writeline(f"dim_stride = inp_stride{i}")
239 code.writeline(f"dim_size = inp_shape{i}")
240 code.writeline("else:")
241 with code.indent():
242 code.writeline(f"base_inp_offset += coord_{i} * inp_stride{i}")
243 code.newline()
245 code.writeline("# Handle negative indices")
246 code.writeline("if with_negative_index:")
247 with code.indent():
248 code.writeline(
249 "gather_index = tl.where(gather_index < 0, gather_index + dim_size, gather_index).to(tl.int32)"
250 )
252 code.writeline(
253 "final_inp_offset = base_inp_offset + gather_index * dim_stride"
254 )
255 code.writeline(
256 "val = tl.load(inp_ptr + final_inp_offset, mask=mask, other=0.0)"
257 )
258 code.writeline("tl.store(out_ptr + offsets, val, mask=mask)")
260 code.newline()
261 return code
264def generate_gather_wrapper(
265 rank: int, wrapper_name: str, kernel_name: str, code: IndentedBuffer
266) -> IndentedBuffer:
267 code.writeline(
268 f"def {wrapper_name}(inp, dim, index, out, grid, magic, more, with_negative_index):"
269 )
270 with code.indent():
271 # Extract shapes and strides
272 code.writeline("inp_shape = inp.shape")
273 code.writeline("inp_stride = inp.stride()")
274 code.writeline("index_shape = index.shape")
275 code.writeline("index_stride = index.stride()")
276 code.writeline("out_stride = out.stride()")
277 code.writeline("num_elements = index.numel()")
278 code.newline()
280 code.writeline(f"{kernel_name}[grid](")
281 with code.indent():
282 args = [
283 "inp, ",
284 "index, ",
285 "out, ",
286 ]
287 args += [f"inp_shape[{i}], " for i in range(rank)]
288 args += [f"index_shape[{i}], " for i in range(rank)]
290 args += [f"magic[{i}], " for i in range(rank)]
291 args += [f"more[{i}], " for i in range(rank)]
293 args += [f"inp_stride[{i}], " for i in range(rank)]
294 args += [f"index_stride[{i}], " for i in range(rank)]
295 args += [f"out_stride[{i}], " for i in range(rank)]
297 args += [
298 "dim, ",
299 "num_elements, ",
300 "with_negative_index, ",
301 ]
302 args += [
303 "force_simt_only=False, ",
304 ]
305 code.writelines(args)
306 code.writeline(")")
307 code.writeline("return out")
308 code.newline()
309 return code
312def generate_code(
313 inputs, wrapper_name: str, kernel_name: str, div_kinds: List[str]
314) -> str:
315 code = IndentedBuffer()
316 rank = inputs[0].ndim
317 code = generate_imports(code)
318 code = generate_device_functions(code)
319 code = generate_gather_kernel(rank, kernel_name, div_kinds, code)
320 code = generate_gather_wrapper(rank, wrapper_name, kernel_name, code)
321 return code.getvalue()
324class GatherFunction:
325 def __init__(self):
326 self.overloads = {}
327 self.kernels = {}
329 def __call__(
330 self, inp, dim, index, out, grid, magic_map=None, with_negative_index=False
331 ):
332 rank = inp.ndim
334 if magic_map is None:
335 magic, more = get_all_magics(tuple(index.shape))
336 else:
337 magic, more = magic_map
339 # div_kinds: 'S' (shift-only), 'M' (mul-noadd), 'A' (mul-add)
340 div_kinds = []
341 for m, mo in zip(magic, more):
342 if int(m) == 0:
343 div_kinds.append("S")
344 elif (int(mo) & 0x40) != 0:
345 div_kinds.append("A")
346 else:
347 div_kinds.append("M")
349 pattern = "".join(div_kinds)
350 key = f"gather_rank_{rank}_pat_{pattern}"
352 if key not in self.overloads:
353 kernel_name = f"_gather_kernel_{rank}"
354 wrapper_name = f"_gather_wrapper_{rank}"
356 src_code = generate_code([inp], wrapper_name, kernel_name, div_kinds)
358 file_name = f"{key}.py"
359 file_path = os.path.join(CACHE_DIR, file_name)
360 with open(file_path, "w", encoding="utf-8") as f:
361 f.write(src_code)
363 spec = importlib.util.spec_from_file_location(
364 f"dynamic_mod_{key}", file_path
365 )
366 mod = importlib.util.module_from_spec(spec)
367 assert spec.loader is not None
368 spec.loader.exec_module(mod)
370 self.overloads[key] = getattr(mod, wrapper_name)
372 return self.overloads[key](
373 inp, dim, index, out, grid, magic, more, with_negative_index
374 )
376 def get_kernel(self, rank: int):
377 return self.kernels.get(f"gather_rank_{rank}")
380_gather_func = GatherFunction()
383def gather(
384 inp,
385 dim: int,
386 index,
387 out=None,
388 grid_fn=None,
389 magic_map=None,
390 with_negative_index=False,
391):
392 if out is None:
393 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device)
395 _gather_func(
396 inp,
397 dim,
398 index,
399 out,
400 grid_fn,
401 magic_map=magic_map,
402 with_negative_index=with_negative_index,
403 )
404 return out