Coverage for src/flag_gems/ops/reflection_pad1d_backward.py: 47%
89 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
3import math
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import torch_device_fn
11logger = logging.getLogger(__name__)
14@triton.jit
15def reflection_pad1d_backward_kernel(
16 grad_output_ptr,
17 grad_input_ptr,
18 B,
19 W_in,
20 pad_left,
21 W_out,
22 BLOCK_W: tl.constexpr,
23):
24 pid_b = tl.program_id(axis=0)
25 pid_w = tl.program_id(axis=1)
27 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W)
28 mask_out = offs_w < W_out
30 base_out = pid_b * W_out
31 base_in = pid_b * W_in
33 # Load gradient from output and cast to float32 for accumulation
34 grad = tl.load(grad_output_ptr + base_out + offs_w, mask=mask_out, other=0.0)
35 grad_f32 = grad.to(tl.float32)
37 # Compute reflected index for each output position
38 x = offs_w.to(tl.int32) - pad_left
39 Wm1 = W_in - 1
40 p = 2 * Wm1
42 t = tl.abs(x)
43 m = t % p
44 iw = tl.where(m < W_in, m, p - m)
46 # Atomically accumulate gradient to input positions (float32 for atomic safety)
47 grad_input_offset = base_in + iw
48 tl.atomic_add(grad_input_ptr + grad_input_offset, grad_f32, mask=mask_out)
51def _launch_reflection_pad1d_backward(
52 grad_output: torch.Tensor,
53 input: torch.Tensor,
54 padding,
55):
56 if not isinstance(padding, (list, tuple)) or len(padding) != 2:
57 raise ValueError(
58 "padding must be a sequence of length 2: (pad_left, pad_right)"
59 )
60 pad_left, pad_right = int(padding[0]), int(padding[1])
61 if pad_left < 0 or pad_right < 0:
62 raise ValueError("padding values must be >= 0")
63 if input.dim() < 1:
64 raise ValueError("input must have at least 1 dimension")
66 grad_output = grad_output.contiguous()
67 x = input.contiguous()
69 W_in = int(x.shape[-1])
70 if W_in <= 0:
71 raise ValueError("last dimension (width) must be > 0")
73 W_out = W_in + pad_left + pad_right
74 leading_shape = x.shape[:-1]
75 B = int(math.prod(leading_shape)) if len(leading_shape) > 0 else 1
77 # Initialize grad_input as float32 for safe accumulation
78 grad_input = torch.zeros_like(x, dtype=torch.float32)
80 # No padding case - just copy
81 if pad_left == 0 and pad_right == 0:
82 if W_out != W_in:
83 raise RuntimeError(
84 "Internal error: W_out should equal W_in when no padding"
85 )
86 grid = (B, triton.cdiv(W_in, 256))
87 with torch_device_fn.device(x.device):
88 _copy_rows_kernel_f32[grid](grad_output, grad_input, B, W_in, BLOCK_W=256)
89 if grad_input.dtype == x.dtype:
90 return grad_input
91 result = torch.empty_like(x)
92 with torch_device_fn.device(x.device):
93 _copy_rows_kernel[grid](grad_input, result, B, W_in, BLOCK_W=256)
94 return result
96 # Validate input dimensions
97 if W_in < 2:
98 raise ValueError(
99 "input width must be at least 2 for reflection padding when padding > 0"
100 )
101 if pad_left >= W_in or pad_right >= W_in:
102 raise ValueError(
103 "padding values must be less than the input width for reflection padding"
104 )
106 grid = (B, triton.cdiv(W_out, 256))
107 with torch_device_fn.device(x.device):
108 reflection_pad1d_backward_kernel[grid](
109 grad_output, grad_input, B, W_in, pad_left, W_out, BLOCK_W=256
110 )
111 if grad_input.dtype == x.dtype:
112 return grad_input
113 result = torch.empty_like(x)
114 cast_grid = (B, triton.cdiv(W_in, 256))
115 with torch_device_fn.device(x.device):
116 _copy_rows_kernel[cast_grid](grad_input, result, B, W_in, BLOCK_W=256)
117 return result
120@triton.jit
121def _copy_rows_kernel_f32(in_ptr, out_ptr, B, W, BLOCK_W: tl.constexpr):
122 pid_b = tl.program_id(axis=0)
123 pid_w = tl.program_id(axis=1)
125 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W)
126 mask = (offs_w < W) & (pid_b < B)
128 base = pid_b * W
129 vals = tl.load(in_ptr + base + offs_w, mask=mask, other=0.0).to(tl.float32)
130 tl.store(out_ptr + base + offs_w, vals, mask=mask)
133@triton.jit
134def _copy_rows_kernel(in_ptr, out_ptr, B, W, BLOCK_W: tl.constexpr):
135 pid_b = tl.program_id(axis=0)
136 pid_w = tl.program_id(axis=1)
138 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W)
139 mask = (offs_w < W) & (pid_b < B)
141 base = pid_b * W
142 vals = tl.load(in_ptr + base + offs_w, mask=mask, other=0)
143 tl.store(out_ptr + base + offs_w, vals, mask=mask)
146def reflection_pad1d_backward(grad_output: torch.Tensor, input: torch.Tensor, padding):
147 logger.debug("GEMS REFLECTION_PAD1D_BACKWARD")
148 return _launch_reflection_pad1d_backward(grad_output, input, padding)