Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/upsample_linear1d.py: 0%

55 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def upsample_linear1d_kernel( 

15 input_ptr, 

16 output_ptr, 

17 NC, 

18 W_in, 

19 W_out, 

20 scale, 

21 bias, 

22 BLOCK_SIZE: tl.constexpr, 

23): 

24 pid_nc = tl.program_id(0) 

25 pid_w = tl.program_id(1) 

26 

27 base_in = pid_nc * W_in 

28 base_out = pid_nc * W_out 

29 

30 # Use modulo wrap to keep all indices in [0, W_out). 

31 # On KunlunXin, masked tl.store does not suppress writes for masked-out 

32 # threads without TRITONXPU_STORE_MASK_SIM=1, causing corruption of 

33 # adjacent channel data. The modulo wrap means tail-block threads simply 

34 # re-write already-computed values to valid positions — harmless. 

35 offs_w = (pid_w * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)) % W_out 

36 

37 offs_w_f = offs_w.to(tl.float32) 

38 

39 src = offs_w_f * scale + bias 

40 

41 # Clamp source position to [0, W_in - 1] 

42 src = tl.maximum(0.0, tl.minimum(src, W_in - 1.0)) 

43 

44 # For non-negative src, int truncation equals floor 

45 lower = src.to(tl.int32) 

46 upper = tl.minimum(lower + 1, W_in - 1) 

47 

48 t = src - lower.to(tl.float32) 

49 w0 = 1.0 - t 

50 w1 = t 

51 

52 # No mask needed: all offsets are within [0, W_in - 1] and [0, W_out - 1] 

53 x0 = tl.load(input_ptr + base_in + lower) 

54 x1 = tl.load(input_ptr + base_in + upper) 

55 

56 x0_f = x0.to(tl.float32) 

57 x1_f = x1.to(tl.float32) 

58 

59 out = w0 * x0_f + w1 * x1_f 

60 

61 out = out.to(x0.dtype) 

62 tl.store(output_ptr + base_out + offs_w, out) 

63 

64 

65def upsample_linear1d( 

66 self: torch.Tensor, 

67 output_size, 

68 align_corners: bool, 

69 scales: float = None, 

70): 

71 logger.debug("GEMS_KUNLUNXIN UPSAMPL_LINEAR1D") 

72 assert self.ndim == 3, "Input must be [N, C, W]" 

73 

74 N, C, W_in = self.shape 

75 NC = N * C 

76 

77 if output_size is not None: 

78 W_out = int( 

79 output_size[0] if isinstance(output_size, (list, tuple)) else output_size 

80 ) 

81 else: 

82 assert ( 

83 scales is not None 

84 ), "scales must be specified if output_size is not provided." 

85 W_out = int(math.floor(W_in * scales)) 

86 

87 inp = self.contiguous().view(NC, W_in) 

88 out = torch.empty((NC, W_out), device=self.device, dtype=self.dtype) 

89 

90 if align_corners: 

91 if W_out > 1: 

92 scale_val = (W_in - 1.0) / (W_out - 1.0) 

93 else: 

94 scale_val = 0.0 

95 bias_val = 0.0 

96 else: 

97 if scales is not None: 

98 real_scale = 1.0 / scales 

99 else: 

100 real_scale = W_in / W_out 

101 

102 scale_val = real_scale 

103 bias_val = 0.5 * real_scale - 0.5 

104 

105 BLOCK_SIZE = 256 

106 grid = (NC, triton.cdiv(W_out, BLOCK_SIZE)) 

107 

108 with torch_device_fn.device(self.device): 

109 upsample_linear1d_kernel[grid]( 

110 inp, 

111 out, 

112 NC, 

113 W_in, 

114 W_out, 

115 scale_val, 

116 bias_val, 

117 BLOCK_SIZE=BLOCK_SIZE, 

118 ) 

119 

120 return out.view(N, C, W_out)