Coverage for src/flag_gems/ops/upsample_linear1d.py: 55%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8import flag_gems 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def upsample_linear1d_kernel( 

15 input_ptr, 

16 output_ptr, 

17 NC, 

18 W_in, 

19 W_out, 

20 scale, 

21 bias, 

22 BLOCK_SIZE: tl.constexpr, 

23): 

24 pid_nc = tl.program_id(0) 

25 pid_w = tl.program_id(1) 

26 

27 base_in = pid_nc * W_in 

28 base_out = pid_nc * W_out 

29 

30 offs_w = pid_w * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

31 mask = (pid_nc < NC) & (offs_w < W_out) 

32 

33 offs_w_f = offs_w.to(tl.float32) 

34 

35 src = offs_w_f * scale + bias 

36 

37 src = tl.maximum(0.0, tl.minimum(src, W_in - 1.0)) 

38 

39 lower = tl.floor(src).to(tl.int32) 

40 upper = tl.minimum(lower + 1, W_in - 1) 

41 

42 t = src - lower.to(tl.float32) 

43 w0 = 1.0 - t 

44 w1 = t 

45 

46 x0 = tl.load(input_ptr + base_in + lower, mask=mask) 

47 x1 = tl.load(input_ptr + base_in + upper, mask=mask) 

48 

49 x0_f = x0.to(tl.float32) 

50 x1_f = x1.to(tl.float32) 

51 

52 out = w0 * x0_f + w1 * x1_f 

53 

54 out = out.to(x0.dtype) 

55 tl.store(output_ptr + base_out + offs_w, out, mask=mask) 

56 

57 

58def upsample_linear1d( 

59 self: torch.Tensor, 

60 output_size, 

61 align_corners: bool, 

62 scales: float = None, 

63): 

64 logger.debug("GEMS UPSAMPLE LINEAR1D OPTIMIZED") 

65 assert self.ndim == 3, "Input must be [N, C, W]" 

66 assert self.device.type == flag_gems.device 

67 

68 N, C, W_in = self.shape 

69 NC = N * C 

70 

71 if output_size is not None: 

72 W_out = int( 

73 output_size[0] if isinstance(output_size, (list, tuple)) else output_size 

74 ) 

75 else: 

76 assert scales is not None 

77 W_out = int(math.floor(W_in * scales)) 

78 

79 inp = self.contiguous().view(NC, W_in) 

80 out = torch.empty((NC, W_out), device=self.device, dtype=self.dtype) 

81 

82 if align_corners: 

83 if W_out > 1: 

84 scale_val = (W_in - 1.0) / (W_out - 1.0) 

85 else: 

86 scale_val = 0.0 

87 bias_val = 0.0 

88 else: 

89 if scales is not None: 

90 real_scale = 1.0 / scales 

91 else: 

92 real_scale = W_in / W_out 

93 

94 scale_val = real_scale 

95 bias_val = 0.5 * real_scale - 0.5 

96 

97 BLOCK_SIZE = 256 

98 grid = (NC, triton.cdiv(W_out, BLOCK_SIZE)) 

99 

100 upsample_linear1d_kernel[grid]( 

101 inp, 

102 out, 

103 NC, 

104 W_in, 

105 W_out, 

106 scale_val, 

107 bias_val, 

108 BLOCK_SIZE=BLOCK_SIZE, 

109 ) 

110 

111 return out.view(N, C, W_out)