Coverage for src/flag_gems/ops/upsample_bicubic2d_aa_backward.py: 36%
190 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8logger = logging.getLogger(__name__)
11@triton.jit
12def _cubic_aa_filter(x):
13 """Keys cubic filter with a = -0.5 (PIL-compatible). x must be >= 0."""
14 return tl.where(
15 x < 1.0,
16 (1.5 * x - 2.5) * x * x + 1.0,
17 tl.where(
18 x < 2.0,
19 ((-0.5 * x + 2.5) * x - 4.0) * x + 2.0,
20 0.0,
21 ),
22 )
25@triton.jit
26def _f2i(x):
27 """float -> int32 with clamping to avoid undefined overflow."""
28 _LO: tl.constexpr = -2147483648.0
29 _HI: tl.constexpr = 2147483520.0
30 return tl.minimum(tl.maximum(x, _LO), _HI).to(tl.int32)
33@triton.jit
34def _fused_backward_kernel(
35 grad_out_ptr, # [NC, H_out, W_out] flat
36 grad_in_ptr, # [NC, H_in, W_in] flat (output)
37 # H params
38 H_in,
39 H_out,
40 h_scale,
41 support_h,
42 invscale_h,
43 inv_h_scale,
44 # W params
45 W_in,
46 W_out,
47 w_scale,
48 support_w,
49 invscale_w,
50 inv_w_scale,
51 # Stride
52 stride_go_nc, # = H_out * W_out
53 # Compile-time constants
54 BLOCK_IW: tl.constexpr,
55 MAX_OH: tl.constexpr,
56 MAX_OW: tl.constexpr,
57 MAX_KSIZE_H: tl.constexpr,
58 MAX_KSIZE_W: tl.constexpr,
59):
60 pid_row = tl.program_id(0) # nc * H_in + ih
61 pid_col = tl.program_id(1) # iw tile
63 nc = pid_row // H_in
64 ih = pid_row % H_in
65 ih_f = ih.to(tl.float32)
67 iw_base = pid_col * BLOCK_IW
68 iws = iw_base + tl.arange(0, BLOCK_IW)
69 iw_mask = iws < W_in
70 iw_f = iws.to(tl.float32)
72 # Scalar: which oh values contribute to this ih
73 oh_start = tl.maximum(_f2i((ih_f + 0.5 - support_h) * inv_h_scale - 0.5), 0)
75 # Vector: which ow values contribute to each iw
76 ow_starts = tl.maximum(_f2i((iw_f + 0.5 - support_w) * inv_w_scale - 0.5), 0)
78 go_nc_base = nc.to(tl.int64) * stride_go_nc
80 accum = tl.zeros([BLOCK_IW], dtype=tl.float32)
82 # --- d_ow OUTER loop: wx computed once per d_ow, reused across d_oh ---
83 for d_ow in tl.static_range(MAX_OW):
84 ow = ow_starts + d_ow # vector
85 ow_valid_base = iw_mask & (ow >= 0) & (ow < W_out)
87 # Compute wx (vector) — only once per d_ow
88 center_w = w_scale * (ow.to(tl.float32) + 0.5)
89 xmin_w = tl.maximum(_f2i(center_w - support_w + 0.5), 0)
90 xsize_w = tl.minimum(_f2i(center_w + support_w + 0.5), W_in) - xmin_w
91 xsize_w_pos = tl.maximum(xsize_w, 0)
92 iw_in_range = ow_valid_base & (iws >= xmin_w) & (iws < xmin_w + xsize_w_pos)
94 # Inline total_wx computation (vector)
95 xmin_w_f = xmin_w.to(tl.float32)
96 total_wx = tl.zeros([BLOCK_IW], dtype=tl.float32)
97 for j_w in tl.static_range(MAX_KSIZE_W):
98 arg_w = tl.abs((j_w + xmin_w_f - center_w + 0.5) * invscale_w)
99 w_w = _cubic_aa_filter(arg_w)
100 total_wx += tl.where(j_w < xsize_w_pos, w_w, 0.0)
102 raw_wx = _cubic_aa_filter(tl.abs((iw_f - center_w + 0.5) * invscale_w))
103 wx = tl.where(iw_in_range & (total_wx != 0.0), raw_wx / total_wx, 0.0)
105 ow_safe = tl.maximum(tl.minimum(ow, W_out - 1), 0)
107 # --- d_oh INNER loop: wy is scalar, cheap to recompute ---
108 for d_oh in tl.static_range(MAX_OH):
109 oh = oh_start + d_oh # scalar
110 oh_valid = (oh >= 0) & (oh < H_out)
112 # Compute wy (scalar)
113 center_h = h_scale * (oh + 0.5)
114 ymin_h = tl.maximum(_f2i(center_h - support_h + 0.5), 0)
115 ysize_h = tl.minimum(_f2i(center_h + support_h + 0.5), H_in) - ymin_h
116 ysize_h_pos = tl.maximum(ysize_h, 0)
117 ih_in_range = oh_valid & (ih >= ymin_h) & (ih < ymin_h + ysize_h_pos)
119 # Inline total_wy computation (scalar, very cheap)
120 ymin_h_f = ymin_h.to(tl.float32)
121 total_wy = 0.0
122 for j_h in tl.static_range(MAX_KSIZE_H):
123 arg_h = tl.abs((j_h + ymin_h_f - center_h + 0.5) * invscale_h)
124 w_h = _cubic_aa_filter(arg_h)
125 total_wy += tl.where(j_h < ysize_h_pos, w_h, 0.0)
127 raw_wy = _cubic_aa_filter(tl.abs((ih_f - center_h + 0.5) * invscale_h))
128 wy = tl.where(ih_in_range & (total_wy != 0.0), raw_wy / total_wy, 0.0)
130 # Load grad_out and accumulate
131 valid = iw_in_range & ih_in_range
132 oh_safe = tl.maximum(tl.minimum(oh, H_out - 1), 0)
133 g = tl.load(
134 grad_out_ptr
135 + go_nc_base
136 + oh_safe.to(tl.int64) * W_out
137 + ow_safe.to(tl.int64),
138 mask=valid,
139 other=0.0,
140 )
141 accum += wy * wx * g
143 gi_off = pid_row.to(tl.int64) * W_in + iws.to(tl.int64)
144 tl.store(
145 grad_in_ptr + gi_off,
146 accum.to(grad_in_ptr.dtype.element_ty),
147 mask=iw_mask,
148 )
151@triton.jit
152def _precompute_weight_sums_kernel(
153 total_w_ptr,
154 output_size,
155 input_size,
156 scale,
157 support,
158 invscale,
159 MAX_KSIZE: tl.constexpr,
160):
161 oi = tl.program_id(0)
162 if oi >= output_size:
163 return
164 center = scale * (oi + 0.5)
165 xmin = tl.maximum(_f2i(center - support + 0.5), 0)
166 xsize = tl.minimum(_f2i(center + support + 0.5), input_size) - xmin
167 xsize = tl.minimum(tl.maximum(xsize, 0), MAX_KSIZE)
168 xmin_f = xmin.to(tl.float32)
169 total = 0.0
170 for j in tl.static_range(MAX_KSIZE):
171 arg = tl.abs((j + xmin_f - center + 0.5) * invscale)
172 w = _cubic_aa_filter(arg)
173 total += tl.where(j < xsize, w, 0.0)
174 tl.store(total_w_ptr + oi, total)
177@triton.jit
178def _pass1_w_gather_nchw_kernel(
179 grad_out_ptr, # [NC, H_out, W_out] flat
180 buf_ptr, # [NC, H_out, W_in] flat (output)
181 total_wx_ptr, # [W_out]
182 W_in,
183 W_out,
184 w_scale,
185 support_w,
186 invscale_w,
187 inv_w_scale,
188 BLOCK_IW: tl.constexpr,
189 MAX_OW: tl.constexpr,
190):
191 pid_row = tl.program_id(0)
192 pid_col = tl.program_id(1)
194 iw_base = pid_col * BLOCK_IW
195 iws = iw_base + tl.arange(0, BLOCK_IW)
196 iw_mask = iws < W_in
197 iw_f = iws.to(tl.float32)
199 go_base = pid_row.to(tl.int64) * W_out
200 buf_base = pid_row.to(tl.int64) * W_in
202 ow_starts = tl.maximum(_f2i((iw_f + 0.5 - support_w) * inv_w_scale - 0.5), 0)
204 accum = tl.zeros([BLOCK_IW], dtype=tl.float32)
206 for d_ow in tl.static_range(MAX_OW):
207 ow = ow_starts + d_ow
208 ow_valid = iw_mask & (ow >= 0) & (ow < W_out)
210 center_w = w_scale * (ow.to(tl.float32) + 0.5)
211 xmin = tl.maximum(_f2i(center_w - support_w + 0.5), 0)
212 xsize = tl.minimum(_f2i(center_w + support_w + 0.5), W_in) - xmin
213 in_range = ow_valid & (iws >= xmin) & (iws < xmin + tl.maximum(xsize, 0))
215 raw_wx = _cubic_aa_filter(tl.abs((iw_f - center_w + 0.5) * invscale_w))
216 ow_safe = tl.maximum(tl.minimum(ow, W_out - 1), 0)
217 tw_x = tl.load(total_wx_ptr + ow_safe, mask=in_range, other=1.0)
218 wx = tl.where(in_range & (tw_x != 0.0), raw_wx / tw_x, 0.0)
220 g = tl.load(
221 grad_out_ptr + go_base + ow_safe.to(tl.int64), mask=in_range, other=0.0
222 )
223 accum += wx * g
225 tl.store(buf_ptr + buf_base + iws.to(tl.int64), accum, mask=iw_mask)
228@triton.jit
229def _pass2_h_gather_nchw_kernel(
230 buf_ptr, # [NC, H_out, W_in] flat (input)
231 grad_in_ptr, # [NC, H_in, W_in] flat (output)
232 total_wy_ptr, # [H_out]
233 H_in,
234 W_in,
235 H_out,
236 h_scale,
237 support_h,
238 invscale_h,
239 inv_h_scale,
240 stride_buf_hw, # = H_out * W_in
241 BLOCK_IW: tl.constexpr,
242 MAX_OH: tl.constexpr,
243):
244 pid_row = tl.program_id(0)
245 pid_col = tl.program_id(1)
247 nc = pid_row // H_in
248 ih = pid_row % H_in
249 ih_f = ih.to(tl.float32)
251 iw_base = pid_col * BLOCK_IW
252 iws = iw_base + tl.arange(0, BLOCK_IW)
253 iw_mask = iws < W_in
255 oh_start = tl.maximum(_f2i((ih_f + 0.5 - support_h) * inv_h_scale - 0.5), 0)
257 buf_nc_base = nc.to(tl.int64) * stride_buf_hw
259 accum = tl.zeros([BLOCK_IW], dtype=tl.float32)
261 for d_oh in tl.static_range(MAX_OH):
262 oh = oh_start + d_oh
263 oh_valid = (oh >= 0) & (oh < H_out)
265 center_h = h_scale * (oh + 0.5)
266 ymin = tl.maximum(_f2i(center_h - support_h + 0.5), 0)
267 ysize = tl.minimum(_f2i(center_h + support_h + 0.5), H_in) - ymin
268 ih_in_range = oh_valid & (ih >= ymin) & (ih < ymin + tl.maximum(ysize, 0))
270 raw_wy = _cubic_aa_filter(tl.abs((ih_f - center_h + 0.5) * invscale_h))
271 oh_safe = tl.maximum(tl.minimum(oh, H_out - 1), 0)
272 tw_y = tl.load(total_wy_ptr + oh_safe)
273 wy = tl.where(ih_in_range & (tw_y != 0.0), raw_wy / tw_y, 0.0)
275 buf_off = buf_nc_base + oh_safe.to(tl.int64) * W_in + iws.to(tl.int64)
276 b = tl.load(buf_ptr + buf_off, mask=iw_mask & ih_in_range, other=0.0)
278 accum += wy * b
280 gi_off = pid_row.to(tl.int64) * W_in + iws.to(tl.int64)
281 tl.store(
282 grad_in_ptr + gi_off,
283 accum.to(grad_in_ptr.dtype.element_ty),
284 mask=iw_mask,
285 )
288def _compute_scale(input_size, output_size, align_corners, scale=None):
289 if align_corners:
290 return float(input_size - 1) / (output_size - 1) if output_size > 1 else 0.0
291 else:
292 return (
293 (1.0 / scale)
294 if (scale is not None and scale > 0)
295 else float(input_size) / output_size
296 )
299# Threshold: when total elements (across the larger of input / output spatial)
300# is below this, the fused single-kernel path is used (1 launch instead of 4).
301# Above this, the 2-pass separable path is more memory-bandwidth efficient.
302_FUSE_THRESHOLD = 1 << 20 # 1M elements
305def _upsample_bicubic2d_aa_backward(
306 grad_output: torch.Tensor,
307 output_size, # [H_out, W_out]
308 input_size, # [N, C, H_in, W_in]
309 align_corners: bool,
310 scales_h=None,
311 scales_w=None,
312) -> torch.Tensor:
313 N, C, H_in, W_in = input_size
314 H_out, W_out = output_size
316 assert grad_output.shape == (N, C, H_out, W_out), (
317 f"grad_output shape {grad_output.shape} != "
318 f"expected ({N}, {C}, {H_out}, {W_out})"
319 )
321 NC = N * C
322 if NC == 0 or H_in == 0 or W_in == 0 or H_out == 0 or W_out == 0:
323 return grad_output.new_zeros(input_size)
325 # ---- Work in NCHW — zero-copy reshape to [NC, H, W] ----
326 grad_out_flat = grad_output.contiguous().reshape(NC, H_out, W_out)
328 # ---- Scales & filter parameters ----
329 h_scale = _compute_scale(H_in, H_out, align_corners, scales_h)
330 w_scale = _compute_scale(W_in, W_out, align_corners, scales_w)
332 INTERP_SIZE = 4
333 support_h = (INTERP_SIZE * 0.5) * h_scale if h_scale >= 1.0 else INTERP_SIZE * 0.5
334 support_w = (INTERP_SIZE * 0.5) * w_scale if w_scale >= 1.0 else INTERP_SIZE * 0.5
335 invscale_h = 1.0 / h_scale if h_scale >= 1.0 else 1.0
336 invscale_w = 1.0 / w_scale if w_scale >= 1.0 else 1.0
338 MAX_KSIZE_H = math.ceil(support_h) * 2 + 1
339 MAX_KSIZE_W = math.ceil(support_w) * 2 + 1
341 _EPS = 1e-10
342 inv_h_scale = 1.0 / max(h_scale, _EPS)
343 inv_w_scale = 1.0 / max(w_scale, _EPS)
345 MAX_OH = min(math.ceil(2 * support_h * inv_h_scale) + 2, max(H_out, 1))
346 MAX_OW = min(math.ceil(2 * support_w * inv_w_scale) + 2, max(W_out, 1))
348 # ---- BLOCK_IW & num_warps ----
349 BLOCK_IW = min(triton.next_power_of_2(max(W_in, 1)), 256)
350 if BLOCK_IW < 32:
351 BLOCK_IW = 32
352 nw = 1 if BLOCK_IW <= 32 else (2 if BLOCK_IW <= 64 else 4)
354 # ---- Choose fused vs 2-pass ----
355 total_elems = NC * max(H_in * W_in, H_out * W_out)
356 use_fused = total_elems <= _FUSE_THRESHOLD
358 if use_fused:
359 # ============================================================
360 # FUSED PATH — single kernel launch, no intermediate buffer
361 # ============================================================
362 grad_in_flat = torch.empty(
363 NC, H_in, W_in, dtype=grad_output.dtype, device=grad_output.device
364 )
365 grid = (NC * H_in, triton.cdiv(W_in, BLOCK_IW))
366 _fused_backward_kernel[grid](
367 grad_out_flat,
368 grad_in_flat,
369 H_in,
370 H_out,
371 h_scale,
372 support_h,
373 invscale_h,
374 inv_h_scale,
375 W_in,
376 W_out,
377 w_scale,
378 support_w,
379 invscale_w,
380 inv_w_scale,
381 H_out * W_out, # stride_go_nc
382 BLOCK_IW=BLOCK_IW,
383 MAX_OH=MAX_OH,
384 MAX_OW=MAX_OW,
385 MAX_KSIZE_H=MAX_KSIZE_H,
386 MAX_KSIZE_W=MAX_KSIZE_W,
387 num_warps=nw,
388 )
389 return grad_in_flat.reshape(N, C, H_in, W_in)
391 else:
392 # ============================================================
393 # 2-PASS PATH — separable, memory-bandwidth efficient for big tensors
394 # ============================================================
396 # Phase 0: precompute weight sums
397 total_wy = torch.empty(
398 max(H_out, 1), dtype=torch.float32, device=grad_output.device
399 )
400 total_wx = torch.empty(
401 max(W_out, 1), dtype=torch.float32, device=grad_output.device
402 )
403 if H_out > 0:
404 _precompute_weight_sums_kernel[(H_out,)](
405 total_wy,
406 H_out,
407 H_in,
408 h_scale,
409 support_h,
410 invscale_h,
411 MAX_KSIZE=MAX_KSIZE_H,
412 )
413 if W_out > 0:
414 _precompute_weight_sums_kernel[(W_out,)](
415 total_wx,
416 W_out,
417 W_in,
418 w_scale,
419 support_w,
420 invscale_w,
421 MAX_KSIZE=MAX_KSIZE_W,
422 )
424 # Phase 1: W-gather -> buf [NC, H_out, W_in]
425 buf = torch.empty(
426 NC, H_out, W_in, dtype=torch.float32, device=grad_output.device
427 )
428 grid1 = (NC * H_out, triton.cdiv(W_in, BLOCK_IW))
429 _pass1_w_gather_nchw_kernel[grid1](
430 grad_out_flat,
431 buf,
432 total_wx,
433 W_in,
434 W_out,
435 w_scale,
436 support_w,
437 invscale_w,
438 inv_w_scale,
439 BLOCK_IW=BLOCK_IW,
440 MAX_OW=MAX_OW,
441 num_warps=nw,
442 )
444 # Phase 2: H-gather -> grad_in [NC, H_in, W_in]
445 grad_in_flat = torch.empty(
446 NC, H_in, W_in, dtype=grad_output.dtype, device=grad_output.device
447 )
448 grid2 = (NC * H_in, triton.cdiv(W_in, BLOCK_IW))
449 _pass2_h_gather_nchw_kernel[grid2](
450 buf,
451 grad_in_flat,
452 total_wy,
453 H_in,
454 W_in,
455 H_out,
456 h_scale,
457 support_h,
458 invscale_h,
459 inv_h_scale,
460 H_out * W_in, # stride_buf_hw
461 BLOCK_IW=BLOCK_IW,
462 MAX_OH=MAX_OH,
463 num_warps=nw,
464 )
466 return grad_in_flat.reshape(N, C, H_in, W_in)