Coverage for src/flag_gems/runtime/backend/_cambricon/ops/pad.py: 0%
349 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
7import triton
8import triton.language as tl
10from flag_gems.utils import libentry
11from flag_gems.utils.code_cache import code_cache_dir
12from flag_gems.utils.code_utils import IndentedBuffer
14from ..utils import TOTAL_CORE_NUM
16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
19# --------------------------- padding wrapper genration -----------------------------------
20def parameter_for_wrapper() -> str:
21 """Generate parameter declaration with type annotation for wrapper function.
22 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor
23 """
24 parameters: List[str] = []
26 parameters.append("in0")
27 parameters.append("pad")
28 parameters.append("mode")
29 parameters.append("value=0")
30 return ", ".join(parameters)
33def parameter_for_wrapper_out() -> str:
34 """Generate parameter declaration with type annotation for wrapper function.
35 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor
36 """
37 parameters: List[str] = []
39 parameters.append("in0")
40 parameters.append("out0")
41 parameters.append("dst_shape")
42 parameters.append("pad_before")
43 parameters.append("pad_after")
44 parameters.append("mode")
45 parameters.append("value=0")
47 return ", ".join(parameters)
50def parameter_ref_for_wrapper() -> str:
51 """Generate parameter reference for wrapper function.
52 Example: in0, val0, out0, out0_offset
53 """
54 parameters: List[str] = []
56 parameters.append("in0")
57 parameters.append("out0")
58 parameters.append("dst_shape")
59 parameters.append("pad_before")
60 parameters.append("pad_after")
61 parameters.append("mode")
62 parameters.append("value")
64 return ", ".join(parameters)
67def output_ref_for_wrapper() -> str:
68 return "out0"
71def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
72 code.writeline("import math")
73 code.writeline("import torch")
74 code.writeline("import triton")
75 code.writeline("from triton import language as tl")
76 code.newline()
77 code.writeline("from flag_gems.utils.libentry import libentry")
78 code.writeline("from flag_gems.runtime import torch_device_fn")
79 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
80 code.writeline("from flag_gems.utils.type_utils import type_promotion")
81 code.newline()
82 code.newline()
83 return code
86def generate_functional_padding_wrapper(
87 wrapper_name: str,
88 destination_passing_func_name: str,
89 code: IndentedBuffer,
90) -> IndentedBuffer:
91 # wrapper signature
92 parameters: str = parameter_for_wrapper()
93 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
94 code.writeline(wrapper_signature)
96 with code.indent():
97 code.writeline("ndim = in0.ndim")
98 code.writeline("pad_size = len(pad)")
99 code.writeline("assert pad_size % 2 == 0")
100 code.newline()
101 code.writeline("pad_before = [0 for _ in range(ndim)]")
102 code.writeline("pad_after = [0 for _ in range(ndim)]")
103 code.newline()
104 code.writeline("pad_pair = pad_size // 2 ")
105 code.writeline("for i in range(pad_pair): ")
106 with code.indent():
107 code.writeline("pad_before[ndim - i - 1] = pad[2 * i]")
108 code.writeline("pad_after[ndim - i - 1] = pad[2 * i + 1]")
109 code.writeline("dst_shape = list(in0.shape)")
110 code.writeline("for i in range(ndim): ")
111 with code.indent():
112 code.writeline("dst_shape[i] += pad_before[i] + pad_after[i]")
114 code.writeline(
115 ("out0 = torch.empty(dst_shape, device=in0.device, dtype=in0.dtype)")
116 )
118 # call destination_passing_func
119 output_names: str = output_ref_for_wrapper()
120 call_str = (
121 f"{output_names} = {destination_passing_func_name}"
122 f"({parameter_ref_for_wrapper()})"
123 )
124 code.writeline(call_str)
126 return_str = "return out0"
127 code.writeline(return_str)
128 code.newline()
129 code.newline()
131 return code
134def generate_destination_passing_padding_wrapper(
135 rank: int,
136 wrapper_name: str,
137 kernel_name: str,
138 code: IndentedBuffer,
139) -> IndentedBuffer:
140 # wrapper signature
141 parameters: str = parameter_for_wrapper_out()
143 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
144 code.writeline(wrapper_signature)
146 with code.indent():
147 # docstring
148 code.writeline("BLOCK_SIZE = 2048")
149 code.writeline("grid = (triton.cdiv(out0.numel(), BLOCK_SIZE), 1, 1)")
150 code.newline()
152 code.writeline("x_shape = in0.shape")
153 code.writeline("in_strides0 = in0.stride()")
154 code.writeline("out_strides = out0.stride()")
156 # input strides for each input tensor w.r.t. the task index space
157 if rank > 0:
158 code.writeline("# strides of each tensor argument w.r.t the task space")
159 for i in range(rank):
160 code.writeline(f"valid_dim{i}_start = pad_before[{i}]")
162 code.writeline(f"valid_dim{i}_end = dst_shape[{i}] - pad_after[{i}]")
164 code.newline()
166 code.writeline("# Check which dimensions have padding")
167 for i in range(rank):
168 code.writeline(
169 f"dim{i}_has_pad = pad_before[{i}] > 0 or pad_after[{i}] > 0"
170 )
171 code.writeline("IS_CONSTANT = mode == 'constant'")
172 code.writeline("IS_REFLECT = mode == 'reflect'")
173 code.writeline("IS_REPLICATE = mode == 'replicate'")
174 code.writeline("IS_CIRCULAR = mode == 'circular'")
176 code.newline()
178 # grid
179 code.writeline("# kernel launch")
181 # launch kernel
182 code.writeline("with torch_device_fn.device(in0.device):")
183 with code.indent():
184 kernel_launch: str = f"{kernel_name}[grid]("
185 code.writeline(kernel_launch)
187 with code.indent():
188 code.writeline("in0, out0, ")
190 if rank > 0:
191 s = ", ".join(f"x_shape[{j}]" for j in range(rank))
192 code.writeline(f"{s}, # shape for x")
194 s = ", ".join(f"in_strides0[{j}]" for j in range(rank))
195 code.writeline(f"{s}, # stride for x")
197 s = ", ".join(f"out_strides[{j}]" for j in range(rank))
198 code.writeline(f"{s}, # stride for out")
200 s = ", ".join(f"valid_dim{j}_start" for j in range(rank))
201 code.writeline(f"{s}, # valid dim start")
203 s = ", ".join(f"valid_dim{j}_end" for j in range(rank))
204 code.writeline(f"{s}, # valid dim end")
206 s = ", ".join(f"bool(dim{i}_has_pad)" for i in range(rank))
207 code.writeline(f"{s}, # dim has padding flags")
209 code.writeline("in0.numel(), ")
210 code.writeline("out0.numel(), ")
211 code.writeline("value, ")
212 code.writeline("IS_CONSTANT, ")
213 code.writeline("IS_REFLECT, ")
214 code.writeline("IS_REPLICATE, ")
215 code.writeline("IS_CIRCULAR, ")
216 code.writeline("BLOCK_SIZE, ")
217 code.writeline(")")
219 code.writeline("return out0")
220 code.newline()
221 code.newline()
222 return code
225def generate_pad_kernel(
226 rank: int,
227 kernel_name: str,
228 code: IndentedBuffer,
229) -> IndentedBuffer:
230 # make the inlined function visible in the context
231 code.newline()
233 # the decorators
234 code.writeline("@libentry()")
235 non_specialize_arg_names = ["value"]
236 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})")
238 # signature
239 code.writeline(f"def {kernel_name}(")
240 with code.indent():
241 code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type")
243 code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type")
245 if rank > 0:
246 # shape for inputs
247 shape_args = ", ".join(f"x_shape{j}: int" for j in range(rank))
248 code.writeline(f"{shape_args}, # shape for x")
250 # shape for inputs
251 stride_args = ", ".join(f"in_strides{j}: int" for j in range(rank))
252 code.writeline(f"{stride_args}, # stride for x")
254 # shape for inputs
255 stride_args = ", ".join(f"out_strides{j}: int" for j in range(rank))
256 code.writeline(f"{stride_args}, # stride for out")
258 # shape for inputs
259 stride_args = ", ".join(f"valid_dim{j}_start: int" for j in range(rank))
260 code.writeline(f"{stride_args}, # valid dim start")
262 # shape for inputs
263 stride_args = ", ".join(f"valid_dim{j}_end: int" for j in range(rank))
264 code.writeline(f"{stride_args}, # valid dim end")
266 for i in range(rank):
267 code.writeline(f"dim{i}_has_pad: tl.constexpr, ")
269 code.writeline("in_elem_cnt: tl.constexpr, ")
270 code.writeline("out_elem_cnt: tl.constexpr, ")
271 code.writeline("value, # padding value")
272 code.writeline("IS_CONSTANT: tl.constexpr, ")
273 code.writeline("IS_REFLECT: tl.constexpr, ")
274 code.writeline("IS_REPLICATE: tl.constexpr, ")
275 code.writeline("IS_CIRCULAR: tl.constexpr, ")
276 code.writeline("BLOCK_SIZE: tl.constexpr, ")
278 code.writeline("):")
280 with code.indent():
281 code.writeline("pid = tl.program_id(0)")
282 code.writeline("block_offset = pid * BLOCK_SIZE")
283 code.writeline("offset = block_offset + tl.arange(0, BLOCK_SIZE)")
284 code.newline()
286 code.writeline("remaining = offset ")
287 for i in range(rank):
288 code.writeline(f"idx = remaining // out_strides{i}")
289 code.writeline(f"dst_index_{i} = idx")
290 code.writeline(f"remaining = remaining - idx * out_strides{i}")
291 code.newline()
293 code.writeline("if_pad_false_mask = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32)")
294 code.writeline("if_pad_true_mask = tl.full((BLOCK_SIZE, ), 1, dtype=tl.int32)")
296 code.writeline(
297 "cond = ((dst_index_0 >= valid_dim0_start) & (dst_index_0 < valid_dim0_end))"
298 )
300 for i in range(1, rank):
301 code.writeline(
302 f"cond &= ((dst_index_{i} >= valid_dim{i}_start) & (dst_index_{i} < valid_dim{i}_end))"
303 )
305 code.writeline(
306 "if_pad = tl.where(cond, if_pad_false_mask, if_pad_true_mask).to(tl.int1)"
307 )
309 for i in range(rank):
310 code.writeline(f"src_index_{i} = dst_index_{i} - valid_dim{i}_start ")
312 for i in range(rank):
313 code.writeline(
314 f"src_index_{i} = tl.where(src_index_{i} < 0, 0, src_index_{i})"
315 )
317 code.newline()
318 code.writeline("if IS_REFLECT: ")
319 with code.indent():
320 for i in range(rank):
321 code.writeline(
322 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} < valid_dim{i}_start),
323 valid_dim{i}_start - dst_index_{i}, src_index_{i})"""
324 )
325 for i in range(rank):
326 code.writeline(
327 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} >= valid_dim{i}_end),
328 (x_shape{i} + valid_dim{i}_start - 1) * 2 - dst_index_{i} - valid_dim{i}_start, src_index_{i})"""
329 )
331 code.newline()
332 code.writeline("if IS_REPLICATE: ")
333 with code.indent():
334 for i in range(rank):
335 code.writeline(
336 f"src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} < valid_dim{i}_start), 0, src_index_{i})"
337 )
338 for i in range(rank):
339 end_cond = f"dst_index_{i} >= valid_dim{i}_end"
340 code.writeline(
341 f"src_index_{i} = tl.where(dim{i}_has_pad & ({end_cond}), "
342 f"x_shape{i} - 1, src_index_{i})"
343 )
345 code.newline()
346 code.writeline("if IS_CIRCULAR: ")
347 with code.indent():
348 for i in range(rank):
349 code.writeline(
350 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} < valid_dim{i}_start),
351 dst_index_{i} + x_shape{i} - valid_dim{i}_start, src_index_{i})"""
352 )
353 for i in range(rank):
354 code.writeline(
355 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} >= valid_dim{i}_end),
356 dst_index_{i} - valid_dim{i}_end, src_index_{i})"""
357 )
359 code.newline()
361 code.writeline("src_offset = src_index_0 * in_strides0")
362 for i in range(1, rank):
363 code.writeline(f"src_offset += src_index_{i} * in_strides{i}")
365 code.writeline("load_cond = src_index_0 < x_shape0")
366 for i in range(1, rank):
367 code.writeline(f"load_cond &= src_index_{i} < x_shape{i}")
369 code.writeline("if IS_CONSTANT: ")
370 with code.indent():
371 code.writeline(
372 "x_val = tl.load(in0_ptr + src_offset, mask=((if_pad == 0) & load_cond), other=value)"
373 )
374 code.writeline("else: ")
375 with code.indent():
376 code.writeline(
377 "x_val = tl.load(in0_ptr + src_offset, mask=load_cond, other=0)"
378 )
379 code.writeline("tl.store(out0_ptr + offset, x_val, mask=offset < out_elem_cnt)")
381 return code
384def generate_code(
385 inputs: Tuple[Any],
386 wrapper_name: str,
387 destination_passing_func_name: str,
388 kernel_name: str,
389 code: IndentedBuffer,
390) -> IndentedBuffer:
391 shape = inputs[0].shape
392 rank = len(shape)
394 # the only runtime determined factor is the rank of the task space
395 code = generate_imports(code)
396 code = generate_functional_padding_wrapper(
397 wrapper_name, destination_passing_func_name, code
398 )
399 code = generate_destination_passing_padding_wrapper(
400 rank, destination_passing_func_name, kernel_name, code
401 )
402 code = generate_pad_kernel(rank, kernel_name, code)
403 return code
406class PadFunction:
407 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction
408 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors).
409 The generated code are written out to the cache directory (defaults to ~/.flaggems).
410 """
412 def __init__(self):
413 self.pid = os.getpid()
414 self.overloads: Mapping[str, Callable] = {}
416 def __call__(self, *args, **kwargs):
417 # note: kwargs should not be used in JITFunction directly
418 key = f"{self.arg_key(*args)}"
419 if key in self.overloads:
420 overload = self.overloads[key]
421 else:
422 # generate file & import it
423 code = IndentedBuffer()
424 code = generate_code(
425 args,
426 "_pad_wrapper",
427 "_pad_wrapper_out",
428 "_pad_jit_function",
429 code,
430 )
432 file_name = f"constant_pad_rank_{key}_pid_{self.pid}.py"
434 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
435 f.write(code.getvalue())
437 # load
438 spec = importlib.util.spec_from_file_location(
439 f"_gen_module_rank_{key}_pid_{self.pid}",
440 f.name,
441 )
443 m = importlib.util.module_from_spec(spec)
444 # do not expose it to sys.modules
445 # sys.modules["_add_module"] = m
446 spec.loader.exec_module(m)
447 overload = getattr(m, "_pad_wrapper")
448 self.overloads[key] = overload
449 return overload(*args, **kwargs)
451 def arg_key(self, *args):
452 tensors = [item for item in args if torch.is_tensor(item)]
453 max_rank = max(item.ndim for item in tensors)
454 return max_rank
457_pad_func = PadFunction()
460@libentry()
461@triton.autotune(
462 configs=[
463 triton.Config({"BLOCK_SIZE": 2**n}, num_stages=s)
464 for n in range(10, 16, 2)
465 for s in [1, 3]
466 ],
467 key=["inp_elements"],
468)
469@triton.jit
470def pad_1d_constant_kernel(
471 inp_ptr,
472 out_ptr,
473 inp_elements,
474 pad_value,
475 pad_left,
476 pad_right,
477 BLOCK_SIZE: tl.constexpr,
478):
479 pid = tl.program_id(0)
480 num_jobs = tl.num_programs(0)
481 start = pid * BLOCK_SIZE
482 step = num_jobs * BLOCK_SIZE
483 out_elements = pad_left + inp_elements + pad_right
484 for off in range(start, out_elements, step):
485 inp_offset = off + tl.arange(0, BLOCK_SIZE) - pad_left
486 inp_mask = inp_offset >= 0 and inp_offset < inp_elements
487 inp = tl.load(inp_ptr + inp_offset, mask=inp_mask, other=pad_value)
488 out_offset = off + tl.arange(0, BLOCK_SIZE)
489 out_mask = out_offset < out_elements
490 tl.store(out_ptr + out_offset, inp, mask=out_mask)
493@libentry()
494@triton.autotune(
495 configs=[
496 triton.Config({"BLOCK_H": n}, num_stages=s)
497 for n in [1, 4, 8, 12, 16, 24]
498 for s in [1, 3]
499 ],
500 key=["H", "W"],
501)
502@triton.jit
503def pad_2d_constant_kernel(
504 inp_ptr,
505 out_ptr,
506 H,
507 W: tl.constexpr,
508 pad_value,
509 pad_left: tl.constexpr,
510 pad_right: tl.constexpr,
511 pad_top,
512 pad_bottom,
513 BLOCK_H: tl.constexpr,
514):
515 pid = tl.program_id(0)
516 num_jobs = tl.num_programs(0)
517 block_start = pid * BLOCK_H
518 step = num_jobs * BLOCK_H
519 out_W: tl.constexpr = pad_left + W + pad_right
520 out_H = pad_top + H + pad_bottom
521 for batch_idx in range(block_start, out_H, step):
522 offset_h = tl.arange(0, BLOCK_H) + batch_idx - pad_top
523 offset_w = tl.arange(0, out_W) - pad_left
524 offsets = offset_h[:, None] * W + offset_w[None, :]
525 mask = (offset_h[:, None] >= 0 and offset_h[:, None] < H) and (
526 offset_w[None, :] >= 0 and offset_w[None, :] < W
527 )
528 inp = tl.load(inp_ptr + offsets, mask=mask, other=pad_value)
530 out_offset_c = tl.arange(0, out_W)
531 out_offset_n = tl.arange(0, BLOCK_H) + batch_idx
532 out_offsets = out_offset_n[:, None] * out_W + out_offset_c[None, :]
533 out_mask = out_offset_n[:, None] < out_H and out_offset_c[None, :] < out_W
534 tl.store(out_ptr + out_offsets, inp, mask=out_mask)
537def pad(self, pad, mode="constant", value=None):
538 logger.debug("GEMS_CAMBRICON CONSTANT PAD ND")
540 ndim = self.ndim
541 pad_size = len(pad)
542 assert pad_size % 2 == 0
544 if value is None:
545 value = 0.0
547 if mode == "constant":
548 pad_before = [0 for _ in range(ndim)]
549 pad_after = [0 for _ in range(ndim)]
550 pad_pair = pad_size // 2
551 for i in range(pad_pair):
552 pad_before[ndim - i - 1] = pad[2 * i]
553 pad_after[ndim - i - 1] = pad[2 * i + 1]
555 inp_shape = list(self.shape)
556 out_shape = list(self.shape)
557 for i in range(ndim):
558 out_shape[i] += pad_before[i] + pad_after[i]
559 out = torch.empty(out_shape, dtype=self.dtype, device=self.device)
561 if ndim == 1:
562 grid = lambda meta: (
563 min(triton.cdiv(out_shape[0], meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),
564 )
565 pad_1d_constant_kernel[grid](
566 self.contiguous(),
567 out,
568 inp_shape[0],
569 value,
570 pad_before[-1],
571 pad_after[-1],
572 )
573 return out
575 if ndim == 2:
576 grid = lambda meta: (
577 min(triton.cdiv(out_shape[0], meta["BLOCK_H"]), TOTAL_CORE_NUM),
578 )
579 pad_2d_constant_kernel[grid](
580 self.contiguous(),
581 out,
582 inp_shape[0],
583 inp_shape[1],
584 value,
585 pad_before[-1],
586 pad_after[-1],
587 pad_before[-2],
588 pad_after[-2],
589 )
590 return out
592 if ndim == 3:
593 out[: pad_before[0]] = torch.full(
594 out[0 : pad_before[0]].shape,
595 value,
596 dtype=self.dtype,
597 device=self.device,
598 )
599 out[pad_before[0] + inp_shape[0] :] = torch.full(
600 out[pad_before[0] + inp_shape[0] :].shape,
601 value,
602 dtype=self.dtype,
603 device=self.device,
604 )
606 for i in range(pad_before[0], pad_before[0] + inp_shape[0]):
607 grid = lambda meta: (
608 min(triton.cdiv(out_shape[1], meta["BLOCK_H"]), TOTAL_CORE_NUM),
609 )
610 pad_2d_constant_kernel[grid](
611 self[i - pad_before[0]].contiguous(),
612 out[i],
613 inp_shape[1],
614 inp_shape[2],
615 value,
616 pad_before[-1],
617 pad_after[-1],
618 pad_before[-2],
619 pad_after[-2],
620 )
621 return out
623 pad_pairs = len(pad) // 2
625 if mode == "reflect":
626 for i in range(pad_pairs):
627 pad_l, pad_r = pad[2 * i], pad[2 * i + 1]
628 input_size = self.shape[ndim - 1 - i]
629 assert (
630 pad_l < input_size and pad_r < input_size
631 ), \
632 f"padding size should be less than the corresponding input dimension, \
633 but got padding size: {pad_l}, {pad_r}, input size: {self.shape}"
635 if mode == "circular":
636 for i in range(pad_pairs):
637 pad_l, pad_r = pad[2 * i], pad[2 * i + 1]
638 input_size = self.shape[ndim - 1 - i]
639 assert (
640 pad_l <= input_size and pad_r <= input_size
641 ), "Padding value causes wrapping around more than once."
643 out = _pad_func(self, pad, mode, float(value))
644 return out
647def constant_pad_nd(self, pad_list, value=0):
648 return pad(self, pad_list, mode="constant", value=value)