Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/reflection_pad2d.py: 0%

91 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def reflection_pad2d_kernel( 

15 in_ptr, 

16 out_ptr, 

17 B, 

18 H_in, 

19 W_in, 

20 pad_left, 

21 pad_top, 

22 H_out, 

23 W_out, 

24 BLOCK_HW: tl.constexpr, 

25): 

26 pid_b = tl.program_id(axis=0) 

27 pid_n = tl.program_id(axis=1) 

28 

29 # Use modulo wrap to keep all store indices in [0, H_out * W_out). 

30 # On KunlunXin, masked tl.store does not suppress writes for masked-out 

31 # threads without TRITONXPU_STORE_MASK_SIM=1, causing corruption of 

32 # adjacent batch row data. The modulo wrap means tail-block threads simply 

33 # re-write already-computed values to valid positions — harmless. 

34 HW_out = H_out * W_out 

35 offs_n = (pid_n * BLOCK_HW + tl.arange(0, BLOCK_HW)) % HW_out 

36 

37 # Decode to (h, w) coordinates 

38 h_idx = offs_n // W_out 

39 w_idx = offs_n % W_out 

40 

41 base_in = pid_b * (H_in * W_in) 

42 base_out = pid_b * HW_out 

43 

44 # Compute reflected indices for height 

45 y = h_idx.to(tl.int32) - pad_top 

46 Hm1 = H_in - 1 

47 pH = 2 * Hm1 

48 t_h = tl.abs(y) 

49 m_h = t_h % pH 

50 ih = tl.where(m_h < H_in, m_h, pH - m_h) 

51 

52 # Compute reflected indices for width 

53 x = w_idx.to(tl.int32) - pad_left 

54 Wm1 = W_in - 1 

55 pW = 2 * Wm1 

56 t_w = tl.abs(x) 

57 m_w = t_w % pW 

58 iw = tl.where(m_w < W_in, m_w, pW - m_w) 

59 

60 # No mask needed: offs_n is in [0, HW_out) and in_offs is in [0, H_in*W_in) 

61 in_offs = ih * W_in + iw 

62 vals = tl.load(in_ptr + base_in + in_offs) 

63 tl.store(out_ptr + base_out + offs_n, vals) 

64 

65 

66@triton.jit 

67def copy_tensor_kernel(in_ptr, out_ptr, B, H, W, BLOCK_HW: tl.constexpr): 

68 pid_b = tl.program_id(axis=0) 

69 pid_n = tl.program_id(axis=1) 

70 

71 # Use modulo wrap to avoid masked stores (same KunlunXin workaround). 

72 HW = H * W 

73 offs_n = (pid_n * BLOCK_HW + tl.arange(0, BLOCK_HW)) % HW 

74 

75 base = pid_b * HW 

76 vals = tl.load(in_ptr + base + offs_n) 

77 tl.store(out_ptr + base + offs_n, vals) 

78 

79 

80def launch_reflection_pad2d(input: torch.Tensor, padding, out: torch.Tensor = None): 

81 # Validate padding format 

82 if not isinstance(padding, (list, tuple)): 

83 raise ValueError("padding must be a sequence") 

84 if len(padding) != 4: 

85 raise ValueError( 

86 "padding must be a sequence of length 4: (pad_left, pad_right, pad_top, pad_bottom)" 

87 ) 

88 pad_left, pad_right, pad_top, pad_bottom = [int(p) for p in padding] 

89 

90 # Validate padding values 

91 if pad_left < 0 or pad_right < 0 or pad_top < 0 or pad_bottom < 0: 

92 raise ValueError("padding values must be >= 0") 

93 

94 # Validate input 

95 if input.dim() < 3: 

96 raise ValueError("input must have at least 3 dimensions") 

97 

98 x = input.contiguous() 

99 H_in = int(x.shape[-2]) 

100 W_in = int(x.shape[-1]) 

101 # Validate reflection padding constraints 

102 if H_in < 2 or W_in < 2: 

103 raise ValueError( 

104 "input spatial dimensions must be at least 2 for reflection padding when padding > 0" 

105 ) 

106 if H_in <= 0 or W_in <= 0: 

107 raise ValueError("spatial dimensions must be > 0") 

108 if pad_left >= W_in or pad_right >= W_in or pad_top >= H_in or pad_bottom >= H_in: 

109 raise ValueError( 

110 "padding values must be less than the input spatial dimensions for reflection padding" 

111 ) 

112 

113 H_out = H_in + pad_top + pad_bottom 

114 W_out = W_in + pad_left + pad_right 

115 

116 leading_shape = x.shape[:-2] 

117 B = int(math.prod(leading_shape)) if len(leading_shape) > 0 else 1 

118 

119 # Handle output tensor 

120 if out is None: 

121 out = torch.empty( 

122 (*leading_shape, H_out, W_out), device=x.device, dtype=x.dtype 

123 ) 

124 else: 

125 expected_shape = (*leading_shape, H_out, W_out) 

126 if tuple(out.shape) != expected_shape: 

127 raise ValueError( 

128 f"out tensor has shape {tuple(out.shape)}, expected {expected_shape}" 

129 ) 

130 if out.dtype != x.dtype: 

131 raise ValueError( 

132 f"out dtype {out.dtype} does not match input dtype {x.dtype}" 

133 ) 

134 if out.device != x.device: 

135 raise ValueError("out must be on the same device as input") 

136 out = out.contiguous() 

137 

138 # No padding: just copy 

139 if pad_left == 0 and pad_right == 0 and pad_top == 0 and pad_bottom == 0: 

140 BLOCK_HW = 256 

141 grid = (B, triton.cdiv(H_in * W_in, BLOCK_HW)) 

142 with torch_device_fn.device(x.device): 

143 copy_tensor_kernel[grid](x, out, B, H_in, W_in, BLOCK_HW=BLOCK_HW) 

144 return out 

145 

146 BLOCK_HW = 256 

147 grid = (B, triton.cdiv(H_out * W_out, BLOCK_HW)) 

148 with torch_device_fn.device(x.device): 

149 reflection_pad2d_kernel[grid]( 

150 x, out, B, H_in, W_in, pad_left, pad_top, H_out, W_out, BLOCK_HW=BLOCK_HW 

151 ) 

152 return out 

153 

154 

155def reflection_pad2d(input: torch.Tensor, padding): 

156 logger.debug("GEMS_KUNLUNXIN REFLECTION_PAD2D") 

157 return launch_reflection_pad2d(input, padding, out=None) 

158 

159 

160def reflection_pad2d_out(input: torch.Tensor, padding, out: torch.Tensor): 

161 logger.debug("GEMS_KUNLUNXIN REFLECTION_PAD2D_OUT") 

162 return launch_reflection_pad2d(input, padding, out=out)