Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/reflection_pad2d.py: 0%
91 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
10logger = logging.getLogger(__name__)
13@triton.jit
14def reflection_pad2d_kernel(
15 in_ptr,
16 out_ptr,
17 B,
18 H_in,
19 W_in,
20 pad_left,
21 pad_top,
22 H_out,
23 W_out,
24 BLOCK_HW: tl.constexpr,
25):
26 pid_b = tl.program_id(axis=0)
27 pid_n = tl.program_id(axis=1)
29 # Use modulo wrap to keep all store indices in [0, H_out * W_out).
30 # On KunlunXin, masked tl.store does not suppress writes for masked-out
31 # threads without TRITONXPU_STORE_MASK_SIM=1, causing corruption of
32 # adjacent batch row data. The modulo wrap means tail-block threads simply
33 # re-write already-computed values to valid positions — harmless.
34 HW_out = H_out * W_out
35 offs_n = (pid_n * BLOCK_HW + tl.arange(0, BLOCK_HW)) % HW_out
37 # Decode to (h, w) coordinates
38 h_idx = offs_n // W_out
39 w_idx = offs_n % W_out
41 base_in = pid_b * (H_in * W_in)
42 base_out = pid_b * HW_out
44 # Compute reflected indices for height
45 y = h_idx.to(tl.int32) - pad_top
46 Hm1 = H_in - 1
47 pH = 2 * Hm1
48 t_h = tl.abs(y)
49 m_h = t_h % pH
50 ih = tl.where(m_h < H_in, m_h, pH - m_h)
52 # Compute reflected indices for width
53 x = w_idx.to(tl.int32) - pad_left
54 Wm1 = W_in - 1
55 pW = 2 * Wm1
56 t_w = tl.abs(x)
57 m_w = t_w % pW
58 iw = tl.where(m_w < W_in, m_w, pW - m_w)
60 # No mask needed: offs_n is in [0, HW_out) and in_offs is in [0, H_in*W_in)
61 in_offs = ih * W_in + iw
62 vals = tl.load(in_ptr + base_in + in_offs)
63 tl.store(out_ptr + base_out + offs_n, vals)
66@triton.jit
67def copy_tensor_kernel(in_ptr, out_ptr, B, H, W, BLOCK_HW: tl.constexpr):
68 pid_b = tl.program_id(axis=0)
69 pid_n = tl.program_id(axis=1)
71 # Use modulo wrap to avoid masked stores (same KunlunXin workaround).
72 HW = H * W
73 offs_n = (pid_n * BLOCK_HW + tl.arange(0, BLOCK_HW)) % HW
75 base = pid_b * HW
76 vals = tl.load(in_ptr + base + offs_n)
77 tl.store(out_ptr + base + offs_n, vals)
80def launch_reflection_pad2d(input: torch.Tensor, padding, out: torch.Tensor = None):
81 # Validate padding format
82 if not isinstance(padding, (list, tuple)):
83 raise ValueError("padding must be a sequence")
84 if len(padding) != 4:
85 raise ValueError(
86 "padding must be a sequence of length 4: (pad_left, pad_right, pad_top, pad_bottom)"
87 )
88 pad_left, pad_right, pad_top, pad_bottom = [int(p) for p in padding]
90 # Validate padding values
91 if pad_left < 0 or pad_right < 0 or pad_top < 0 or pad_bottom < 0:
92 raise ValueError("padding values must be >= 0")
94 # Validate input
95 if input.dim() < 3:
96 raise ValueError("input must have at least 3 dimensions")
98 x = input.contiguous()
99 H_in = int(x.shape[-2])
100 W_in = int(x.shape[-1])
101 # Validate reflection padding constraints
102 if H_in < 2 or W_in < 2:
103 raise ValueError(
104 "input spatial dimensions must be at least 2 for reflection padding when padding > 0"
105 )
106 if H_in <= 0 or W_in <= 0:
107 raise ValueError("spatial dimensions must be > 0")
108 if pad_left >= W_in or pad_right >= W_in or pad_top >= H_in or pad_bottom >= H_in:
109 raise ValueError(
110 "padding values must be less than the input spatial dimensions for reflection padding"
111 )
113 H_out = H_in + pad_top + pad_bottom
114 W_out = W_in + pad_left + pad_right
116 leading_shape = x.shape[:-2]
117 B = int(math.prod(leading_shape)) if len(leading_shape) > 0 else 1
119 # Handle output tensor
120 if out is None:
121 out = torch.empty(
122 (*leading_shape, H_out, W_out), device=x.device, dtype=x.dtype
123 )
124 else:
125 expected_shape = (*leading_shape, H_out, W_out)
126 if tuple(out.shape) != expected_shape:
127 raise ValueError(
128 f"out tensor has shape {tuple(out.shape)}, expected {expected_shape}"
129 )
130 if out.dtype != x.dtype:
131 raise ValueError(
132 f"out dtype {out.dtype} does not match input dtype {x.dtype}"
133 )
134 if out.device != x.device:
135 raise ValueError("out must be on the same device as input")
136 out = out.contiguous()
138 # No padding: just copy
139 if pad_left == 0 and pad_right == 0 and pad_top == 0 and pad_bottom == 0:
140 BLOCK_HW = 256
141 grid = (B, triton.cdiv(H_in * W_in, BLOCK_HW))
142 with torch_device_fn.device(x.device):
143 copy_tensor_kernel[grid](x, out, B, H_in, W_in, BLOCK_HW=BLOCK_HW)
144 return out
146 BLOCK_HW = 256
147 grid = (B, triton.cdiv(H_out * W_out, BLOCK_HW))
148 with torch_device_fn.device(x.device):
149 reflection_pad2d_kernel[grid](
150 x, out, B, H_in, W_in, pad_left, pad_top, H_out, W_out, BLOCK_HW=BLOCK_HW
151 )
152 return out
155def reflection_pad2d(input: torch.Tensor, padding):
156 logger.debug("GEMS_KUNLUNXIN REFLECTION_PAD2D")
157 return launch_reflection_pad2d(input, padding, out=None)
160def reflection_pad2d_out(input: torch.Tensor, padding, out: torch.Tensor):
161 logger.debug("GEMS_KUNLUNXIN REFLECTION_PAD2D_OUT")
162 return launch_reflection_pad2d(input, padding, out=out)