Coverage for src/flag_gems/runtime/backend/_sunrise/ops/repeat.py: 0%
255 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import importlib
2import logging
3import os
4from typing import Callable, List, Mapping
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14# --------------------------- repeat wrapper genration -----------------------------------
15def parameter_for_wrapper() -> str:
16 """Generate parameter declaration with type annotation for wrapper function.
17 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor
18 """
19 parameters: List[str] = []
21 parameters.append("in0")
22 parameters.append("sizes")
23 return ", ".join(parameters)
26def parameter_for_wrapper_out() -> str:
27 """Generate parameter declaration with type annotation for wrapper function.
28 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor
29 """
30 parameters: List[str] = []
32 parameters.append("in0")
33 parameters.append("out0")
35 return ", ".join(parameters)
38def parameter_ref_for_wrapper() -> str:
39 """Generate parameter reference for wrapper function.
40 Example: in0, val0, out0, out0_offset
41 """
42 parameters: List[str] = []
44 parameters.append("in0")
45 parameters.append("out0")
47 return ", ".join(parameters)
50def output_ref_for_wrapper() -> str:
51 return "out0"
54def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
55 code.writeline("import math")
56 code.writeline("import torch")
57 code.writeline("try:")
58 code.writeline(" import torch_ptpu")
59 code.writeline("except ImportError:")
60 code.writeline(" import torch.cuda as torch_ptpu")
61 code.writeline("import triton")
62 code.writeline("from triton import language as tl")
63 code.newline()
64 code.writeline("from flag_gems.runtime import torch_device_fn")
65 code.writeline("from flag_gems.utils.shape_utils import volume")
66 code.writeline("from flag_gems.utils.libentry import libentry")
67 code.writeline("from flag_gems.utils.type_utils import type_promotion")
68 code.writeline("from flag_gems.utils import triton_lang_extension as ext")
69 code.newline()
70 code.newline()
71 return code
74def generate_functional_repeat_wrapper(
75 wrapper_name: str,
76 destination_passing_func_name: str,
77 code: IndentedBuffer,
78) -> IndentedBuffer:
79 # wrapper signature
80 parameters: str = parameter_for_wrapper()
81 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
82 code.writeline(wrapper_signature)
84 with code.indent():
85 code.writeline("in0_rank = in0.dim()")
86 code.writeline("sizes_rank = len(sizes)")
87 code.writeline("in0_shape = list(in0.shape)")
88 code.writeline("sizes_shape = list(sizes)")
89 code.newline()
91 code.writeline(
92 "assert(sizes_rank >= in0_rank), \
93 'Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor'"
94 )
95 code.writeline("if (sizes_rank > in0_rank): ")
96 with code.indent():
97 code.writeline("diff = sizes_rank - in0_rank")
98 code.writeline("ones = [1 for _ in range(diff)]")
99 code.writeline("in0_shape = ones + in0_shape")
100 code.newline()
101 code.writeline("is_empty = False")
102 code.writeline("out_shape = []")
103 code.writeline("for i in range(len(in0_shape)): ")
104 with code.indent():
105 code.writeline(
106 "assert(sizes_shape[i] >= 0), 'the number of repetitions per dimension out of range (expected to >= 0) \
107 but got {}'.format(sizes_shape[i])"
108 )
109 code.writeline("if in0_shape[i] * sizes_shape[i] == 0: ")
110 with code.indent():
111 code.writeline("is_empty = True")
112 code.writeline("out_shape.append(in0_shape[i] * sizes_shape[i])")
113 code.newline()
114 code.writeline(
115 "out0 = torch.empty(out_shape, device=in0.device, dtype=in0.dtype)"
116 )
118 code.writeline("in0 = in0.reshape(in0_shape)")
119 code.writeline("if not is_empty: ")
120 with code.indent():
121 # call destination_passing_func
122 output_names: str = output_ref_for_wrapper()
123 call_str = (
124 f"{output_names} = {destination_passing_func_name}"
125 f"({parameter_ref_for_wrapper()})"
126 )
127 code.writeline(call_str)
129 return_str = "return out0"
130 code.writeline(return_str)
131 code.newline()
132 code.newline()
134 return code
137def generate_destination_passing_repeat_wrapper(
138 rank: int,
139 wrapper_name: str,
140 kernel_name: str,
141 code: IndentedBuffer,
142) -> IndentedBuffer:
143 # wrapper signature
144 parameters: str = parameter_for_wrapper_out()
146 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
147 code.writeline(wrapper_signature)
149 with code.indent():
150 # docstring
151 if rank > 0:
152 code.writeline("shape = out0.shape")
153 code.writeline("num_tasks = volume(shape)")
155 if rank > 0:
156 code.writeline("tile_size = min(512, triton.next_power_of_2(num_tasks))")
157 code.writeline("num_warps = 4")
158 code.writeline("num_ctas = min(65535, triton.cdiv(num_tasks, tile_size))")
159 code.writeline(
160 "tiles_per_cta = triton.cdiv(num_tasks, tile_size * num_ctas)"
161 )
162 else:
163 code.writeline("num_warps = 1")
164 code.writeline("num_ctas = 1")
165 code.writeline("grid = (num_ctas, 1, 1)")
166 code.newline()
168 # input strides for each input tensor w.r.t. the task index space
169 if rank > 0:
170 code.writeline("# strides of each tensor argument w.r.t the task space")
171 code.writeline("in0_strides = in0.stride()")
172 code.writeline("in0_shape = in0.shape")
173 code.writeline("out0_strides = out0.stride()")
174 code.newline()
176 # grid
177 code.writeline("# kernel launch")
179 # launch kernel
180 code.writeline("with torch_device_fn.device(in0.device.index):")
181 with code.indent():
182 kernel_launch: str = f"{kernel_name}[grid]("
183 code.writeline(kernel_launch)
185 with code.indent():
186 code.writeline("in0, out0, ")
188 if rank > 0:
189 s = ", ".join(f"in0_strides[{j}]" for j in range(rank))
190 code.writeline(f"{s}, # stride for in0")
192 s = ", ".join(f"out0_strides[{j}]" for j in range(rank))
193 code.writeline(f"{s}, # stride for out0")
195 shape_args: str = ", ".join(f"shape[{i}]" for i in range(rank))
196 code.writeline(f"{shape_args}, # task indexing space")
197 in_shape_args: str = ", ".join(f"in0_shape[{i}]" for i in range(rank))
198 code.writeline(
199 f"{in_shape_args}, # task indexing space used when input and ouput tensor has different shape"
200 )
201 code.writeline("num_tasks, # num tasks")
202 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta")
203 code.writeline("tile_size=tile_size,")
204 code.writeline("one_tile_per_cta=tiles_per_cta==1,")
205 code.writeline("num_warps=num_warps,")
206 code.writeline(")")
208 # return
209 code.writeline("return out0")
210 code.newline()
211 code.newline()
212 return code
215def generate_repeat_kernel(
216 rank: int,
217 kernel_name: str,
218 code: IndentedBuffer,
219) -> IndentedBuffer:
220 # make the inlined function visible in the context
221 code.newline()
223 # the decorators
224 code.writeline("@libentry()")
225 code.writeline("@triton.jit")
227 # signature
228 code.writeline(f"def {kernel_name}(")
229 with code.indent():
230 # signature: inputs ptrs & non tensor inputs
231 code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type")
233 # signature: output ptrs
234 code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type")
236 # signature: strides, for each tensor arguments
237 # only add this arguments when rank > 0
238 if rank > 0:
239 # strides for inputs
240 stride_args = ", ".join(f"in0_stride{j}: int" for j in range(rank))
241 code.writeline(f"{stride_args}, # strides for in0")
243 # strides for outputs
244 stride_args = ", ".join(f"out0_stride{j}: int" for j in range(rank))
245 code.writeline(f"{stride_args}, # strides for out0")
247 # task space, used to reconstruct multi index
248 task_space_args = ", ".join(f"s{i}: int" for i in range(rank))
249 code.writeline(f"{task_space_args}, # task_space")
251 task_space_args2 = ", ".join(f"in_s{i}: int" for i in range(rank))
252 code.writeline(
253 f"{task_space_args2}, # task_space2 used when input and output tensor has different shape"
254 )
256 # number of tasks, used to compute mask
257 code.writeline("num_tasks: int,")
259 # tile size & tiles_per_cta, gsl style
260 if rank > 0:
261 code.writeline("tiles_per_cta,")
263 code.writeline("tile_size: tl.constexpr,")
265 code.writeline("one_tile_per_cta: tl.constexpr,")
266 code.writeline("):")
268 with code.indent():
269 # get pid
270 code.writeline("# task id & masking")
271 pid_stmt = "pid = ext.program_id(0)"
272 code.writeline(pid_stmt)
274 code.writeline("num_ctas = ext.num_programs(0)")
276 # get tid (a.k.a task id)
277 tid_stmt = "init_tid = pid * tile_size + tl.arange(0, tile_size)"
278 code.writeline(tid_stmt)
280 # one-tile-per-cta, monolithic kernel style
281 code.writeline("if one_tile_per_cta: # monolitic kernel style")
282 with code.indent():
283 tid_stmt = "tid = init_tid"
284 code.writeline(tid_stmt)
286 # only apply masking when rank > 0
287 # since we only load a value instead of a block of values when the rank is 0
288 mask_stmt: str = "mask = tid < num_tasks"
289 code.writeline(mask_stmt)
290 code.newline()
292 # reconstruct multi index
293 code.writeline("# multi index recontruction")
294 for i in reversed(range(rank)):
295 if i > 0:
296 code.writeline(f"i{i} = tid % s{i}")
297 code.writeline(f"tid //= s{i}")
298 else:
299 code.writeline(f"i{i} = tid")
300 code.newline()
302 # loads
303 code.writeline("# loads")
304 ptrs_expr: str = " + ".join(
305 f"(i{j} % in_s{j}) * in{i}_stride{j}" for j in range(rank)
306 )
307 ptrs_expr: str = f"in0_ptr + {ptrs_expr}"
308 load_stmt: str = f"in0 = tl.load({ptrs_expr}, mask=mask)"
309 code.writeline(load_stmt)
310 code.newline()
312 # compute
313 code.writeline("# compute")
314 code.writeline("out0 = in0")
315 code.newline()
317 # stores
318 code.writeline("# stores")
319 ptrs_expr: str = " + ".join(f"i{j} * out0_stride{j}" for j in range(rank))
320 ptrs_expr: str = f"out0_ptr + {ptrs_expr}"
321 store_stmt: str = f"tl.store({ptrs_expr}, out0, mask=mask)"
322 code.writeline(store_stmt)
324 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
325 code.writeline("else: # grid-stride-loop style kernel")
326 with code.indent():
327 code.writeline("for j in range(0, tiles_per_cta):")
328 with code.indent():
329 tid_stmt = "tid = init_tid + j * tile_size * num_ctas"
330 code.writeline(tid_stmt)
332 # only apply masking when rank > 0
333 # since we only load a value instead of a block of values when the rank is 0
334 mask_stmt: str = "mask = tid < num_tasks"
335 code.writeline(mask_stmt)
336 code.newline()
338 # reconstruct multi index
339 code.writeline("# multi index recontruction")
340 for i in reversed(range(rank)):
341 if i > 0:
342 code.writeline(f"i{i} = tid % s{i}")
343 code.writeline(f"tid //= s{i}")
344 else:
345 code.writeline(f"i{i} = tid")
346 code.newline()
348 # loads
349 code.writeline("# loads")
350 ptrs_expr: str = " + ".join(
351 f"(i{j} % in_s{j}) * in{i}_stride{j}" for j in range(rank)
352 )
353 ptrs_expr: str = f"in0_ptr + {ptrs_expr}"
354 load_stmt: str = f"in0 = tl.load({ptrs_expr}, mask=mask)"
355 code.writeline(load_stmt)
356 code.newline()
358 # compute
359 code.writeline("# compute")
360 code.writeline("out0 = in0")
361 code.newline()
363 # stores
364 code.writeline("# stores")
365 ptrs_expr: str = " + ".join(
366 f"i{j} * out0_stride{j}" for j in range(rank)
367 )
368 ptrs_expr: str = f"out0_ptr + {ptrs_expr}"
369 store_stmt: str = f"tl.store({ptrs_expr}, out0, mask=mask)"
370 code.writeline(store_stmt)
371 code.newline()
372 return code
375def generate_code(
376 rank: int,
377 wrapper_name: str,
378 destination_passing_func_name: str,
379 kernel_name: str,
380 code: IndentedBuffer,
381) -> IndentedBuffer:
382 # the only runtime determined factor is the rank of the task space
383 code = generate_imports(code)
384 code = generate_functional_repeat_wrapper(
385 wrapper_name, destination_passing_func_name, code
386 )
387 code = generate_destination_passing_repeat_wrapper(
388 rank, destination_passing_func_name, kernel_name, code
389 )
390 code = generate_repeat_kernel(rank, kernel_name, code)
391 return code
394class RepeatFunction:
395 def __init__(self):
396 self.pid = os.getpid()
397 # instantiated & cached overloads
398 self.overloads: Mapping[str, Callable] = {}
400 def __call__(self, x, sizes):
401 # note: kwargs should not be used in JITFunction directly
402 ndim = self.arg_key(x, sizes)
403 key = str(ndim)
404 if key in self.overloads:
405 overload = self.overloads[key]
406 else:
407 # generate file & import it
408 code = IndentedBuffer()
409 code = generate_code(
410 ndim,
411 "_wrapper",
412 "_wrapper_out",
413 "_repeat_flaggems_jit_function",
414 code,
415 )
417 file_name = f"repeat_rank_{key}.py"
418 file_path = code_cache_dir() / file_name
419 write_atomic(file_path, code.getvalue())
421 # load
422 spec = importlib.util.spec_from_file_location(
423 f"_gen_module_rank_{key}",
424 file_path,
425 )
427 m = importlib.util.module_from_spec(spec)
428 # do not expose it to sys.modules
429 # sys.modules["_add_module"] = m
430 spec.loader.exec_module(m)
431 overload = getattr(m, "_wrapper")
432 self.overloads[key] = overload
433 return overload(x, sizes)
435 def arg_key(self, x, sizes):
436 max_rank = max(x.ndim, len(sizes))
437 return max_rank
440_repeat_func = RepeatFunction()
443def repeat(inp: torch.Tensor, sizes) -> torch.Tensor:
444 logger.debug("GEMS REPEAT")
446 out = _repeat_func(inp, sizes)
447 return out