Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/reflection_pad1d.py: 0%

76 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def reflection_pad1d_kernel( 

15 in_ptr, out_ptr, B, W_in, pad_left, W_out, BLOCK_W: tl.constexpr 

16): 

17 pid_b = tl.program_id(axis=0) 

18 pid_w = tl.program_id(axis=1) 

19 

20 # Use modulo wrap to keep all store indices in [0, W_out). 

21 # On KunlunXin, masked tl.store does not suppress writes for masked-out 

22 # threads without TRITONXPU_STORE_MASK_SIM=1, causing corruption of 

23 # adjacent batch row data. The modulo wrap means tail-block threads simply 

24 # re-write already-computed values to valid positions — harmless. 

25 offs_w = (pid_w * BLOCK_W + tl.arange(0, BLOCK_W)) % W_out 

26 

27 base_in = pid_b * W_in 

28 base_out = pid_b * W_out 

29 

30 # Compute reflected indices 

31 x = offs_w.to(tl.int32) - pad_left # shift by left pad 

32 Wm1 = W_in - 1 

33 p = 2 * Wm1 # period for reflection; guaranteed > 0 when this kernel is used 

34 

35 t = tl.abs(x) 

36 m = t % p 

37 iw = tl.where(m < W_in, m, p - m) 

38 

39 # No mask needed: offs_w is in [0, W_out) and iw is in [0, W_in) 

40 vals = tl.load(in_ptr + base_in + iw) 

41 tl.store(out_ptr + base_out + offs_w, vals) 

42 

43 

44@triton.jit 

45def _copy_rows_kernel(in_ptr, out_ptr, B, W, BLOCK_W: tl.constexpr): 

46 pid_b = tl.program_id(axis=0) 

47 pid_w = tl.program_id(axis=1) 

48 

49 # Use modulo wrap to avoid masked stores (same KunlunXin workaround). 

50 offs_w = (pid_w * BLOCK_W + tl.arange(0, BLOCK_W)) % W 

51 

52 base = pid_b * W 

53 vals = tl.load(in_ptr + base + offs_w) 

54 tl.store(out_ptr + base + offs_w, vals) 

55 

56 

57def _launch_reflection_pad1d(input: torch.Tensor, padding, out: torch.Tensor = None): 

58 if not isinstance(padding, (list, tuple)) or len(padding) != 2: 

59 raise ValueError( 

60 "padding must be a sequence of length 2: (pad_left, pad_right)" 

61 ) 

62 pad_left, pad_right = int(padding[0]), int(padding[1]) 

63 if pad_left < 0 or pad_right < 0: 

64 raise ValueError("padding values must be >= 0") 

65 if input.dim() < 1: 

66 raise ValueError("input must have at least 1 dimension") 

67 

68 x = input.contiguous() 

69 W_in = int(x.shape[-1]) 

70 if W_in <= 0: 

71 raise ValueError("last dimension (width) must be > 0") 

72 

73 W_out = W_in + pad_left + pad_right 

74 leading_shape = x.shape[:-1] 

75 B = int(math.prod(leading_shape)) if len(leading_shape) > 0 else 1 

76 

77 if out is None: 

78 out = torch.empty((*leading_shape, W_out), device=x.device, dtype=x.dtype) 

79 else: 

80 expected_shape = (*leading_shape, W_out) 

81 if tuple(out.shape) != expected_shape: 

82 raise ValueError( 

83 f"out tensor has shape {tuple(out.shape)}, expected {expected_shape}" 

84 ) 

85 if out.dtype != x.dtype: 

86 raise ValueError( 

87 f"out dtype {out.dtype} does not match input dtype {x.dtype}" 

88 ) 

89 if out.device != x.device: 

90 raise ValueError("out must be on the same device as input") 

91 out = out.contiguous() 

92 

93 # No padding: just copy 

94 if pad_left == 0 and pad_right == 0: 

95 if W_out != W_in: 

96 raise RuntimeError( 

97 "Internal error: W_out should equal W_in when no padding" 

98 ) 

99 grid = (B, triton.cdiv(W_in, 256)) 

100 with torch_device_fn.device(x.device): 

101 _copy_rows_kernel[grid](x, out, B, W_in, BLOCK_W=256) 

102 return out 

103 

104 # Validate reflection padding constraints 

105 if W_in < 2: 

106 raise ValueError( 

107 "input width must be at least 2 for reflection padding when padding > 0" 

108 ) 

109 if pad_left >= W_in or pad_right >= W_in: 

110 raise ValueError( 

111 "padding values must be less than the input width for reflection padding" 

112 ) 

113 

114 grid = (B, triton.cdiv(W_out, 256)) 

115 with torch_device_fn.device(x.device): 

116 reflection_pad1d_kernel[grid](x, out, B, W_in, pad_left, W_out, BLOCK_W=256) 

117 return out 

118 

119 

120def reflection_pad1d(input: torch.Tensor, padding): 

121 logger.debug("GEMS_KUNLUNXIN REFLECTION_PAD1D") 

122 return _launch_reflection_pad1d(input, padding, out=None) 

123 

124 

125def reflection_pad1d_out(input: torch.Tensor, padding, out: torch.Tensor): 

126 logger.debug("GEMS_KUNLUNXIN REFLECTION_PAD1D_OUT") 

127 return _launch_reflection_pad1d(input, padding, out=out)