Coverage for src/flag_gems/runtime/backend/_sunrise/ops/pad.py: 0%
286 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
11logger = logging.getLogger(__name__)
14# --------------------------- padding wrapper genration -----------------------------------
15def parameter_for_wrapper() -> str:
16 """Generate parameter declaration with type annotation for wrapper function.
17 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor
18 """
19 parameters: List[str] = []
21 parameters.append("in0")
22 parameters.append("pad")
23 parameters.append("mode")
24 parameters.append("value=0")
25 return ", ".join(parameters)
28def parameter_for_wrapper_out() -> str:
29 """Generate parameter declaration with type annotation for wrapper function.
30 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor
31 """
32 parameters: List[str] = []
34 parameters.append("in0")
35 parameters.append("out0")
36 parameters.append("dst_shape")
37 parameters.append("pad_before")
38 parameters.append("pad_after")
39 parameters.append("mode")
40 parameters.append("value=0")
42 return ", ".join(parameters)
45def parameter_ref_for_wrapper() -> str:
46 """Generate parameter reference for wrapper function.
47 Example: in0, val0, out0, out0_offset
48 """
49 parameters: List[str] = []
51 parameters.append("in0")
52 parameters.append("out0")
53 parameters.append("dst_shape")
54 parameters.append("pad_before")
55 parameters.append("pad_after")
56 parameters.append("mode")
57 parameters.append("value")
59 return ", ".join(parameters)
62def output_ref_for_wrapper() -> str:
63 return "out0"
66def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
67 code.writeline("import math")
68 code.writeline("import torch")
69 code.writeline("import triton")
70 code.writeline("from triton import language as tl")
71 code.newline()
72 code.writeline("from flag_gems.utils.libentry import libentry")
73 code.writeline("from flag_gems.runtime import torch_device_fn")
74 code.writeline("from flag_gems.utils import triton_lang_extension as ext")
75 code.writeline("from flag_gems.utils.type_utils import type_promotion")
76 code.newline()
77 code.newline()
78 return code
81def generate_functional_padding_wrapper(
82 wrapper_name: str,
83 destination_passing_func_name: str,
84 code: IndentedBuffer,
85) -> IndentedBuffer:
86 # wrapper signature
87 parameters: str = parameter_for_wrapper()
88 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
89 code.writeline(wrapper_signature)
91 with code.indent():
92 code.writeline("ndim = in0.ndim")
93 code.writeline("pad_size = len(pad)")
94 code.writeline("assert pad_size % 2 == 0")
95 code.newline()
96 code.writeline("pad_before = [0 for _ in range(ndim)]")
97 code.writeline("pad_after = [0 for _ in range(ndim)]")
98 code.newline()
99 code.writeline("pad_pair = pad_size // 2 ")
100 code.writeline("for i in range(pad_pair): ")
101 with code.indent():
102 code.writeline("pad_before[ndim - i - 1] = pad[2 * i]")
103 code.writeline("pad_after[ndim - i - 1] = pad[2 * i + 1]")
104 code.writeline("dst_shape = list(in0.shape)")
105 code.writeline("for i in range(ndim): ")
106 with code.indent():
107 code.writeline("dst_shape[i] += pad_before[i] + pad_after[i]")
109 code.writeline(
110 ("out0 = torch.empty(dst_shape, device=in0.device, dtype=in0.dtype)")
111 )
113 # call destination_passing_func
114 output_names: str = output_ref_for_wrapper()
115 call_str = (
116 f"{output_names} = {destination_passing_func_name}"
117 f"({parameter_ref_for_wrapper()})"
118 )
119 code.writeline(call_str)
121 return_str = "return out0"
122 code.writeline(return_str)
123 code.newline()
124 code.newline()
126 return code
129def generate_destination_passing_padding_wrapper(
130 rank: int,
131 wrapper_name: str,
132 kernel_name: str,
133 code: IndentedBuffer,
134) -> IndentedBuffer:
135 # wrapper signature
136 parameters: str = parameter_for_wrapper_out()
138 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
139 code.writeline(wrapper_signature)
141 with code.indent():
142 # docstring
143 code.writeline("BLOCK_SIZE = 256")
144 code.writeline("grid = (triton.cdiv(out0.numel(), BLOCK_SIZE), 1, 1)")
145 code.newline()
147 code.writeline("x_shape = in0.shape")
148 code.writeline("in_strides0 = in0.stride()")
149 code.writeline("out_strides = out0.stride()")
151 # input strides for each input tensor w.r.t. the task index space
152 if rank > 0:
153 code.writeline("# strides of each tensor argument w.r.t the task space")
154 for i in range(rank):
155 code.writeline(f"valid_dim{i}_start = pad_before[{i}]")
157 code.writeline(f"valid_dim{i}_end = dst_shape[{i}] - pad_after[{i}]")
159 code.newline()
161 code.writeline("# Check which dimensions have padding")
162 for i in range(rank):
163 code.writeline(
164 f"dim{i}_has_pad = pad_before[{i}] > 0 or pad_after[{i}] > 0"
165 )
166 code.writeline("IS_CONSTANT = mode == 'constant'")
167 code.writeline("IS_REFLECT = mode == 'reflect'")
168 code.writeline("IS_REPLICATE = mode == 'replicate'")
169 code.writeline("IS_CIRCULAR = mode == 'circular'")
171 code.newline()
173 # grid
174 code.writeline("# kernel launch")
176 # launch kernel
177 code.writeline("with torch_device_fn.device(in0.device):")
178 with code.indent():
179 kernel_launch: str = f"{kernel_name}[grid]("
180 code.writeline(kernel_launch)
182 with code.indent():
183 code.writeline("in0, out0, ")
185 if rank > 0:
186 s = ", ".join(f"x_shape[{j}]" for j in range(rank))
187 code.writeline(f"{s}, # shape for x")
189 s = ", ".join(f"in_strides0[{j}]" for j in range(rank))
190 code.writeline(f"{s}, # stride for x")
192 s = ", ".join(f"out_strides[{j}]" for j in range(rank))
193 code.writeline(f"{s}, # stride for out")
195 s = ", ".join(f"valid_dim{j}_start" for j in range(rank))
196 code.writeline(f"{s}, # valid dim start")
198 s = ", ".join(f"valid_dim{j}_end" for j in range(rank))
199 code.writeline(f"{s}, # valid dim end")
201 s = ", ".join(f"bool(dim{i}_has_pad)" for i in range(rank))
202 code.writeline(f"{s}, # dim has padding flags")
204 code.writeline("in0.numel(), ")
205 code.writeline("out0.numel(), ")
206 code.writeline("value, ")
207 code.writeline("IS_CONSTANT, ")
208 code.writeline("IS_REFLECT, ")
209 code.writeline("IS_REPLICATE, ")
210 code.writeline("IS_CIRCULAR, ")
211 code.writeline("BLOCK_SIZE, ")
212 code.writeline(")")
214 code.writeline("return out0")
215 code.newline()
216 code.newline()
217 return code
220def generate_pad_kernel(
221 rank: int,
222 kernel_name: str,
223 code: IndentedBuffer,
224) -> IndentedBuffer:
225 # make the inlined function visible in the context
226 code.newline()
228 # the decorators
229 code.writeline("@libentry()")
230 non_specialize_arg_names = ["value"]
231 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})")
233 # signature
234 code.writeline(f"def {kernel_name}(")
235 with code.indent():
236 code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type")
238 code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type")
240 if rank > 0:
241 # shape for inputs
242 shape_args = ", ".join(f"x_shape{j}: int" for j in range(rank))
243 code.writeline(f"{shape_args}, # shape for x")
245 # shape for inputs
246 stride_args = ", ".join(f"in_strides{j}: int" for j in range(rank))
247 code.writeline(f"{stride_args}, # stride for x")
249 # shape for inputs
250 stride_args = ", ".join(f"out_strides{j}: int" for j in range(rank))
251 code.writeline(f"{stride_args}, # stride for out")
253 # shape for inputs
254 stride_args = ", ".join(f"valid_dim{j}_start: int" for j in range(rank))
255 code.writeline(f"{stride_args}, # valid dim start")
257 # shape for inputs
258 stride_args = ", ".join(f"valid_dim{j}_end: int" for j in range(rank))
259 code.writeline(f"{stride_args}, # valid dim end")
261 for i in range(rank):
262 code.writeline(f"dim{i}_has_pad: tl.constexpr, ")
264 code.writeline("in_elem_cnt: tl.constexpr, ")
265 code.writeline("out_elem_cnt: tl.constexpr, ")
266 code.writeline("value, # padding value")
267 code.writeline("IS_CONSTANT: tl.constexpr, ")
268 code.writeline("IS_REFLECT: tl.constexpr, ")
269 code.writeline("IS_REPLICATE: tl.constexpr, ")
270 code.writeline("IS_CIRCULAR: tl.constexpr, ")
271 code.writeline("BLOCK_SIZE: tl.constexpr, ")
273 code.writeline("):")
275 with code.indent():
276 code.writeline("pid = ext.program_id(0)")
277 code.writeline("block_offset = pid * BLOCK_SIZE")
278 code.writeline("offset = block_offset + tl.arange(0, BLOCK_SIZE)")
279 code.newline()
281 code.writeline("remaining = offset ")
282 for i in range(rank):
283 code.writeline(f"idx = remaining // out_strides{i}")
284 code.writeline(f"dst_index_{i} = idx")
285 code.writeline(f"remaining = remaining - idx * out_strides{i}")
286 code.newline()
288 code.writeline("if_pad_false_mask = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32)")
289 code.writeline("if_pad_true_mask = tl.full((BLOCK_SIZE, ), 1, dtype=tl.int32)")
291 code.writeline(
292 "cond = ((dst_index_0 >= valid_dim0_start) & (dst_index_0 < valid_dim0_end))"
293 )
295 for i in range(1, rank):
296 code.writeline(
297 f"cond &= ((dst_index_{i} >= valid_dim{i}_start) & (dst_index_{i} < valid_dim{i}_end))"
298 )
300 code.writeline(
301 "if_pad = tl.where(cond, if_pad_false_mask, if_pad_true_mask).to(tl.int1)"
302 )
304 for i in range(rank):
305 code.writeline(f"src_index_{i} = dst_index_{i} - valid_dim{i}_start ")
307 for i in range(rank):
308 code.writeline(
309 f"src_index_{i} = tl.where(src_index_{i} < 0, 0, src_index_{i})"
310 )
312 code.newline()
313 code.writeline("if IS_REFLECT: ")
314 with code.indent():
315 for i in range(rank):
316 code.writeline(
317 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} < valid_dim{i}_start),
318 valid_dim{i}_start - dst_index_{i}, src_index_{i})"""
319 )
320 for i in range(rank):
321 code.writeline(
322 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} >= valid_dim{i}_end),
323 (x_shape{i} + valid_dim{i}_start - 1) * 2 - dst_index_{i} - valid_dim{i}_start, src_index_{i})"""
324 )
326 code.newline()
327 code.writeline("if IS_REPLICATE: ")
328 with code.indent():
329 for i in range(rank):
330 code.writeline(
331 f"src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} < valid_dim{i}_start), 0, src_index_{i})"
332 )
333 for i in range(rank):
334 end_cond = f"dst_index_{i} >= valid_dim{i}_end"
335 code.writeline(
336 f"src_index_{i} = tl.where(dim{i}_has_pad & ({end_cond}), "
337 f"x_shape{i} - 1, src_index_{i})"
338 )
340 code.newline()
341 code.writeline("if IS_CIRCULAR: ")
342 with code.indent():
343 for i in range(rank):
344 code.writeline(
345 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} < valid_dim{i}_start),
346 dst_index_{i} + x_shape{i} - valid_dim{i}_start, src_index_{i})"""
347 )
348 for i in range(rank):
349 code.writeline(
350 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} >= valid_dim{i}_end),
351 dst_index_{i} - valid_dim{i}_end, src_index_{i})"""
352 )
354 code.newline()
356 for i in range(rank):
357 code.writeline(
358 f"safe_src_index_{i} = tl.where(src_index_{i} < x_shape{i}, src_index_{i}, x_shape{i} - 1)"
359 )
361 code.newline()
363 code.writeline("src_offset = src_index_0 * in_strides0")
364 for i in range(1, rank):
365 code.writeline(f"src_offset += src_index_{i} * in_strides{i}")
367 code.writeline("safe_src_offset = safe_src_index_0 * in_strides0")
368 for i in range(1, rank):
369 code.writeline(f"safe_src_offset += safe_src_index_{i} * in_strides{i}")
371 code.writeline("load_cond = src_index_0 < x_shape0")
372 for i in range(1, rank):
373 code.writeline(f"load_cond &= src_index_{i} < x_shape{i}")
375 code.writeline("if IS_CONSTANT: ")
376 with code.indent():
377 code.writeline(
378 "x_loaded = tl.load(in0_ptr + safe_src_offset, mask=offset < out_elem_cnt, other=0)"
379 )
380 code.writeline("x_val = tl.where(cond, x_loaded, value)")
381 code.writeline("else: ")
382 with code.indent():
383 code.writeline(
384 "x_val = tl.load(in0_ptr + src_offset, mask=load_cond, other=0)"
385 )
386 code.writeline("tl.store(out0_ptr + offset, x_val, mask=offset < out_elem_cnt)")
388 return code
391def generate_code(
392 inputs: Tuple[Any],
393 wrapper_name: str,
394 destination_passing_func_name: str,
395 kernel_name: str,
396 code: IndentedBuffer,
397) -> IndentedBuffer:
398 shape = inputs[0].shape
399 rank = len(shape)
401 # the only runtime determined factor is the rank of the task space
402 code = generate_imports(code)
403 code = generate_functional_padding_wrapper(
404 wrapper_name, destination_passing_func_name, code
405 )
406 code = generate_destination_passing_padding_wrapper(
407 rank, destination_passing_func_name, kernel_name, code
408 )
409 code = generate_pad_kernel(rank, kernel_name, code)
410 return code
413class PadFunction:
414 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction
415 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors).
416 The generated code are written out to the cache directory (defaults to ~/.flaggems).
417 """
419 def __init__(self):
420 self.pid = os.getpid()
421 self.overloads: Mapping[str, Callable] = {}
423 def __call__(self, *args, **kwargs):
424 # note: kwargs should not be used in JITFunction directly
425 key = f"{self.arg_key(*args)}"
426 if key in self.overloads:
427 overload = self.overloads[key]
428 else:
429 # generate file & import it
430 code = IndentedBuffer()
431 code = generate_code(
432 args,
433 "_pad_wrapper",
434 "_pad_wrapper_out",
435 "_pad_jit_function",
436 code,
437 )
439 file_name = f"constant_pad_rank_{key}.py"
440 file_path = code_cache_dir() / file_name
441 write_atomic(file_path, code.getvalue())
443 # load
444 spec = importlib.util.spec_from_file_location(
445 f"_gen_module_rank_{key}",
446 file_path,
447 )
449 m = importlib.util.module_from_spec(spec)
450 # do not expose it to sys.modules
451 # sys.modules["_add_module"] = m
452 spec.loader.exec_module(m)
453 overload = getattr(m, "_pad_wrapper")
454 self.overloads[key] = overload
455 return overload(*args, **kwargs)
457 def arg_key(self, *args):
458 tensors = [item for item in args if torch.is_tensor(item)]
459 max_rank = max(item.ndim for item in tensors)
460 return max_rank
463_pad_func = PadFunction()
466def pad(self, pad, mode="constant", value=None):
467 logger.debug("GEMS CONSTANT PAD ND")
469 ndim = self.ndim
471 if value is None:
472 value = 0.0
474 pad_pairs = len(pad) // 2
476 if mode == "reflect":
477 for i in range(pad_pairs):
478 pad_l, pad_r = pad[2 * i], pad[2 * i + 1]
479 input_size = self.shape[ndim - 1 - i]
480 assert (
481 pad_l < input_size and pad_r < input_size
482 ), \
483 f"padding size should be less than the corresponding input dimension, \
484 but got padding size: {pad_l}, {pad_r}, input size: {self.shape}"
486 if mode == "circular":
487 for i in range(pad_pairs):
488 pad_l, pad_r = pad[2 * i], pad[2 * i + 1]
489 input_size = self.shape[ndim - 1 - i]
490 assert (
491 pad_l <= input_size and pad_r <= input_size
492 ), "Padding value causes wrapping around more than once."
494 out = _pad_func(self, pad, mode, float(value))
495 return out
498def constant_pad_nd(self, pad_list, value=0):
499 return pad(self, pad_list, mode="constant", value=value)