Coverage for src/flag_gems/ops/_upsample_nearest_exact1d.py: 53%

134 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3import math 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9import flag_gems 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@triton.jit 

15def _upsample_nearest_exact1d_kernel( 

16 in_ptr, 

17 out_ptr, 

18 N, 

19 C, 

20 IW, 

21 OW, 

22 sN_in, 

23 sC_in, 

24 sW_in, 

25 sN_out, 

26 sC_out, 

27 sW_out, 

28 use_scales: tl.constexpr, 

29 scale_w, 

30 BLOCK_W: tl.constexpr, 

31): 

32 pid_w = tl.program_id(0) 

33 pid_nc = tl.program_id(1) 

34 

35 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W) 

36 mask = offs_w < OW 

37 

38 # Compute (n, c) from flattened plane index 

39 nc = pid_nc 

40 n = nc // C 

41 c = nc - n * C 

42 

43 base_in = n * sN_in + c * sC_in 

44 base_out = n * sN_out + c * sC_out 

45 

46 # Compute source indices iw for each output index ow 

47 iw = tl.zeros([BLOCK_W], dtype=tl.int32) 

48 if use_scales: 

49 ow_f = offs_w.to(tl.float32) 

50 iw_f = tl.floor(ow_f / scale_w) 

51 iw = iw_f.to(tl.int32) 

52 else: 

53 iw = (offs_w * IW) // OW 

54 iw = tl.minimum(iw, IW - 1) 

55 

56 in_ptrs = in_ptr + base_in + iw * sW_in 

57 x = tl.load(in_ptrs, mask=mask) 

58 

59 out_ptrs = out_ptr + base_out + offs_w * sW_out 

60 tl.store(out_ptrs, x, mask=mask) 

61 

62 

63def _parse_size_1d(val): 

64 if val is None: 

65 return None 

66 if isinstance(val, torch.Size): 

67 return int(val[-1]) if len(val) > 0 else None 

68 if isinstance(val, (list, tuple)): 

69 if len(val) == 0: 

70 return None 

71 return int(val[-1]) 

72 return int(val) 

73 

74 

75def _parse_scale_1d(val): 

76 if val is None: 

77 return None 

78 if isinstance(val, (list, tuple)): 

79 if len(val) == 0: 

80 return None 

81 return float(val[-1]) 

82 return float(val) 

83 

84 

85def _compute_out_w(iw, output_size, scale): 

86 if output_size is not None: 

87 return int(output_size) 

88 if scale is None: 

89 raise ValueError( 

90 "Either output_size or scale must be provided for _upsample_nearest_exact1d." 

91 ) 

92 # Follow common convention: OW = floor(IW * scale) 

93 return int(math.floor(iw * scale)) 

94 

95 

96def _launch_upsample_nearest_exact1d_kernel(input, out, output_size=None, scale=None): 

97 if input.ndim != 3: 

98 raise ValueError( 

99 f"_upsample_nearest_exact1d expects a 3D tensor (N, C, W); got shape {tuple(input.shape)}" 

100 ) 

101 if input.device.type != flag_gems.device or out.device.type != flag_gems.device: 

102 # Fallback to the native operator for non-target devices 

103 return torch.ops.aten._upsample_nearest_exact1d( 

104 input, [out.shape[-1]], [scale] if scale is not None else None 

105 ) 

106 

107 N, C, IW = input.shape 

108 OW = out.shape[-1] 

109 

110 sN_in, sC_in, sW_in = input.stride() 

111 sN_out, sC_out, sW_out = out.stride() 

112 

113 BLOCK_W = 256 

114 grid = (triton.cdiv(OW, BLOCK_W), N * C) 

115 

116 use_scales = scale is not None and output_size is None 

117 scale_w = float(scale) if use_scales else 1.0 

118 

119 _upsample_nearest_exact1d_kernel[grid]( 

120 input, 

121 out, 

122 N, 

123 C, 

124 IW, 

125 OW, 

126 sN_in, 

127 sC_in, 

128 sW_in, 

129 sN_out, 

130 sC_out, 

131 sW_out, 

132 use_scales=use_scales, 

133 scale_w=scale_w, 

134 BLOCK_W=BLOCK_W, 

135 ) 

136 return out 

137 

138 

139def _extract_io_and_params(args, kwargs, expect_out=False): 

140 # Extract input tensor 

141 in_t = kwargs.get("input", None) 

142 if in_t is None: 

143 in_t = kwargs.get("self", None) 

144 if in_t is None and len(args) > 0 and isinstance(args[0], torch.Tensor): 

145 in_t = args[0] 

146 args = args[1:] 

147 if in_t is None or not isinstance(in_t, torch.Tensor): 

148 raise ValueError("Input tensor not found for _upsample_nearest_exact1d.") 

149 

150 # Extract output_size / scales from kwargs or remaining args 

151 output_size = kwargs.get( 

152 "output_size", kwargs.get("size", kwargs.get("output_size_list", None)) 

153 ) 

154 scales = kwargs.get( 

155 "scale_factor", 

156 kwargs.get("scales", kwargs.get("scale_factors", kwargs.get("scale", None))), 

157 ) 

158 

159 # If positional arguments contain size and/or scales 

160 # Try to interpret next positional as output_size if present and not a tensor 

161 pos = 0 

162 if ( 

163 output_size is None 

164 and pos < len(args) 

165 and not isinstance(args[pos], torch.Tensor) 

166 ): 

167 output_size = args[pos] 

168 pos += 1 

169 if scales is None and pos < len(args) and not isinstance(args[pos], torch.Tensor): 

170 scales = args[pos] 

171 pos += 1 

172 

173 out_t = None 

174 if expect_out: 

175 out_t = kwargs.get("out", None) 

176 if out_t is None: 

177 # find last tensor among remaining args as out 

178 for a in reversed(args): 

179 if isinstance(a, torch.Tensor): 

180 out_t = a 

181 break 

182 if out_t is None: 

183 raise ValueError( 

184 "Output tensor 'out' not found for _upsample_nearest_exact1d_out." 

185 ) 

186 

187 # Normalize single-dim size and scale 

188 out_w = _parse_size_1d(output_size) 

189 scale_w = _parse_scale_1d(scales) 

190 

191 return in_t, out_t, out_w, scale_w 

192 

193 

194def _prepare_out_tensor(in_t, out_w, scale_w, dtype=None, device=None): 

195 N, C, IW = in_t.shape 

196 OW = _compute_out_w(IW, out_w, scale_w) 

197 if OW < 0: 

198 raise ValueError("Output width must be non-negative.") 

199 if dtype is None: 

200 dtype = in_t.dtype 

201 if device is None: 

202 device = in_t.device 

203 return torch.empty((N, C, OW), dtype=dtype, device=device) 

204 

205 

206def _upsample_nearest_exact1d(*args, **kwargs): 

207 logger.debug("GEMS _UPSAMPLE_NEAREST_EXACT1D") 

208 in_t, _, out_w, scale_w = _extract_io_and_params(args, kwargs, expect_out=False) 

209 out_t = _prepare_out_tensor(in_t, out_w, scale_w) 

210 if out_t.numel() == 0: 

211 return out_t 

212 return _launch_upsample_nearest_exact1d_kernel( 

213 in_t, out_t, output_size=out_w, scale=scale_w 

214 ) 

215 

216 

217def _upsample_nearest_exact1d_out(*args, **kwargs): 

218 logger.debug("GEMS _UPSAMPLE_NEAREST_EXACT1D_OUT") 

219 in_t, out_t, out_w, scale_w = _extract_io_and_params(args, kwargs, expect_out=True) 

220 if out_t.ndim != 3: 

221 raise ValueError( 

222 f"Out tensor must be 3D (N, C, W); got shape {tuple(out_t.shape)}" 

223 ) 

224 # Validate that out_t has the correct computed width if parameters are provided 

225 expected_w = _compute_out_w(in_t.shape[-1], out_w, scale_w) 

226 if out_t.shape[-1] != expected_w: 

227 raise ValueError( 

228 f"Provided out tensor has width {out_t.shape[-1]} but expected {expected_w}." 

229 ) 

230 if out_t.numel() == 0: 

231 return out_t 

232 return _launch_upsample_nearest_exact1d_kernel( 

233 in_t, out_t, output_size=out_w, scale=scale_w 

234 ) 

235 

236 

237def _upsample_nearest_exact1d_vec(*args, **kwargs): 

238 logger.debug("GEMS _UPSAMPLE_NEAREST_EXACT1D_VEC") 

239 # Treat vec the same as base variant, allowing list-like output_size/scales 

240 in_t, _, out_w, scale_w = _extract_io_and_params(args, kwargs, expect_out=False) 

241 out_t = _prepare_out_tensor(in_t, out_w, scale_w) 

242 if out_t.numel() == 0: 

243 return out_t 

244 return _launch_upsample_nearest_exact1d_kernel( 

245 in_t, out_t, output_size=out_w, scale=scale_w 

246 )