Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/reflection_pad1d.py: 0%
76 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
10logger = logging.getLogger(__name__)
13@triton.jit
14def reflection_pad1d_kernel(
15 in_ptr, out_ptr, B, W_in, pad_left, W_out, BLOCK_W: tl.constexpr
16):
17 pid_b = tl.program_id(axis=0)
18 pid_w = tl.program_id(axis=1)
20 # Use modulo wrap to keep all store indices in [0, W_out).
21 # On KunlunXin, masked tl.store does not suppress writes for masked-out
22 # threads without TRITONXPU_STORE_MASK_SIM=1, causing corruption of
23 # adjacent batch row data. The modulo wrap means tail-block threads simply
24 # re-write already-computed values to valid positions — harmless.
25 offs_w = (pid_w * BLOCK_W + tl.arange(0, BLOCK_W)) % W_out
27 base_in = pid_b * W_in
28 base_out = pid_b * W_out
30 # Compute reflected indices
31 x = offs_w.to(tl.int32) - pad_left # shift by left pad
32 Wm1 = W_in - 1
33 p = 2 * Wm1 # period for reflection; guaranteed > 0 when this kernel is used
35 t = tl.abs(x)
36 m = t % p
37 iw = tl.where(m < W_in, m, p - m)
39 # No mask needed: offs_w is in [0, W_out) and iw is in [0, W_in)
40 vals = tl.load(in_ptr + base_in + iw)
41 tl.store(out_ptr + base_out + offs_w, vals)
44@triton.jit
45def _copy_rows_kernel(in_ptr, out_ptr, B, W, BLOCK_W: tl.constexpr):
46 pid_b = tl.program_id(axis=0)
47 pid_w = tl.program_id(axis=1)
49 # Use modulo wrap to avoid masked stores (same KunlunXin workaround).
50 offs_w = (pid_w * BLOCK_W + tl.arange(0, BLOCK_W)) % W
52 base = pid_b * W
53 vals = tl.load(in_ptr + base + offs_w)
54 tl.store(out_ptr + base + offs_w, vals)
57def _launch_reflection_pad1d(input: torch.Tensor, padding, out: torch.Tensor = None):
58 if not isinstance(padding, (list, tuple)) or len(padding) != 2:
59 raise ValueError(
60 "padding must be a sequence of length 2: (pad_left, pad_right)"
61 )
62 pad_left, pad_right = int(padding[0]), int(padding[1])
63 if pad_left < 0 or pad_right < 0:
64 raise ValueError("padding values must be >= 0")
65 if input.dim() < 1:
66 raise ValueError("input must have at least 1 dimension")
68 x = input.contiguous()
69 W_in = int(x.shape[-1])
70 if W_in <= 0:
71 raise ValueError("last dimension (width) must be > 0")
73 W_out = W_in + pad_left + pad_right
74 leading_shape = x.shape[:-1]
75 B = int(math.prod(leading_shape)) if len(leading_shape) > 0 else 1
77 if out is None:
78 out = torch.empty((*leading_shape, W_out), device=x.device, dtype=x.dtype)
79 else:
80 expected_shape = (*leading_shape, W_out)
81 if tuple(out.shape) != expected_shape:
82 raise ValueError(
83 f"out tensor has shape {tuple(out.shape)}, expected {expected_shape}"
84 )
85 if out.dtype != x.dtype:
86 raise ValueError(
87 f"out dtype {out.dtype} does not match input dtype {x.dtype}"
88 )
89 if out.device != x.device:
90 raise ValueError("out must be on the same device as input")
91 out = out.contiguous()
93 # No padding: just copy
94 if pad_left == 0 and pad_right == 0:
95 if W_out != W_in:
96 raise RuntimeError(
97 "Internal error: W_out should equal W_in when no padding"
98 )
99 grid = (B, triton.cdiv(W_in, 256))
100 with torch_device_fn.device(x.device):
101 _copy_rows_kernel[grid](x, out, B, W_in, BLOCK_W=256)
102 return out
104 # Validate reflection padding constraints
105 if W_in < 2:
106 raise ValueError(
107 "input width must be at least 2 for reflection padding when padding > 0"
108 )
109 if pad_left >= W_in or pad_right >= W_in:
110 raise ValueError(
111 "padding values must be less than the input width for reflection padding"
112 )
114 grid = (B, triton.cdiv(W_out, 256))
115 with torch_device_fn.device(x.device):
116 reflection_pad1d_kernel[grid](x, out, B, W_in, pad_left, W_out, BLOCK_W=256)
117 return out
120def reflection_pad1d(input: torch.Tensor, padding):
121 logger.debug("GEMS_KUNLUNXIN REFLECTION_PAD1D")
122 return _launch_reflection_pad1d(input, padding, out=None)
125def reflection_pad1d_out(input: torch.Tensor, padding, out: torch.Tensor):
126 logger.debug("GEMS_KUNLUNXIN REFLECTION_PAD1D_OUT")
127 return _launch_reflection_pad1d(input, padding, out=out)