Coverage for src/flag_gems/ops/_upsample_nearest_exact1d.py: 53%
134 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
3import math
5import torch
6import triton
7import triton.language as tl
9import flag_gems
11logger = logging.getLogger(__name__)
14@triton.jit
15def _upsample_nearest_exact1d_kernel(
16 in_ptr,
17 out_ptr,
18 N,
19 C,
20 IW,
21 OW,
22 sN_in,
23 sC_in,
24 sW_in,
25 sN_out,
26 sC_out,
27 sW_out,
28 use_scales: tl.constexpr,
29 scale_w,
30 BLOCK_W: tl.constexpr,
31):
32 pid_w = tl.program_id(0)
33 pid_nc = tl.program_id(1)
35 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W)
36 mask = offs_w < OW
38 # Compute (n, c) from flattened plane index
39 nc = pid_nc
40 n = nc // C
41 c = nc - n * C
43 base_in = n * sN_in + c * sC_in
44 base_out = n * sN_out + c * sC_out
46 # Compute source indices iw for each output index ow
47 iw = tl.zeros([BLOCK_W], dtype=tl.int32)
48 if use_scales:
49 ow_f = offs_w.to(tl.float32)
50 iw_f = tl.floor(ow_f / scale_w)
51 iw = iw_f.to(tl.int32)
52 else:
53 iw = (offs_w * IW) // OW
54 iw = tl.minimum(iw, IW - 1)
56 in_ptrs = in_ptr + base_in + iw * sW_in
57 x = tl.load(in_ptrs, mask=mask)
59 out_ptrs = out_ptr + base_out + offs_w * sW_out
60 tl.store(out_ptrs, x, mask=mask)
63def _parse_size_1d(val):
64 if val is None:
65 return None
66 if isinstance(val, torch.Size):
67 return int(val[-1]) if len(val) > 0 else None
68 if isinstance(val, (list, tuple)):
69 if len(val) == 0:
70 return None
71 return int(val[-1])
72 return int(val)
75def _parse_scale_1d(val):
76 if val is None:
77 return None
78 if isinstance(val, (list, tuple)):
79 if len(val) == 0:
80 return None
81 return float(val[-1])
82 return float(val)
85def _compute_out_w(iw, output_size, scale):
86 if output_size is not None:
87 return int(output_size)
88 if scale is None:
89 raise ValueError(
90 "Either output_size or scale must be provided for _upsample_nearest_exact1d."
91 )
92 # Follow common convention: OW = floor(IW * scale)
93 return int(math.floor(iw * scale))
96def _launch_upsample_nearest_exact1d_kernel(input, out, output_size=None, scale=None):
97 if input.ndim != 3:
98 raise ValueError(
99 f"_upsample_nearest_exact1d expects a 3D tensor (N, C, W); got shape {tuple(input.shape)}"
100 )
101 if input.device.type != flag_gems.device or out.device.type != flag_gems.device:
102 # Fallback to the native operator for non-target devices
103 return torch.ops.aten._upsample_nearest_exact1d(
104 input, [out.shape[-1]], [scale] if scale is not None else None
105 )
107 N, C, IW = input.shape
108 OW = out.shape[-1]
110 sN_in, sC_in, sW_in = input.stride()
111 sN_out, sC_out, sW_out = out.stride()
113 BLOCK_W = 256
114 grid = (triton.cdiv(OW, BLOCK_W), N * C)
116 use_scales = scale is not None and output_size is None
117 scale_w = float(scale) if use_scales else 1.0
119 _upsample_nearest_exact1d_kernel[grid](
120 input,
121 out,
122 N,
123 C,
124 IW,
125 OW,
126 sN_in,
127 sC_in,
128 sW_in,
129 sN_out,
130 sC_out,
131 sW_out,
132 use_scales=use_scales,
133 scale_w=scale_w,
134 BLOCK_W=BLOCK_W,
135 )
136 return out
139def _extract_io_and_params(args, kwargs, expect_out=False):
140 # Extract input tensor
141 in_t = kwargs.get("input", None)
142 if in_t is None:
143 in_t = kwargs.get("self", None)
144 if in_t is None and len(args) > 0 and isinstance(args[0], torch.Tensor):
145 in_t = args[0]
146 args = args[1:]
147 if in_t is None or not isinstance(in_t, torch.Tensor):
148 raise ValueError("Input tensor not found for _upsample_nearest_exact1d.")
150 # Extract output_size / scales from kwargs or remaining args
151 output_size = kwargs.get(
152 "output_size", kwargs.get("size", kwargs.get("output_size_list", None))
153 )
154 scales = kwargs.get(
155 "scale_factor",
156 kwargs.get("scales", kwargs.get("scale_factors", kwargs.get("scale", None))),
157 )
159 # If positional arguments contain size and/or scales
160 # Try to interpret next positional as output_size if present and not a tensor
161 pos = 0
162 if (
163 output_size is None
164 and pos < len(args)
165 and not isinstance(args[pos], torch.Tensor)
166 ):
167 output_size = args[pos]
168 pos += 1
169 if scales is None and pos < len(args) and not isinstance(args[pos], torch.Tensor):
170 scales = args[pos]
171 pos += 1
173 out_t = None
174 if expect_out:
175 out_t = kwargs.get("out", None)
176 if out_t is None:
177 # find last tensor among remaining args as out
178 for a in reversed(args):
179 if isinstance(a, torch.Tensor):
180 out_t = a
181 break
182 if out_t is None:
183 raise ValueError(
184 "Output tensor 'out' not found for _upsample_nearest_exact1d_out."
185 )
187 # Normalize single-dim size and scale
188 out_w = _parse_size_1d(output_size)
189 scale_w = _parse_scale_1d(scales)
191 return in_t, out_t, out_w, scale_w
194def _prepare_out_tensor(in_t, out_w, scale_w, dtype=None, device=None):
195 N, C, IW = in_t.shape
196 OW = _compute_out_w(IW, out_w, scale_w)
197 if OW < 0:
198 raise ValueError("Output width must be non-negative.")
199 if dtype is None:
200 dtype = in_t.dtype
201 if device is None:
202 device = in_t.device
203 return torch.empty((N, C, OW), dtype=dtype, device=device)
206def _upsample_nearest_exact1d(*args, **kwargs):
207 logger.debug("GEMS _UPSAMPLE_NEAREST_EXACT1D")
208 in_t, _, out_w, scale_w = _extract_io_and_params(args, kwargs, expect_out=False)
209 out_t = _prepare_out_tensor(in_t, out_w, scale_w)
210 if out_t.numel() == 0:
211 return out_t
212 return _launch_upsample_nearest_exact1d_kernel(
213 in_t, out_t, output_size=out_w, scale=scale_w
214 )
217def _upsample_nearest_exact1d_out(*args, **kwargs):
218 logger.debug("GEMS _UPSAMPLE_NEAREST_EXACT1D_OUT")
219 in_t, out_t, out_w, scale_w = _extract_io_and_params(args, kwargs, expect_out=True)
220 if out_t.ndim != 3:
221 raise ValueError(
222 f"Out tensor must be 3D (N, C, W); got shape {tuple(out_t.shape)}"
223 )
224 # Validate that out_t has the correct computed width if parameters are provided
225 expected_w = _compute_out_w(in_t.shape[-1], out_w, scale_w)
226 if out_t.shape[-1] != expected_w:
227 raise ValueError(
228 f"Provided out tensor has width {out_t.shape[-1]} but expected {expected_w}."
229 )
230 if out_t.numel() == 0:
231 return out_t
232 return _launch_upsample_nearest_exact1d_kernel(
233 in_t, out_t, output_size=out_w, scale=scale_w
234 )
237def _upsample_nearest_exact1d_vec(*args, **kwargs):
238 logger.debug("GEMS _UPSAMPLE_NEAREST_EXACT1D_VEC")
239 # Treat vec the same as base variant, allowing list-like output_size/scales
240 in_t, _, out_w, scale_w = _extract_io_and_params(args, kwargs, expect_out=False)
241 out_t = _prepare_out_tensor(in_t, out_w, scale_w)
242 if out_t.numel() == 0:
243 return out_t
244 return _launch_upsample_nearest_exact1d_kernel(
245 in_t, out_t, output_size=out_w, scale=scale_w
246 )