Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/gather.py: 0%
195 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer
10from flag_gems.utils.shape_utils import restride_dim
12from .scatter import scatter_
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
18 code.writeline("import torch")
19 code.writeline("import triton")
20 code.writeline("import triton.language as tl")
21 code.writeline("import builtins")
22 code.newline()
23 code.writeline("from flag_gems.utils import libentry")
24 code.writeline("from flag_gems import runtime")
25 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
27 code.newline()
28 code.newline()
29 return code
32def generate_gather_kernel(
33 rank: int,
34 kernel_name: str,
35 code: IndentedBuffer,
36) -> IndentedBuffer:
37 # make the inlined function visible in the context
38 code.newline()
40 # the autotune function
41 code.writeline("def cfggen():")
42 with code.indent():
43 code.writeline("block_m = [1, 2, 4, 8]")
44 code.writeline("block_n = [256, 512, 1024, 2048]")
45 code.writeline("configs = [")
46 with code.indent():
47 code.writeline('triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_warps=4)')
48 code.writeline("for m in block_m")
49 code.writeline("for n in block_n")
50 code.writeline("]")
51 code.writeline("return configs")
53 code.newline()
54 code.newline()
56 code.writeline("def heur_block_m(args):")
57 with code.indent():
58 code.writeline('return triton.next_power_of_2(triton.cdiv(args["M"], 12))')
60 code.newline()
62 code.writeline("def heur_block_n(args):")
63 with code.indent():
64 code.writeline('return builtins.min(triton.next_power_of_2(args["N"]), 4096)')
66 code.newline()
67 code.newline()
69 # the decorators
70 code.writeline("@libentry()")
71 # code.writeline('@triton.autotune(configs=cfggen(), key=["M", "N"])')
72 code.writeline("@triton.heuristics(")
73 with code.indent():
74 code.writeline("values={")
75 with code.indent():
76 code.writeline('"BLOCK_M": heur_block_m,')
77 code.writeline('"BLOCK_N": heur_block_n,')
78 code.writeline("},")
79 code.writeline(")")
80 code.writeline("@triton.jit")
82 # signature
83 code.writeline(f"def {kernel_name}(")
84 with code.indent():
85 if rank > 0:
86 code.writeline("inp,")
87 code.writeline("out,")
88 code.writeline("index,")
90 stride_args = ", ".join(
91 f"inp_stride_{i}: tl.constexpr" for i in range(rank)
92 )
93 code.writeline(f"{stride_args}, # stride for inp")
95 stride_args = ", ".join(
96 f"index_stride_{i}: tl.constexpr" for i in range(rank)
97 )
98 code.writeline(f"{stride_args}, # stride for index")
100 shape_args = ", ".join(
101 f"index_shape_{i}: tl.constexpr" for i in range(rank)
102 )
103 code.writeline(f"{shape_args}, # shape for index")
105 code.writeline("dim: tl.constexpr,")
106 code.writeline("stride_dim: tl.constexpr,")
107 code.writeline("inp_dim_size: tl.constexpr,")
108 code.writeline("M: tl.constexpr,")
109 code.writeline("N: tl.constexpr,")
110 code.writeline("BLOCK_M: tl.constexpr,")
111 code.writeline("BLOCK_N: tl.constexpr,")
112 code.writeline("):")
114 # Kernel Code
115 with code.indent():
116 code.writeline("pid_x = tle.program_id(0)")
117 code.writeline("pid_y = tle.program_id(1)")
118 code.writeline(
119 "rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]"
120 )
121 code.writeline(
122 "cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]"
123 )
124 code.writeline("rows_mask = rows_offsets < M")
125 code.writeline("cols_mask = cols_offsets < N")
127 code.writeline("offsets = (rows_offsets * N + cols_offsets).to(tl.int64)")
128 code.writeline("mask = rows_mask & cols_mask")
130 # 1. Calculate inp_offsets and idx_offsets
131 code.writeline("inp_offsets = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int64)")
132 code.writeline("idx_offsets = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int64)")
133 code.writeline("cur_idx = rows_offsets * N + cols_offsets")
135 # 2. snippets
136 for i in range(rank):
137 code.writeline(f"mod = cur_idx % index_shape_{i}")
138 code.writeline(f"inp_offsets += mod * inp_stride_{i}")
139 code.writeline(f"idx_offsets += mod * index_stride_{i}")
140 if i != (rank - 1):
141 code.writeline(f"cur_idx //= index_shape_{i}")
143 # Use offsets to gather
144 code.writeline("cur_index = tl.load(index + idx_offsets, mask=mask, other=0)")
145 code.writeline("inp_offsets += cur_index * stride_dim")
146 code.writeline("cur_inp = tl.load(inp + inp_offsets, mask=mask, other=0)")
147 code.writeline("tl.store(out + idx_offsets, cur_inp, mask=mask)")
149 code.newline()
150 code.newline()
151 return code
154def parameter_for_wrapper() -> str:
155 # inp_strided, out, index, dim, stride_dim, inp_dim_size, M, N
156 parameters: List[str] = []
158 parameters.append("inp_strided")
159 parameters.append("out")
160 parameters.append("index")
161 parameters.append("dim")
162 parameters.append("stride_dim")
163 parameters.append("inp_dim_size")
164 parameters.append("M")
165 parameters.append("N")
167 return ", ".join(parameters)
170def generate_gather_wrapper(
171 rank: int,
172 wrapper_name: str,
173 kernel_name: str,
174 code: IndentedBuffer,
175) -> IndentedBuffer:
176 parameters: str = parameter_for_wrapper()
177 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
178 code.writeline(wrapper_signature)
180 with code.indent():
181 code.writeline("inp_strides = inp_strided.stride()")
182 code.writeline("index_strides = index.stride()")
183 code.writeline("index_shapes = list(index.shape)")
185 # kernel launch
186 code.writeline("grid = lambda meta: (")
187 with code.indent():
188 code.writeline('triton.cdiv(M, meta["BLOCK_M"]),')
189 code.writeline('triton.cdiv(N, meta["BLOCK_N"])')
190 code.writeline(")")
192 kernel_launch: str = f"{kernel_name}[grid]("
193 code.writeline(kernel_launch)
195 with code.indent():
196 code.writeline("inp_strided, out, index, ")
197 if rank > 0:
198 s = ", ".join(f"inp_strides[{i}]" for i in range(rank))
199 code.writeline(f"{s},")
201 s = ", ".join(f"index_strides[{i}]" for i in range(rank))
202 code.writeline(f"{s},")
204 s = ", ".join(f"index_shapes[{i}]" for i in range(rank))
205 code.writeline(f"{s},")
207 code.writeline("dim,")
208 code.writeline("stride_dim,")
209 code.writeline("inp_dim_size,")
210 code.writeline("M,")
211 code.writeline("N,")
212 code.writeline(")")
213 code.writeline("return out")
215 return code
218def generate_code(
219 inputs: Tuple[Any],
220 wrapper_name: str,
221 kernel_name: str,
222 code: IndentedBuffer,
223) -> IndentedBuffer:
224 # inputs: inp_strided, out, index, dim, stride_dim, inp_dim_size, M, N
225 shape = inputs[2].shape
226 rank = len(shape)
228 code = generate_imports(code)
229 code = generate_gather_kernel(rank, kernel_name, code)
230 code = generate_gather_wrapper(rank, wrapper_name, kernel_name, code)
231 return code
234class GatherFunction:
235 def __init__(self):
236 self.pid = os.getpid()
237 self.overloads: Mapping[str, Callable] = {}
239 def __call__(self, *args, **kwargs):
240 key = f"{self.arg_key(*args)}"
241 if key in self.overloads:
242 overload = self.overloads[key]
243 else:
244 code = IndentedBuffer()
245 code = generate_code(
246 args,
247 "_gather_wrapper",
248 "_gather_jit_function",
249 code,
250 )
252 file_name = f"gather_rank_{key}_pid_{self.pid}.py"
254 with open(cache_dir() / file_name, "wt", encoding="utf-8") as f:
255 f.write(code.getvalue())
257 # load
258 spec = importlib.util.spec_from_file_location(
259 f"_gen_module_rank_{key}_pid_{self.pid}",
260 f.name,
261 )
263 m = importlib.util.module_from_spec(spec)
264 spec.loader.exec_module(m)
265 overload = getattr(m, "_gather_wrapper")
266 self.overloads[key] = overload
268 return overload(*args, **kwargs)
270 def arg_key(self, *args):
271 tensors = [item for item in args if torch.is_tensor(item)]
272 max_rank = max(item.ndim for item in tensors)
273 return max_rank
276_gather_func = GatherFunction()
279def gather(inp, dim, index, out=None, sparse_grad=False):
280 logger.debug("GEMS GATHER")
281 if dim < 0:
282 dim += inp.ndim
283 inp = inp.contiguous()
284 index = index.contiguous()
285 if out is None:
286 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device)
287 out = out.contiguous()
288 stride_dim = inp.stride(dim)
290 inp_strided = restride_dim(inp, dim, index.shape)
291 # plain_idx = torch.arange(0, index.numel(), device=inp.device).reshape(index.shape)
292 N = list(index.shape)[index.ndim - 1]
293 M = index.numel() // N
294 inp_dim_size = inp.size(dim)
296 _gather_func(inp_strided, out, index, dim, stride_dim, inp_dim_size, M, N)
297 return out
300def gather_backward(grad, self, dim, index, sparse_grad):
301 logger.debug("GEMS GATHER BACKWARD")
302 result = grad.new_zeros(self.shape)
303 return scatter_(result, dim, index, grad, reduce="add")