Coverage for src/flag_gems/ops/reflection_pad1d_backward.py: 47%

89 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3import math 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import torch_device_fn 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@triton.jit 

15def reflection_pad1d_backward_kernel( 

16 grad_output_ptr, 

17 grad_input_ptr, 

18 B, 

19 W_in, 

20 pad_left, 

21 W_out, 

22 BLOCK_W: tl.constexpr, 

23): 

24 pid_b = tl.program_id(axis=0) 

25 pid_w = tl.program_id(axis=1) 

26 

27 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W) 

28 mask_out = offs_w < W_out 

29 

30 base_out = pid_b * W_out 

31 base_in = pid_b * W_in 

32 

33 # Load gradient from output and cast to float32 for accumulation 

34 grad = tl.load(grad_output_ptr + base_out + offs_w, mask=mask_out, other=0.0) 

35 grad_f32 = grad.to(tl.float32) 

36 

37 # Compute reflected index for each output position 

38 x = offs_w.to(tl.int32) - pad_left 

39 Wm1 = W_in - 1 

40 p = 2 * Wm1 

41 

42 t = tl.abs(x) 

43 m = t % p 

44 iw = tl.where(m < W_in, m, p - m) 

45 

46 # Atomically accumulate gradient to input positions (float32 for atomic safety) 

47 grad_input_offset = base_in + iw 

48 tl.atomic_add(grad_input_ptr + grad_input_offset, grad_f32, mask=mask_out) 

49 

50 

51def _launch_reflection_pad1d_backward( 

52 grad_output: torch.Tensor, 

53 input: torch.Tensor, 

54 padding, 

55): 

56 if not isinstance(padding, (list, tuple)) or len(padding) != 2: 

57 raise ValueError( 

58 "padding must be a sequence of length 2: (pad_left, pad_right)" 

59 ) 

60 pad_left, pad_right = int(padding[0]), int(padding[1]) 

61 if pad_left < 0 or pad_right < 0: 

62 raise ValueError("padding values must be >= 0") 

63 if input.dim() < 1: 

64 raise ValueError("input must have at least 1 dimension") 

65 

66 grad_output = grad_output.contiguous() 

67 x = input.contiguous() 

68 

69 W_in = int(x.shape[-1]) 

70 if W_in <= 0: 

71 raise ValueError("last dimension (width) must be > 0") 

72 

73 W_out = W_in + pad_left + pad_right 

74 leading_shape = x.shape[:-1] 

75 B = int(math.prod(leading_shape)) if len(leading_shape) > 0 else 1 

76 

77 # Initialize grad_input as float32 for safe accumulation 

78 grad_input = torch.zeros_like(x, dtype=torch.float32) 

79 

80 # No padding case - just copy 

81 if pad_left == 0 and pad_right == 0: 

82 if W_out != W_in: 

83 raise RuntimeError( 

84 "Internal error: W_out should equal W_in when no padding" 

85 ) 

86 grid = (B, triton.cdiv(W_in, 256)) 

87 with torch_device_fn.device(x.device): 

88 _copy_rows_kernel_f32[grid](grad_output, grad_input, B, W_in, BLOCK_W=256) 

89 if grad_input.dtype == x.dtype: 

90 return grad_input 

91 result = torch.empty_like(x) 

92 with torch_device_fn.device(x.device): 

93 _copy_rows_kernel[grid](grad_input, result, B, W_in, BLOCK_W=256) 

94 return result 

95 

96 # Validate input dimensions 

97 if W_in < 2: 

98 raise ValueError( 

99 "input width must be at least 2 for reflection padding when padding > 0" 

100 ) 

101 if pad_left >= W_in or pad_right >= W_in: 

102 raise ValueError( 

103 "padding values must be less than the input width for reflection padding" 

104 ) 

105 

106 grid = (B, triton.cdiv(W_out, 256)) 

107 with torch_device_fn.device(x.device): 

108 reflection_pad1d_backward_kernel[grid]( 

109 grad_output, grad_input, B, W_in, pad_left, W_out, BLOCK_W=256 

110 ) 

111 if grad_input.dtype == x.dtype: 

112 return grad_input 

113 result = torch.empty_like(x) 

114 cast_grid = (B, triton.cdiv(W_in, 256)) 

115 with torch_device_fn.device(x.device): 

116 _copy_rows_kernel[cast_grid](grad_input, result, B, W_in, BLOCK_W=256) 

117 return result 

118 

119 

120@triton.jit 

121def _copy_rows_kernel_f32(in_ptr, out_ptr, B, W, BLOCK_W: tl.constexpr): 

122 pid_b = tl.program_id(axis=0) 

123 pid_w = tl.program_id(axis=1) 

124 

125 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W) 

126 mask = (offs_w < W) & (pid_b < B) 

127 

128 base = pid_b * W 

129 vals = tl.load(in_ptr + base + offs_w, mask=mask, other=0.0).to(tl.float32) 

130 tl.store(out_ptr + base + offs_w, vals, mask=mask) 

131 

132 

133@triton.jit 

134def _copy_rows_kernel(in_ptr, out_ptr, B, W, BLOCK_W: tl.constexpr): 

135 pid_b = tl.program_id(axis=0) 

136 pid_w = tl.program_id(axis=1) 

137 

138 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W) 

139 mask = (offs_w < W) & (pid_b < B) 

140 

141 base = pid_b * W 

142 vals = tl.load(in_ptr + base + offs_w, mask=mask, other=0) 

143 tl.store(out_ptr + base + offs_w, vals, mask=mask) 

144 

145 

146def reflection_pad1d_backward(grad_output: torch.Tensor, input: torch.Tensor, padding): 

147 logger.debug("GEMS REFLECTION_PAD1D_BACKWARD") 

148 return _launch_reflection_pad1d_backward(grad_output, input, padding)