Coverage for src/flag_gems/ops/upsample_bicubic2d_aa_backward.py: 36%

190 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@triton.jit 

12def _cubic_aa_filter(x): 

13 """Keys cubic filter with a = -0.5 (PIL-compatible). x must be >= 0.""" 

14 return tl.where( 

15 x < 1.0, 

16 (1.5 * x - 2.5) * x * x + 1.0, 

17 tl.where( 

18 x < 2.0, 

19 ((-0.5 * x + 2.5) * x - 4.0) * x + 2.0, 

20 0.0, 

21 ), 

22 ) 

23 

24 

25@triton.jit 

26def _f2i(x): 

27 """float -> int32 with clamping to avoid undefined overflow.""" 

28 _LO: tl.constexpr = -2147483648.0 

29 _HI: tl.constexpr = 2147483520.0 

30 return tl.minimum(tl.maximum(x, _LO), _HI).to(tl.int32) 

31 

32 

33@triton.jit 

34def _fused_backward_kernel( 

35 grad_out_ptr, # [NC, H_out, W_out] flat 

36 grad_in_ptr, # [NC, H_in, W_in] flat (output) 

37 # H params 

38 H_in, 

39 H_out, 

40 h_scale, 

41 support_h, 

42 invscale_h, 

43 inv_h_scale, 

44 # W params 

45 W_in, 

46 W_out, 

47 w_scale, 

48 support_w, 

49 invscale_w, 

50 inv_w_scale, 

51 # Stride 

52 stride_go_nc, # = H_out * W_out 

53 # Compile-time constants 

54 BLOCK_IW: tl.constexpr, 

55 MAX_OH: tl.constexpr, 

56 MAX_OW: tl.constexpr, 

57 MAX_KSIZE_H: tl.constexpr, 

58 MAX_KSIZE_W: tl.constexpr, 

59): 

60 pid_row = tl.program_id(0) # nc * H_in + ih 

61 pid_col = tl.program_id(1) # iw tile 

62 

63 nc = pid_row // H_in 

64 ih = pid_row % H_in 

65 ih_f = ih.to(tl.float32) 

66 

67 iw_base = pid_col * BLOCK_IW 

68 iws = iw_base + tl.arange(0, BLOCK_IW) 

69 iw_mask = iws < W_in 

70 iw_f = iws.to(tl.float32) 

71 

72 # Scalar: which oh values contribute to this ih 

73 oh_start = tl.maximum(_f2i((ih_f + 0.5 - support_h) * inv_h_scale - 0.5), 0) 

74 

75 # Vector: which ow values contribute to each iw 

76 ow_starts = tl.maximum(_f2i((iw_f + 0.5 - support_w) * inv_w_scale - 0.5), 0) 

77 

78 go_nc_base = nc.to(tl.int64) * stride_go_nc 

79 

80 accum = tl.zeros([BLOCK_IW], dtype=tl.float32) 

81 

82 # --- d_ow OUTER loop: wx computed once per d_ow, reused across d_oh --- 

83 for d_ow in tl.static_range(MAX_OW): 

84 ow = ow_starts + d_ow # vector 

85 ow_valid_base = iw_mask & (ow >= 0) & (ow < W_out) 

86 

87 # Compute wx (vector) — only once per d_ow 

88 center_w = w_scale * (ow.to(tl.float32) + 0.5) 

89 xmin_w = tl.maximum(_f2i(center_w - support_w + 0.5), 0) 

90 xsize_w = tl.minimum(_f2i(center_w + support_w + 0.5), W_in) - xmin_w 

91 xsize_w_pos = tl.maximum(xsize_w, 0) 

92 iw_in_range = ow_valid_base & (iws >= xmin_w) & (iws < xmin_w + xsize_w_pos) 

93 

94 # Inline total_wx computation (vector) 

95 xmin_w_f = xmin_w.to(tl.float32) 

96 total_wx = tl.zeros([BLOCK_IW], dtype=tl.float32) 

97 for j_w in tl.static_range(MAX_KSIZE_W): 

98 arg_w = tl.abs((j_w + xmin_w_f - center_w + 0.5) * invscale_w) 

99 w_w = _cubic_aa_filter(arg_w) 

100 total_wx += tl.where(j_w < xsize_w_pos, w_w, 0.0) 

101 

102 raw_wx = _cubic_aa_filter(tl.abs((iw_f - center_w + 0.5) * invscale_w)) 

103 wx = tl.where(iw_in_range & (total_wx != 0.0), raw_wx / total_wx, 0.0) 

104 

105 ow_safe = tl.maximum(tl.minimum(ow, W_out - 1), 0) 

106 

107 # --- d_oh INNER loop: wy is scalar, cheap to recompute --- 

108 for d_oh in tl.static_range(MAX_OH): 

109 oh = oh_start + d_oh # scalar 

110 oh_valid = (oh >= 0) & (oh < H_out) 

111 

112 # Compute wy (scalar) 

113 center_h = h_scale * (oh + 0.5) 

114 ymin_h = tl.maximum(_f2i(center_h - support_h + 0.5), 0) 

115 ysize_h = tl.minimum(_f2i(center_h + support_h + 0.5), H_in) - ymin_h 

116 ysize_h_pos = tl.maximum(ysize_h, 0) 

117 ih_in_range = oh_valid & (ih >= ymin_h) & (ih < ymin_h + ysize_h_pos) 

118 

119 # Inline total_wy computation (scalar, very cheap) 

120 ymin_h_f = ymin_h.to(tl.float32) 

121 total_wy = 0.0 

122 for j_h in tl.static_range(MAX_KSIZE_H): 

123 arg_h = tl.abs((j_h + ymin_h_f - center_h + 0.5) * invscale_h) 

124 w_h = _cubic_aa_filter(arg_h) 

125 total_wy += tl.where(j_h < ysize_h_pos, w_h, 0.0) 

126 

127 raw_wy = _cubic_aa_filter(tl.abs((ih_f - center_h + 0.5) * invscale_h)) 

128 wy = tl.where(ih_in_range & (total_wy != 0.0), raw_wy / total_wy, 0.0) 

129 

130 # Load grad_out and accumulate 

131 valid = iw_in_range & ih_in_range 

132 oh_safe = tl.maximum(tl.minimum(oh, H_out - 1), 0) 

133 g = tl.load( 

134 grad_out_ptr 

135 + go_nc_base 

136 + oh_safe.to(tl.int64) * W_out 

137 + ow_safe.to(tl.int64), 

138 mask=valid, 

139 other=0.0, 

140 ) 

141 accum += wy * wx * g 

142 

143 gi_off = pid_row.to(tl.int64) * W_in + iws.to(tl.int64) 

144 tl.store( 

145 grad_in_ptr + gi_off, 

146 accum.to(grad_in_ptr.dtype.element_ty), 

147 mask=iw_mask, 

148 ) 

149 

150 

151@triton.jit 

152def _precompute_weight_sums_kernel( 

153 total_w_ptr, 

154 output_size, 

155 input_size, 

156 scale, 

157 support, 

158 invscale, 

159 MAX_KSIZE: tl.constexpr, 

160): 

161 oi = tl.program_id(0) 

162 if oi >= output_size: 

163 return 

164 center = scale * (oi + 0.5) 

165 xmin = tl.maximum(_f2i(center - support + 0.5), 0) 

166 xsize = tl.minimum(_f2i(center + support + 0.5), input_size) - xmin 

167 xsize = tl.minimum(tl.maximum(xsize, 0), MAX_KSIZE) 

168 xmin_f = xmin.to(tl.float32) 

169 total = 0.0 

170 for j in tl.static_range(MAX_KSIZE): 

171 arg = tl.abs((j + xmin_f - center + 0.5) * invscale) 

172 w = _cubic_aa_filter(arg) 

173 total += tl.where(j < xsize, w, 0.0) 

174 tl.store(total_w_ptr + oi, total) 

175 

176 

177@triton.jit 

178def _pass1_w_gather_nchw_kernel( 

179 grad_out_ptr, # [NC, H_out, W_out] flat 

180 buf_ptr, # [NC, H_out, W_in] flat (output) 

181 total_wx_ptr, # [W_out] 

182 W_in, 

183 W_out, 

184 w_scale, 

185 support_w, 

186 invscale_w, 

187 inv_w_scale, 

188 BLOCK_IW: tl.constexpr, 

189 MAX_OW: tl.constexpr, 

190): 

191 pid_row = tl.program_id(0) 

192 pid_col = tl.program_id(1) 

193 

194 iw_base = pid_col * BLOCK_IW 

195 iws = iw_base + tl.arange(0, BLOCK_IW) 

196 iw_mask = iws < W_in 

197 iw_f = iws.to(tl.float32) 

198 

199 go_base = pid_row.to(tl.int64) * W_out 

200 buf_base = pid_row.to(tl.int64) * W_in 

201 

202 ow_starts = tl.maximum(_f2i((iw_f + 0.5 - support_w) * inv_w_scale - 0.5), 0) 

203 

204 accum = tl.zeros([BLOCK_IW], dtype=tl.float32) 

205 

206 for d_ow in tl.static_range(MAX_OW): 

207 ow = ow_starts + d_ow 

208 ow_valid = iw_mask & (ow >= 0) & (ow < W_out) 

209 

210 center_w = w_scale * (ow.to(tl.float32) + 0.5) 

211 xmin = tl.maximum(_f2i(center_w - support_w + 0.5), 0) 

212 xsize = tl.minimum(_f2i(center_w + support_w + 0.5), W_in) - xmin 

213 in_range = ow_valid & (iws >= xmin) & (iws < xmin + tl.maximum(xsize, 0)) 

214 

215 raw_wx = _cubic_aa_filter(tl.abs((iw_f - center_w + 0.5) * invscale_w)) 

216 ow_safe = tl.maximum(tl.minimum(ow, W_out - 1), 0) 

217 tw_x = tl.load(total_wx_ptr + ow_safe, mask=in_range, other=1.0) 

218 wx = tl.where(in_range & (tw_x != 0.0), raw_wx / tw_x, 0.0) 

219 

220 g = tl.load( 

221 grad_out_ptr + go_base + ow_safe.to(tl.int64), mask=in_range, other=0.0 

222 ) 

223 accum += wx * g 

224 

225 tl.store(buf_ptr + buf_base + iws.to(tl.int64), accum, mask=iw_mask) 

226 

227 

228@triton.jit 

229def _pass2_h_gather_nchw_kernel( 

230 buf_ptr, # [NC, H_out, W_in] flat (input) 

231 grad_in_ptr, # [NC, H_in, W_in] flat (output) 

232 total_wy_ptr, # [H_out] 

233 H_in, 

234 W_in, 

235 H_out, 

236 h_scale, 

237 support_h, 

238 invscale_h, 

239 inv_h_scale, 

240 stride_buf_hw, # = H_out * W_in 

241 BLOCK_IW: tl.constexpr, 

242 MAX_OH: tl.constexpr, 

243): 

244 pid_row = tl.program_id(0) 

245 pid_col = tl.program_id(1) 

246 

247 nc = pid_row // H_in 

248 ih = pid_row % H_in 

249 ih_f = ih.to(tl.float32) 

250 

251 iw_base = pid_col * BLOCK_IW 

252 iws = iw_base + tl.arange(0, BLOCK_IW) 

253 iw_mask = iws < W_in 

254 

255 oh_start = tl.maximum(_f2i((ih_f + 0.5 - support_h) * inv_h_scale - 0.5), 0) 

256 

257 buf_nc_base = nc.to(tl.int64) * stride_buf_hw 

258 

259 accum = tl.zeros([BLOCK_IW], dtype=tl.float32) 

260 

261 for d_oh in tl.static_range(MAX_OH): 

262 oh = oh_start + d_oh 

263 oh_valid = (oh >= 0) & (oh < H_out) 

264 

265 center_h = h_scale * (oh + 0.5) 

266 ymin = tl.maximum(_f2i(center_h - support_h + 0.5), 0) 

267 ysize = tl.minimum(_f2i(center_h + support_h + 0.5), H_in) - ymin 

268 ih_in_range = oh_valid & (ih >= ymin) & (ih < ymin + tl.maximum(ysize, 0)) 

269 

270 raw_wy = _cubic_aa_filter(tl.abs((ih_f - center_h + 0.5) * invscale_h)) 

271 oh_safe = tl.maximum(tl.minimum(oh, H_out - 1), 0) 

272 tw_y = tl.load(total_wy_ptr + oh_safe) 

273 wy = tl.where(ih_in_range & (tw_y != 0.0), raw_wy / tw_y, 0.0) 

274 

275 buf_off = buf_nc_base + oh_safe.to(tl.int64) * W_in + iws.to(tl.int64) 

276 b = tl.load(buf_ptr + buf_off, mask=iw_mask & ih_in_range, other=0.0) 

277 

278 accum += wy * b 

279 

280 gi_off = pid_row.to(tl.int64) * W_in + iws.to(tl.int64) 

281 tl.store( 

282 grad_in_ptr + gi_off, 

283 accum.to(grad_in_ptr.dtype.element_ty), 

284 mask=iw_mask, 

285 ) 

286 

287 

288def _compute_scale(input_size, output_size, align_corners, scale=None): 

289 if align_corners: 

290 return float(input_size - 1) / (output_size - 1) if output_size > 1 else 0.0 

291 else: 

292 return ( 

293 (1.0 / scale) 

294 if (scale is not None and scale > 0) 

295 else float(input_size) / output_size 

296 ) 

297 

298 

299# Threshold: when total elements (across the larger of input / output spatial) 

300# is below this, the fused single-kernel path is used (1 launch instead of 4). 

301# Above this, the 2-pass separable path is more memory-bandwidth efficient. 

302_FUSE_THRESHOLD = 1 << 20 # 1M elements 

303 

304 

305def _upsample_bicubic2d_aa_backward( 

306 grad_output: torch.Tensor, 

307 output_size, # [H_out, W_out] 

308 input_size, # [N, C, H_in, W_in] 

309 align_corners: bool, 

310 scales_h=None, 

311 scales_w=None, 

312) -> torch.Tensor: 

313 N, C, H_in, W_in = input_size 

314 H_out, W_out = output_size 

315 

316 assert grad_output.shape == (N, C, H_out, W_out), ( 

317 f"grad_output shape {grad_output.shape} != " 

318 f"expected ({N}, {C}, {H_out}, {W_out})" 

319 ) 

320 

321 NC = N * C 

322 if NC == 0 or H_in == 0 or W_in == 0 or H_out == 0 or W_out == 0: 

323 return grad_output.new_zeros(input_size) 

324 

325 # ---- Work in NCHW — zero-copy reshape to [NC, H, W] ---- 

326 grad_out_flat = grad_output.contiguous().reshape(NC, H_out, W_out) 

327 

328 # ---- Scales & filter parameters ---- 

329 h_scale = _compute_scale(H_in, H_out, align_corners, scales_h) 

330 w_scale = _compute_scale(W_in, W_out, align_corners, scales_w) 

331 

332 INTERP_SIZE = 4 

333 support_h = (INTERP_SIZE * 0.5) * h_scale if h_scale >= 1.0 else INTERP_SIZE * 0.5 

334 support_w = (INTERP_SIZE * 0.5) * w_scale if w_scale >= 1.0 else INTERP_SIZE * 0.5 

335 invscale_h = 1.0 / h_scale if h_scale >= 1.0 else 1.0 

336 invscale_w = 1.0 / w_scale if w_scale >= 1.0 else 1.0 

337 

338 MAX_KSIZE_H = math.ceil(support_h) * 2 + 1 

339 MAX_KSIZE_W = math.ceil(support_w) * 2 + 1 

340 

341 _EPS = 1e-10 

342 inv_h_scale = 1.0 / max(h_scale, _EPS) 

343 inv_w_scale = 1.0 / max(w_scale, _EPS) 

344 

345 MAX_OH = min(math.ceil(2 * support_h * inv_h_scale) + 2, max(H_out, 1)) 

346 MAX_OW = min(math.ceil(2 * support_w * inv_w_scale) + 2, max(W_out, 1)) 

347 

348 # ---- BLOCK_IW & num_warps ---- 

349 BLOCK_IW = min(triton.next_power_of_2(max(W_in, 1)), 256) 

350 if BLOCK_IW < 32: 

351 BLOCK_IW = 32 

352 nw = 1 if BLOCK_IW <= 32 else (2 if BLOCK_IW <= 64 else 4) 

353 

354 # ---- Choose fused vs 2-pass ---- 

355 total_elems = NC * max(H_in * W_in, H_out * W_out) 

356 use_fused = total_elems <= _FUSE_THRESHOLD 

357 

358 if use_fused: 

359 # ============================================================ 

360 # FUSED PATH — single kernel launch, no intermediate buffer 

361 # ============================================================ 

362 grad_in_flat = torch.empty( 

363 NC, H_in, W_in, dtype=grad_output.dtype, device=grad_output.device 

364 ) 

365 grid = (NC * H_in, triton.cdiv(W_in, BLOCK_IW)) 

366 _fused_backward_kernel[grid]( 

367 grad_out_flat, 

368 grad_in_flat, 

369 H_in, 

370 H_out, 

371 h_scale, 

372 support_h, 

373 invscale_h, 

374 inv_h_scale, 

375 W_in, 

376 W_out, 

377 w_scale, 

378 support_w, 

379 invscale_w, 

380 inv_w_scale, 

381 H_out * W_out, # stride_go_nc 

382 BLOCK_IW=BLOCK_IW, 

383 MAX_OH=MAX_OH, 

384 MAX_OW=MAX_OW, 

385 MAX_KSIZE_H=MAX_KSIZE_H, 

386 MAX_KSIZE_W=MAX_KSIZE_W, 

387 num_warps=nw, 

388 ) 

389 return grad_in_flat.reshape(N, C, H_in, W_in) 

390 

391 else: 

392 # ============================================================ 

393 # 2-PASS PATH — separable, memory-bandwidth efficient for big tensors 

394 # ============================================================ 

395 

396 # Phase 0: precompute weight sums 

397 total_wy = torch.empty( 

398 max(H_out, 1), dtype=torch.float32, device=grad_output.device 

399 ) 

400 total_wx = torch.empty( 

401 max(W_out, 1), dtype=torch.float32, device=grad_output.device 

402 ) 

403 if H_out > 0: 

404 _precompute_weight_sums_kernel[(H_out,)]( 

405 total_wy, 

406 H_out, 

407 H_in, 

408 h_scale, 

409 support_h, 

410 invscale_h, 

411 MAX_KSIZE=MAX_KSIZE_H, 

412 ) 

413 if W_out > 0: 

414 _precompute_weight_sums_kernel[(W_out,)]( 

415 total_wx, 

416 W_out, 

417 W_in, 

418 w_scale, 

419 support_w, 

420 invscale_w, 

421 MAX_KSIZE=MAX_KSIZE_W, 

422 ) 

423 

424 # Phase 1: W-gather -> buf [NC, H_out, W_in] 

425 buf = torch.empty( 

426 NC, H_out, W_in, dtype=torch.float32, device=grad_output.device 

427 ) 

428 grid1 = (NC * H_out, triton.cdiv(W_in, BLOCK_IW)) 

429 _pass1_w_gather_nchw_kernel[grid1]( 

430 grad_out_flat, 

431 buf, 

432 total_wx, 

433 W_in, 

434 W_out, 

435 w_scale, 

436 support_w, 

437 invscale_w, 

438 inv_w_scale, 

439 BLOCK_IW=BLOCK_IW, 

440 MAX_OW=MAX_OW, 

441 num_warps=nw, 

442 ) 

443 

444 # Phase 2: H-gather -> grad_in [NC, H_in, W_in] 

445 grad_in_flat = torch.empty( 

446 NC, H_in, W_in, dtype=grad_output.dtype, device=grad_output.device 

447 ) 

448 grid2 = (NC * H_in, triton.cdiv(W_in, BLOCK_IW)) 

449 _pass2_h_gather_nchw_kernel[grid2]( 

450 buf, 

451 grad_in_flat, 

452 total_wy, 

453 H_in, 

454 W_in, 

455 H_out, 

456 h_scale, 

457 support_h, 

458 invscale_h, 

459 inv_h_scale, 

460 H_out * W_in, # stride_buf_hw 

461 BLOCK_IW=BLOCK_IW, 

462 MAX_OH=MAX_OH, 

463 num_warps=nw, 

464 ) 

465 

466 return grad_in_flat.reshape(N, C, H_in, W_in)