Coverage for src/flag_gems/ops/grid_sample.py: 13%

2179 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1""" 

2Grid sample operator implementation for FlagGems. 

3 

4This module provides the grid sampling operation with various interpolation modes. 

5Grid sample computes the output using input values and pixel locations from grid. 

6""" 

7 

8import logging 

9 

10import torch 

11import triton 

12import triton.language as tl 

13 

14from flag_gems import runtime 

15from flag_gems.utils import libentry 

16 

17logger = logging.getLogger(__name__) 

18 

19# ============================================================================ 

20# Grid Sample Constants 

21# ============================================================================ 

22 

23# Maximum tiled voxel count for tiled kernel usage 

24MAX_TILED_VOXELS = 128 * 128 * 128 # ~2M voxels 

25 

26# Voxel thresholds for adaptive block targeting 

27# These represent approximate cube dimensions: 16³=4096, 20³=8000, 32³=32768, 50³=125000, 64³=262144 

28VOXEL_THRESHOLD_SMALL = 8192 # Threshold for small outputs (16³ - 20³) 

29VOXEL_THRESHOLD_MEDIUM = 32768 # Threshold for medium outputs (20³ - 32³) 

30VOXEL_THRESHOLD_LARGE = 131072 # Threshold for large outputs (32³ - 50³) 

31VOXEL_THRESHOLD_VERY_LARGE = 262144 # Threshold for very large outputs (50³ - 64³) 

32 

33# Block target configuration for different output sizes 

34# Small outputs (16³ - 20³): Higher block count for better utilization 

35TARGET_BLOCKS_SMALL = 512 

36MIN_BLOCKS_NC_SMALL = 64 

37MAX_BLOCKS_NC_SMALL = 1024 

38 

39# Medium outputs (20³ - 32³): Even higher block count 

40TARGET_BLOCKS_MEDIUM = 768 

41MIN_BLOCKS_NC_MEDIUM = 128 

42MAX_BLOCKS_NC_MEDIUM = 2048 

43 

44# Large outputs (32³ - 50³): Maximum block targeting 

45TARGET_BLOCKS_LARGE = 1024 

46MIN_BLOCKS_NC_LARGE = 128 

47MAX_BLOCKS_NC_LARGE = 2048 

48 

49# Very large outputs (50³ - 64³): Reduced block count 

50TARGET_BLOCKS_VERY_LARGE = 512 

51MIN_BLOCKS_NC_VERY_LARGE = 64 

52MAX_BLOCKS_NC_VERY_LARGE = 1024 

53 

54# Extra large outputs (>= 64³): Conservative block targeting 

55TARGET_BLOCKS_EXTRA_LARGE = 300 

56MIN_BLOCKS_NC_EXTRA_LARGE = 50 

57MAX_BLOCKS_NC_EXTRA_LARGE = 1000 

58 

59# Channel scaling constants 

60CHANNEL_COUNT_THRESHOLD = 32 # Channel count above which to scale down block targets 

61CHANNEL_SCALING_EXPONENT = 0.7 # Exponent for channel scaling factor 

62MIN_TARGET_TOTAL_BLOCKS = 128 # Minimum target total blocks when scaling for channels 

63MIN_BLOCKS_PER_NC = 16 # Minimum blocks per (N, C) pair when scaling for channels 

64 

65# Tile size constants 

66MIN_TILE_SIDE = 4 # Minimum tile side length for 3D outputs 

67MAX_TILE_SIDE = 64 # Maximum tile side length for 3D outputs 

68LARGE_TILE_THRESHOLD = 32 # Threshold for using 32 or 64 sized tiles 

69VERY_LARGE_TILE_THRESHOLD = 48 # Threshold for using 64 instead of 32 

70MEDIUM_TILE_THRESHOLD = 16 # Threshold for using 16 sized tiles 

71SMALL_TILE_THRESHOLD = 8 # Threshold for using 8 sized tiles 

72 

73# Trilinear reduction constants 

74MIN_BLOCK_DIMENSION = 2 # Minimum block dimension after halving for trilinear 

75 

76 

77def _validate_grid_sample_input(input, grid, mode, padding_mode): 

78 """ 

79 Validate input tensors and parameters for grid_sample. 

80 

81 Args: 

82 input: Input tensor 

83 grid: Grid tensor 

84 mode: Interpolation mode 

85 padding_mode: Padding mode 

86 

87 Raises: 

88 ValueError: If inputs or parameters are invalid 

89 """ 

90 if input.dim() not in [4, 5]: 

91 raise ValueError("Input must be 4D or 5D") 

92 

93 if input.dim() == 4 and grid.dim() != 4: 

94 raise ValueError( 

95 "For 4D input, grid must be 4D (N, H_out, W_out, 2), " 

96 f"but got {grid.dim()}D tensor" 

97 ) 

98 

99 if input.dim() == 5 and grid.dim() != 5: 

100 raise ValueError( 

101 f"For 5D input, grid must be 5D (N, D_out, H_out, W_out, 3), " 

102 f"but got {grid.dim()}D tensor" 

103 ) 

104 

105 if input.dim() == 4 and grid.shape[-1] != 2: 

106 raise ValueError( 

107 f"For 4D input, grid must have 2 coordinates in last dimension, " 

108 f"but got {grid.shape[-1]}" 

109 ) 

110 

111 if input.dim() == 5 and grid.shape[-1] != 3: 

112 raise ValueError( 

113 f"For 5D input, grid must have 3 coordinates in last dimension, " 

114 f"but got {grid.shape[-1]}" 

115 ) 

116 

117 if input.shape[0] != grid.shape[0]: 

118 raise ValueError( 

119 f"Input and grid must have same batch size, " 

120 f"but got {input.shape[0]} and {grid.shape[0]}" 

121 ) 

122 

123 valid_modes = ["bilinear", "nearest", "bicubic"] 

124 if mode not in valid_modes: 

125 raise ValueError( 

126 f"Invalid mode '{mode}'. Expected one of {valid_modes}, " 

127 f"but note: bicubic only supports 4D input" 

128 ) 

129 

130 if mode == "bicubic" and input.dim() == 5: 

131 raise ValueError("Bicubic interpolation only supports 4D input") 

132 

133 valid_padding_modes = ["zeros", "border", "reflection"] 

134 if padding_mode not in valid_padding_modes: 

135 raise ValueError( 

136 f"Invalid padding_mode '{padding_mode}'. Expected one of {valid_padding_modes}" 

137 ) 

138 

139 

140# ============================================================================ 

141# 2D Nearest Neighbor Kernels 

142# ============================================================================ 

143 

144 

145@libentry() 

146@triton.autotune( 

147 configs=runtime.get_tuned_config("grid_sample_2d_nearest"), 

148 key=["N", "C", "H_out", "W_out"], 

149) 

150@triton.jit 

151def grid_sample_2d_nearest_zeros_kernel( 

152 ptr_output, 

153 ptr_input, 

154 ptr_grid, 

155 N, 

156 C, 

157 H_in, 

158 W_in, 

159 H_out, 

160 W_out, 

161 align_corners: tl.constexpr, 

162 BLOCK_SIZE: tl.constexpr, 

163): 

164 """ 

165 Grid sample kernel for 2D nearest neighbor interpolation with zeros padding. 

166 

167 For each output pixel, this kernel: 

168 1. Loads the grid coordinates (normalized to [-1, 1]) 

169 2. Transforms coordinates to pixel space 

170 3. Rounds to nearest pixel location 

171 4. Loads the input pixel (or 0 if out of bounds) 

172 5. Stores to output 

173 

174 Args: 

175 ptr_output: Pointer to output tensor 

176 ptr_input: Pointer to input tensor 

177 ptr_grid: Pointer to grid tensor 

178 N: Batch size 

179 C: Number of channels 

180 H_in: Input height 

181 W_in: Input width 

182 H_out: Output height 

183 W_out: Output width 

184 align_corners: Whether to align corners 

185 BLOCK_SIZE: Block size for tuning 

186 """ 

187 # Each program instance handles one output pixel (for all channels) 

188 pid = tl.program_id(0) 

189 nc = pid // (H_out * W_out) 

190 hw = pid % (H_out * W_out) 

191 

192 n = nc // C 

193 c = nc % C 

194 h_out = hw // W_out 

195 w_out = hw % W_out 

196 

197 # Load grid coordinates for this output location 

198 # Grid shape: (N, H_out, W_out, 2) 

199 grid_idx = n * H_out * W_out * 2 + h_out * W_out * 2 + w_out * 2 

200 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32) 

201 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32) 

202 

203 # Handle NaN - use sentinel value -2.0 (outside valid grid range [-1, 1]) 

204 # We'll detect this and return 0.0 for NaN values 

205 grid_x_nan = grid_x != grid_x # True if NaN 

206 grid_y_nan = grid_y != grid_y # True if NaN 

207 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

208 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

209 

210 # Denormalize to pixel space 

211 if align_corners: 

212 # Pixel centers at -1 and 1 

213 x = (grid_x + 1.0) * (W_in - 1) / 2.0 

214 y = (grid_y + 1.0) * (H_in - 1) / 2.0 

215 # Use banker's rounding (round half to even) for align_corners=True too 

216 x_floor = tl.floor(x) 

217 y_floor = tl.floor(y) 

218 x_frac = x - x_floor 

219 y_frac = y - y_floor 

220 x_is_half = x_frac == 0.5 

221 y_is_half = y_frac == 0.5 

222 x_floor_int = tl.cast(x_floor, tl.int32) 

223 y_floor_int = tl.cast(y_floor, tl.int32) 

224 x_is_even = x_floor_int % 2 == 0 

225 y_is_even = y_floor_int % 2 == 0 

226 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1) 

227 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1) 

228 x_idx = tl.cast( 

229 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round), 

230 tl.int32, 

231 ) 

232 y_idx = tl.cast( 

233 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round), 

234 tl.int32, 

235 ) 

236 # Check bounds (align_corners=True: valid range is [0, W_in) x [0, H_in)) 

237 # Also check for NaN (sentinel value -2.0) 

238 mask = ( 

239 (x_idx >= 0) 

240 & (x_idx < W_in) 

241 & (y_idx >= 0) 

242 & (y_idx < H_in) 

243 & ~grid_x_nan 

244 & ~grid_y_nan 

245 ) 

246 else: 

247 # Pixel corners at -1 and 1 

248 x = (grid_x + 1.0) * W_in / 2.0 - 0.5 

249 y = (grid_y + 1.0) * H_in / 2.0 - 0.5 

250 # Use banker's rounding (round half to even) for align_corners=False 

251 x_floor = tl.floor(x) 

252 y_floor = tl.floor(y) 

253 x_frac = x - x_floor 

254 y_frac = y - y_floor 

255 x_is_half = x_frac == 0.5 

256 y_is_half = y_frac == 0.5 

257 x_floor_int = tl.cast(x_floor, tl.int32) 

258 y_floor_int = tl.cast(y_floor, tl.int32) 

259 x_is_even = x_floor_int % 2 == 0 

260 y_is_even = y_floor_int % 2 == 0 

261 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1) 

262 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1) 

263 x_idx = tl.cast( 

264 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round), 

265 tl.int32, 

266 ) 

267 y_idx = tl.cast( 

268 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round), 

269 tl.int32, 

270 ) 

271 

272 # Check bounds (align_corners=False) 

273 # Also check for NaN (sentinel value -2.0) 

274 mask = ( 

275 (x_idx >= 0) 

276 & (x_idx < W_in) 

277 & (y_idx >= 0) 

278 & (y_idx < H_in) 

279 & ~grid_x_nan 

280 & ~grid_y_nan 

281 ) 

282 

283 # Input shape: (N, C, H_in, W_in) 

284 input_offset = n * C * H_in * W_in + c * H_in * W_in + y_idx * W_in + x_idx 

285 val = tl.load(ptr_input + input_offset, mask=mask, other=0.0).to(tl.float32) 

286 

287 # Store output 

288 # Output shape: (N, C, H_out, W_out) 

289 output_offset = n * C * H_out * W_out + c * H_out * W_out + h_out * W_out + w_out 

290 tl.store(ptr_output + output_offset, val) 

291 

292 

293@libentry() 

294@triton.autotune( 

295 configs=runtime.get_tuned_config("grid_sample_2d_nearest"), 

296 key=["N", "C", "H_out", "W_out"], 

297) 

298@triton.jit 

299def grid_sample_2d_nearest_border_kernel( 

300 ptr_output, 

301 ptr_input, 

302 ptr_grid, 

303 N, 

304 C, 

305 H_in, 

306 W_in, 

307 H_out, 

308 W_out, 

309 align_corners: tl.constexpr, 

310 BLOCK_SIZE: tl.constexpr, 

311): 

312 """ 

313 Grid sample kernel for 2D nearest neighbor interpolation with border padding. 

314 

315 Out-of-bound coordinates are clamped to the border. 

316 """ 

317 pid = tl.program_id(0) 

318 nc = pid // (H_out * W_out) 

319 hw = pid % (H_out * W_out) 

320 

321 n = nc // C 

322 c = nc % C 

323 h_out = hw // W_out 

324 w_out = hw % W_out 

325 

326 # Load grid coordinates 

327 grid_idx = n * H_out * W_out * 2 + h_out * W_out * 2 + w_out * 2 

328 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32) 

329 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32) 

330 

331 # Handle NaN 

332 grid_x = tl.where(grid_x != grid_x, -1.0, grid_x) 

333 grid_y = tl.where(grid_y != grid_y, -1.0, grid_y) 

334 

335 # Denormalize to pixel space 

336 if align_corners: 

337 x = (grid_x + 1.0) * (W_in - 1) / 2.0 

338 y = (grid_y + 1.0) * (H_in - 1) / 2.0 

339 # Use banker's rounding (round half to even) 

340 x_floor = tl.floor(x) 

341 y_floor = tl.floor(y) 

342 x_frac = x - x_floor 

343 y_frac = y - y_floor 

344 x_is_half = x_frac == 0.5 

345 y_is_half = y_frac == 0.5 

346 x_floor_int = tl.cast(x_floor, tl.int32) 

347 y_floor_int = tl.cast(y_floor, tl.int32) 

348 x_is_even = x_floor_int % 2 == 0 

349 y_is_even = y_floor_int % 2 == 0 

350 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1) 

351 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1) 

352 x_idx_unclamped = tl.cast( 

353 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round), 

354 tl.int32, 

355 ) 

356 y_idx_unclamped = tl.cast( 

357 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round), 

358 tl.int32, 

359 ) 

360 # For align_corners=True: clamp to [0, W_in-1] 

361 x_idx = tl.maximum(0, tl.minimum(x_idx_unclamped, W_in - 1)) 

362 y_idx = tl.maximum(0, tl.minimum(y_idx_unclamped, H_in - 1)) 

363 else: 

364 x = (grid_x + 1.0) * W_in / 2.0 - 0.5 

365 y = (grid_y + 1.0) * H_in / 2.0 - 0.5 

366 # Use banker's rounding (round half to even) for align_corners=False 

367 x_floor = tl.floor(x) 

368 y_floor = tl.floor(y) 

369 x_frac = x - x_floor 

370 y_frac = y - y_floor 

371 x_is_half = x_frac == 0.5 

372 y_is_half = y_frac == 0.5 

373 x_floor_int = tl.cast(x_floor, tl.int32) 

374 y_floor_int = tl.cast(y_floor, tl.int32) 

375 x_is_even = x_floor_int % 2 == 0 

376 y_is_even = y_floor_int % 2 == 0 

377 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1) 

378 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1) 

379 x_idx_unclamped = tl.cast( 

380 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round), 

381 tl.int32, 

382 ) 

383 y_idx_unclamped = tl.cast( 

384 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round), 

385 tl.int32, 

386 ) 

387 # For align_corners=False: clamp to [0, W_in-1] 

388 x_idx = tl.maximum(0, tl.minimum(x_idx_unclamped, W_in - 1)) 

389 y_idx = tl.maximum(0, tl.minimum(y_idx_unclamped, H_in - 1)) 

390 

391 # Load input pixel (always in bounds due to clamping) 

392 input_offset = n * C * H_in * W_in + c * H_in * W_in + y_idx * W_in + x_idx 

393 val = tl.load(ptr_input + input_offset).to(tl.float32) 

394 

395 # Store output 

396 output_offset = n * C * H_out * W_out + c * H_out * W_out + h_out * W_out + w_out 

397 tl.store(ptr_output + output_offset, val) 

398 

399 

400@libentry() 

401@triton.autotune( 

402 configs=runtime.get_tuned_config("grid_sample_2d_nearest"), 

403 key=["N", "C", "H_out", "W_out"], 

404) 

405@triton.jit 

406def grid_sample_2d_nearest_reflection_kernel( 

407 ptr_output, 

408 ptr_input, 

409 ptr_grid, 

410 N, 

411 C, 

412 H_in, 

413 W_in, 

414 H_out, 

415 W_out, 

416 align_corners: tl.constexpr, 

417 BLOCK_SIZE: tl.constexpr, 

418): 

419 """ 

420 Grid sample kernel for 2D nearest neighbor interpolation with reflection padding. 

421 

422 Out-of-bound coordinates are reflected back into the valid range. 

423 """ 

424 pid = tl.program_id(0) 

425 nc = pid // (H_out * W_out) 

426 hw = pid % (H_out * W_out) 

427 

428 n = nc // C 

429 c = nc % C 

430 h_out = hw // W_out 

431 w_out = hw % W_out 

432 

433 # Load grid coordinates 

434 grid_idx = n * H_out * W_out * 2 + h_out * W_out * 2 + w_out * 2 

435 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32) 

436 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32) 

437 

438 # Handle NaN 

439 grid_x = tl.where(grid_x != grid_x, -1.0, grid_x) 

440 grid_y = tl.where(grid_y != grid_y, -1.0, grid_y) 

441 

442 # Reflection padding in GRID space (before denormalizing) 

443 # The grid space is [-1, 1], reflect at boundaries -1 and 1 

444 # Triangle wave pattern with period 4 

445 

446 # Shift to [0, 4) range, handling negative modulo correctly 

447 grid_x_shifted = grid_x + 1.0 

448 # Triton's % operator behaves like C's fmod for floats (preserves sign) 

449 # So we need to adjust: for negative values, add period to make it positive 

450 grid_x_mod = grid_x_shifted % 4.0 

451 grid_x_mod = tl.where(grid_x_mod < 0, grid_x_mod + 4.0, grid_x_mod) 

452 

453 # Triangle wave: goes up from 0 to 2, then down from 2 to 0 

454 grid_x_refl_mod = tl.where(grid_x_mod <= 2.0, grid_x_mod, 4.0 - grid_x_mod) 

455 grid_x_refl = grid_x_refl_mod - 1.0 # Shift back to [-1, 1] 

456 

457 # Same for y 

458 grid_y_shifted = grid_y + 1.0 

459 grid_y_mod = grid_y_shifted % 4.0 

460 grid_y_mod = tl.where(grid_y_mod < 0, grid_y_mod + 4.0, grid_y_mod) 

461 grid_y_refl_mod = tl.where(grid_y_mod <= 2.0, grid_y_mod, 4.0 - grid_y_mod) 

462 grid_y_refl = grid_y_refl_mod - 1.0 

463 

464 # Denormalize to pixel space 

465 if align_corners: 

466 x = (grid_x_refl + 1.0) * (W_in - 1) / 2.0 

467 y = (grid_y_refl + 1.0) * (H_in - 1) / 2.0 

468 else: 

469 x = (grid_x_refl + 1.0) * W_in / 2.0 - 0.5 

470 y = (grid_y_refl + 1.0) * H_in / 2.0 - 0.5 

471 

472 # Banker's rounding (round half to even) 

473 x_floor = tl.floor(x) 

474 y_floor = tl.floor(y) 

475 x_frac = x - x_floor 

476 y_frac = y - y_floor 

477 x_is_half = x_frac == 0.5 

478 y_is_half = y_frac == 0.5 

479 x_floor_int = tl.cast(x_floor, tl.int32) 

480 y_floor_int = tl.cast(y_floor, tl.int32) 

481 x_is_even = x_floor_int % 2 == 0 

482 y_is_even = y_floor_int % 2 == 0 

483 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1) 

484 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1) 

485 x_idx_unclamped = tl.cast( 

486 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round), 

487 tl.int32, 

488 ) 

489 y_idx_unclamped = tl.cast( 

490 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round), 

491 tl.int32, 

492 ) 

493 

494 # Clamp to valid bounds (should already be in bounds due to reflection, but clamp for safety) 

495 x_idx = tl.maximum(0, tl.minimum(x_idx_unclamped, W_in - 1)) 

496 y_idx = tl.maximum(0, tl.minimum(y_idx_unclamped, H_in - 1)) 

497 

498 # Load input pixel 

499 input_offset = n * C * H_in * W_in + c * H_in * W_in + y_idx * W_in + x_idx 

500 val = tl.load(ptr_input + input_offset).to(tl.float32) 

501 

502 # Store output 

503 output_offset = n * C * H_out * W_out + c * H_out * W_out + h_out * W_out + w_out 

504 tl.store(ptr_output + output_offset, val) 

505 

506 

507# ============================================================================ 

508# Bilinear Interpolation Kernels (4D) 

509# ============================================================================ 

510 

511 

512@libentry() 

513@triton.autotune( 

514 configs=runtime.get_tuned_config("grid_sample_2d_bilinear"), 

515 key=["N", "C", "H_out", "W_out"], 

516) 

517@triton.jit 

518def grid_sample_2d_bilinear_zeros_kernel( 

519 ptr_output, 

520 ptr_input, 

521 ptr_grid, 

522 N, 

523 C, 

524 H_in, 

525 W_in, 

526 H_out, 

527 W_out, 

528 align_corners: tl.constexpr, 

529 BLOCK_SIZE: tl.constexpr, 

530): 

531 """ 

532 Grid sample kernel for 2D bilinear interpolation with zeros padding. 

533 

534 Each program instance handles one output pixel location (all channels). 

535 Loads 4 corner pixels and performs bilinear interpolation. 

536 """ 

537 # Each program instance processes one output pixel (all channels) 

538 pid = tl.program_id(0) 

539 nc = pid // (H_out * W_out) 

540 hw = pid % (H_out * W_out) 

541 

542 n = nc // C 

543 c = nc % C 

544 h_out = hw // W_out 

545 w_out = hw % W_out 

546 

547 # Load grid coordinates for this output location 

548 # Grid shape: (N, H_out, W_out, 2) 

549 grid_idx = n * H_out * W_out * 2 + h_out * W_out * 2 + w_out * 2 

550 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32) 

551 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32) 

552 

553 # Handle NaN - use sentinel value -2.0 (outside valid grid range [-1, 1]) 

554 grid_x_nan = grid_x != grid_x 

555 grid_y_nan = grid_y != grid_y 

556 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

557 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

558 

559 # Denormalize to pixel space 

560 if align_corners: 

561 # Pixel centers at -1 and 1 

562 x = (grid_x + 1.0) * (W_in - 1) / 2.0 

563 y = (grid_y + 1.0) * (H_in - 1) / 2.0 

564 else: 

565 # Pixel corners at -1 and 1 

566 x = (grid_x + 1.0) * W_in / 2.0 - 0.5 

567 y = (grid_y + 1.0) * H_in / 2.0 - 0.5 

568 

569 # Find 4 corner indices 

570 x0 = tl.floor(x) 

571 y0 = tl.floor(y) 

572 x1 = x0 + 1 

573 y1 = y0 + 1 

574 

575 # Compute interpolation weights 

576 wx = x - x0 

577 wy = y - y0 

578 

579 # Convert corner indices to int 

580 x0_int = tl.cast(x0, tl.int32) 

581 y0_int = tl.cast(y0, tl.int32) 

582 x1_int = tl.cast(x1, tl.int32) 

583 y1_int = tl.cast(y1, tl.int32) 

584 

585 # Check bounds for each corner (zeros padding) 

586 x0_in_bounds = (x0_int >= 0) & (x0_int < W_in) 

587 x1_in_bounds = (x1_int >= 0) & (x1_int < W_in) 

588 y0_in_bounds = (y0_int >= 0) & (y0_int < H_in) 

589 y1_in_bounds = (y1_int >= 0) & (y1_int < H_in) 

590 

591 # Load 4 corner pixels with zeros padding 

592 # Input shape: (N, C, H_in, W_in) 

593 input_base = n * C * H_in * W_in + c * H_in * W_in 

594 

595 offset_00 = input_base + y0_int * W_in + x0_int 

596 offset_01 = input_base + y0_int * W_in + x1_int 

597 offset_10 = input_base + y1_int * W_in + x0_int 

598 offset_11 = input_base + y1_int * W_in + x1_int 

599 

600 p00 = tl.load( 

601 ptr_input + offset_00, 

602 mask=x0_in_bounds & y0_in_bounds & ~grid_x_nan & ~grid_y_nan, 

603 other=0.0, 

604 ).to(tl.float32) 

605 p01 = tl.load( 

606 ptr_input + offset_01, 

607 mask=x1_in_bounds & y0_in_bounds & ~grid_x_nan & ~grid_y_nan, 

608 other=0.0, 

609 ).to(tl.float32) 

610 p10 = tl.load( 

611 ptr_input + offset_10, 

612 mask=x0_in_bounds & y1_in_bounds & ~grid_x_nan & ~grid_y_nan, 

613 other=0.0, 

614 ).to(tl.float32) 

615 p11 = tl.load( 

616 ptr_input + offset_11, 

617 mask=x1_in_bounds & y1_in_bounds & ~grid_x_nan & ~grid_y_nan, 

618 other=0.0, 

619 ).to(tl.float32) 

620 

621 # Bilinear interpolation 

622 # Interpolate along x, then y 

623 # top = p00 * (1-wx) + p01 * wx 

624 # bottom = p10 * (1-wx) + p11 * wx 

625 # result = top * (1-wy) + bottom * wy 

626 top = p00 * (1.0 - wx) + p01 * wx 

627 bottom = p10 * (1.0 - wx) + p11 * wx 

628 val = top * (1.0 - wy) + bottom * wy 

629 

630 # Store output 

631 # Output shape: (N, C, H_out, W_out) 

632 output_offset = n * C * H_out * W_out + c * H_out * W_out + h_out * W_out + w_out 

633 tl.store(ptr_output + output_offset, val) 

634 

635 

636@libentry() 

637@triton.autotune( 

638 configs=runtime.get_tuned_config("grid_sample_2d_bilinear"), 

639 key=["N", "C", "H_out", "W_out"], 

640) 

641@triton.jit 

642def grid_sample_2d_bilinear_border_kernel( 

643 ptr_output, 

644 ptr_input, 

645 ptr_grid, 

646 N, 

647 C, 

648 H_in, 

649 W_in, 

650 H_out, 

651 W_out, 

652 align_corners: tl.constexpr, 

653 BLOCK_SIZE: tl.constexpr, 

654): 

655 """ 

656 Grid sample kernel for 2D bilinear interpolation with border padding. 

657 

658 Clamps coordinates to valid range [0, size-1] for out-of-bound values. 

659 """ 

660 # Each program instance processes one output pixel (all channels) 

661 pid = tl.program_id(0) 

662 nc = pid // (H_out * W_out) 

663 hw = pid % (H_out * W_out) 

664 

665 n = nc // C 

666 c = nc % C 

667 h_out = hw // W_out 

668 w_out = hw % W_out 

669 

670 # Load grid coordinates for this output location 

671 grid_idx = n * H_out * W_out * 2 + h_out * W_out * 2 + w_out * 2 

672 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32) 

673 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32) 

674 

675 # Handle NaN 

676 grid_x_nan = grid_x != grid_x 

677 grid_y_nan = grid_y != grid_y 

678 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

679 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

680 

681 # Denormalize to pixel space 

682 if align_corners: 

683 x = (grid_x + 1.0) * (W_in - 1) / 2.0 

684 y = (grid_y + 1.0) * (H_in - 1) / 2.0 

685 else: 

686 x = (grid_x + 1.0) * W_in / 2.0 - 0.5 

687 y = (grid_y + 1.0) * H_in / 2.0 - 0.5 

688 

689 # Find 4 corner indices 

690 x0 = tl.floor(x) 

691 y0 = tl.floor(y) 

692 x1 = x0 + 1 

693 y1 = y0 + 1 

694 

695 # Convert to int 

696 x0_int = tl.cast(x0, tl.int32) 

697 y0_int = tl.cast(y0, tl.int32) 

698 x1_int = tl.cast(x1, tl.int32) 

699 y1_int = tl.cast(y1, tl.int32) 

700 

701 # Clamp to valid bounds (border padding) 

702 x0_int = tl.maximum(0, tl.minimum(x0_int, W_in - 1)) 

703 x1_int = tl.maximum(0, tl.minimum(x1_int, W_in - 1)) 

704 y0_int = tl.maximum(0, tl.minimum(y0_int, H_in - 1)) 

705 y1_int = tl.maximum(0, tl.minimum(y1_int, H_in - 1)) 

706 

707 # Compute interpolation weights 

708 wx = x - x0 

709 wy = y - y0 

710 

711 # Load 4 corner pixels (no mask needed due to clamping) 

712 input_base = n * C * H_in * W_in + c * H_in * W_in 

713 

714 offset_00 = input_base + y0_int * W_in + x0_int 

715 offset_01 = input_base + y0_int * W_in + x1_int 

716 offset_10 = input_base + y1_int * W_in + x0_int 

717 offset_11 = input_base + y1_int * W_in + x1_int 

718 

719 # For NaN, return 0.0 

720 p00 = tl.load(ptr_input + offset_00) 

721 p01 = tl.load(ptr_input + offset_01) 

722 p10 = tl.load(ptr_input + offset_10) 

723 p11 = tl.load(ptr_input + offset_11) 

724 

725 # Bilinear interpolation 

726 top = p00 * (1.0 - wx) + p01 * wx 

727 bottom = p10 * (1.0 - wx) + p11 * wx 

728 val = tl.where(grid_x_nan | grid_y_nan, 0.0, top * (1.0 - wy) + bottom * wy) 

729 

730 # Store output 

731 output_offset = n * C * H_out * W_out + c * H_out * W_out + h_out * W_out + w_out 

732 tl.store(ptr_output + output_offset, val) 

733 

734 

735@libentry() 

736@triton.autotune( 

737 configs=runtime.get_tuned_config("grid_sample_2d_bilinear"), 

738 key=["N", "C", "H_out", "W_out"], 

739) 

740@triton.jit 

741def grid_sample_2d_bilinear_reflection_kernel( 

742 ptr_output, 

743 ptr_input, 

744 ptr_grid, 

745 N, 

746 C, 

747 H_in, 

748 W_in, 

749 H_out, 

750 W_out, 

751 align_corners: tl.constexpr, 

752 BLOCK_SIZE: tl.constexpr, 

753): 

754 """ 

755 Grid sample kernel for 2D bilinear interpolation with reflection padding. 

756 

757 Reflects coordinates at boundaries using triangle wave pattern in grid space. 

758 """ 

759 # Each program instance processes one output pixel (all channels) 

760 pid = tl.program_id(0) 

761 nc = pid // (H_out * W_out) 

762 hw = pid % (H_out * W_out) 

763 

764 n = nc // C 

765 c = nc % C 

766 h_out = hw // W_out 

767 w_out = hw % W_out 

768 

769 # Load grid coordinates 

770 grid_idx = n * H_out * W_out * 2 + h_out * W_out * 2 + w_out * 2 

771 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32) 

772 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32) 

773 

774 # Handle NaN 

775 grid_x_nan = grid_x != grid_x 

776 grid_y_nan = grid_y != grid_y 

777 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

778 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

779 

780 # Reflection padding in GRID space (before denormalizing) 

781 # Triangle wave pattern with period 4 

782 grid_x_shifted = grid_x + 1.0 

783 grid_x_mod = grid_x_shifted % 4.0 

784 grid_x_mod = tl.where(grid_x_mod < 0, grid_x_mod + 4.0, grid_x_mod) 

785 grid_x_refl_mod = tl.where(grid_x_mod <= 2.0, grid_x_mod, 4.0 - grid_x_mod) 

786 grid_x_refl = grid_x_refl_mod - 1.0 

787 

788 grid_y_shifted = grid_y + 1.0 

789 grid_y_mod = grid_y_shifted % 4.0 

790 grid_y_mod = tl.where(grid_y_mod < 0, grid_y_mod + 4.0, grid_y_mod) 

791 grid_y_refl_mod = tl.where(grid_y_mod <= 2.0, grid_y_mod, 4.0 - grid_y_mod) 

792 grid_y_refl = grid_y_refl_mod - 1.0 

793 

794 # Denormalize to pixel space 

795 if align_corners: 

796 x = (grid_x_refl + 1.0) * (W_in - 1) / 2.0 

797 y = (grid_y_refl + 1.0) * (H_in - 1) / 2.0 

798 else: 

799 x = (grid_x_refl + 1.0) * W_in / 2.0 - 0.5 

800 y = (grid_y_refl + 1.0) * H_in / 2.0 - 0.5 

801 

802 # Find 4 corner indices 

803 x0 = tl.floor(x) 

804 y0 = tl.floor(y) 

805 x1 = x0 + 1 

806 y1 = y0 + 1 

807 

808 # Convert to int and clamp for safety 

809 x0_int = tl.cast(x0, tl.int32) 

810 y0_int = tl.cast(y0, tl.int32) 

811 x1_int = tl.cast(x1, tl.int32) 

812 y1_int = tl.cast(y1, tl.int32) 

813 

814 x0_int = tl.maximum(0, tl.minimum(x0_int, W_in - 1)) 

815 x1_int = tl.maximum(0, tl.minimum(x1_int, W_in - 1)) 

816 y0_int = tl.maximum(0, tl.minimum(y0_int, H_in - 1)) 

817 y1_int = tl.maximum(0, tl.minimum(y1_int, H_in - 1)) 

818 

819 # Compute interpolation weights 

820 wx = x - x0 

821 wy = y - y0 

822 

823 # Load 4 corner pixels 

824 input_base = n * C * H_in * W_in + c * H_in * W_in 

825 

826 offset_00 = input_base + y0_int * W_in + x0_int 

827 offset_01 = input_base + y0_int * W_in + x1_int 

828 offset_10 = input_base + y1_int * W_in + x0_int 

829 offset_11 = input_base + y1_int * W_in + x1_int 

830 

831 p00 = tl.load(ptr_input + offset_00) 

832 p01 = tl.load(ptr_input + offset_01) 

833 p10 = tl.load(ptr_input + offset_10) 

834 p11 = tl.load(ptr_input + offset_11) 

835 

836 # Bilinear interpolation 

837 top = p00 * (1.0 - wx) + p01 * wx 

838 bottom = p10 * (1.0 - wx) + p11 * wx 

839 val = tl.where(grid_x_nan | grid_y_nan, 0.0, top * (1.0 - wy) + bottom * wy) 

840 

841 # Store output 

842 output_offset = n * C * H_out * W_out + c * H_out * W_out + h_out * W_out + w_out 

843 tl.store(ptr_output + output_offset, val) 

844 

845 

846# ============================================================================ 

847# Bicubic Interpolation Kernels (4D) 

848# ============================================================================ 

849 

850 

851@libentry() 

852@triton.autotune( 

853 configs=runtime.get_tuned_config("grid_sample_2d_bicubic"), 

854 key=["N", "C", "H_out", "W_out"], 

855) 

856@triton.jit 

857def grid_sample_2d_bicubic_zeros_kernel( 

858 ptr_output, 

859 ptr_input, 

860 ptr_grid, 

861 N, 

862 C, 

863 H_in, 

864 W_in, 

865 H_out, 

866 W_out, 

867 align_corners: tl.constexpr, 

868 BLOCK_SIZE: tl.constexpr, 

869): 

870 """ 

871 Grid sample kernel for 2D bicubic interpolation with zeros padding. 

872 

873 Uses Keys' cubic kernel with a=-0.5. Loads 4x4 neighborhood (16 pixels). 

874 """ 

875 pid = tl.program_id(0) 

876 nc = pid // (H_out * W_out) 

877 hw = pid % (H_out * W_out) 

878 

879 n = nc // C 

880 c = nc % C 

881 h_out = hw // W_out 

882 w_out = hw % W_out 

883 

884 # Load grid coordinates 

885 grid_idx = n * H_out * W_out * 2 + h_out * W_out * 2 + w_out * 2 

886 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32) 

887 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32) 

888 

889 # Handle NaN 

890 grid_x_nan = grid_x != grid_x 

891 grid_y_nan = grid_y != grid_y 

892 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

893 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

894 

895 # Denormalize to pixel space 

896 if align_corners: 

897 x = (grid_x + 1.0) * (W_in - 1) / 2.0 

898 y = (grid_y + 1.0) * (H_in - 1) / 2.0 

899 else: 

900 x = (grid_x + 1.0) * W_in / 2.0 - 0.5 

901 y = (grid_y + 1.0) * H_in / 2.0 - 0.5 

902 

903 # Find 4x4 neighborhood 

904 x0 = tl.floor(x) - 1 

905 y0 = tl.floor(y) - 1 

906 

907 # Convert to int 

908 x0_int = tl.cast(x0, tl.int32) 

909 y0_int = tl.cast(y0, tl.int32) 

910 

911 # Compute interpolation weights using Keys' cubic kernel (a = -0.75) 

912 # W(x) = (a+2)|x|³ - (a+3)|x|² + 1, for |x| ≤ 1 

913 # W(x) = a|x|³ - 5a|x|² + 8a|x| - 4a, for 1 < |x| < 2 

914 # W(x) = 0, otherwise 

915 a = -0.75 

916 

917 # X weights 

918 dx0 = x0 - x 

919 wx0 = tl.abs(dx0) 

920 weight_x0 = tl.where( 

921 wx0 < 1.0, 

922 ((a + 2) * wx0 - (a + 3)) * wx0 * wx0 + 1, 

923 tl.where(wx0 < 2.0, ((wx0 - 5) * wx0 + 8) * wx0 * a - 4 * a, 0.0), 

924 ) 

925 

926 dx1 = x0 + 1 - x 

927 wx1 = tl.abs(dx1) 

928 weight_x1 = tl.where( 

929 wx1 < 1.0, 

930 ((a + 2) * wx1 - (a + 3)) * wx1 * wx1 + 1, 

931 tl.where(wx1 < 2.0, ((wx1 - 5) * wx1 + 8) * wx1 * a - 4 * a, 0.0), 

932 ) 

933 

934 dx2 = x0 + 2 - x 

935 wx2 = tl.abs(dx2) 

936 weight_x2 = tl.where( 

937 wx2 < 1.0, 

938 ((a + 2) * wx2 - (a + 3)) * wx2 * wx2 + 1, 

939 tl.where(wx2 < 2.0, ((wx2 - 5) * wx2 + 8) * wx2 * a - 4 * a, 0.0), 

940 ) 

941 

942 dx3 = x0 + 3 - x 

943 wx3 = tl.abs(dx3) 

944 weight_x3 = tl.where( 

945 wx3 < 1.0, 

946 ((a + 2) * wx3 - (a + 3)) * wx3 * wx3 + 1, 

947 tl.where(wx3 < 2.0, ((wx3 - 5) * wx3 + 8) * wx3 * a - 4 * a, 0.0), 

948 ) 

949 

950 # Y weights 

951 dy0 = y0 - y 

952 wy0 = tl.abs(dy0) 

953 weight_y0 = tl.where( 

954 wy0 < 1.0, 

955 ((a + 2) * wy0 - (a + 3)) * wy0 * wy0 + 1, 

956 tl.where(wy0 < 2.0, ((wy0 - 5) * wy0 + 8) * wy0 * a - 4 * a, 0.0), 

957 ) 

958 

959 dy1 = y0 + 1 - y 

960 wy1 = tl.abs(dy1) 

961 weight_y1 = tl.where( 

962 wy1 < 1.0, 

963 ((a + 2) * wy1 - (a + 3)) * wy1 * wy1 + 1, 

964 tl.where(wy1 < 2.0, ((wy1 - 5) * wy1 + 8) * wy1 * a - 4 * a, 0.0), 

965 ) 

966 

967 dy2 = y0 + 2 - y 

968 wy2 = tl.abs(dy2) 

969 weight_y2 = tl.where( 

970 wy2 < 1.0, 

971 ((a + 2) * wy2 - (a + 3)) * wy2 * wy2 + 1, 

972 tl.where(wy2 < 2.0, ((wy2 - 5) * wy2 + 8) * wy2 * a - 4 * a, 0.0), 

973 ) 

974 

975 dy3 = y0 + 3 - y 

976 wy3 = tl.abs(dy3) 

977 weight_y3 = tl.where( 

978 wy3 < 1.0, 

979 ((a + 2) * wy3 - (a + 3)) * wy3 * wy3 + 1, 

980 tl.where(wy3 < 2.0, ((wy3 - 5) * wy3 + 8) * wy3 * a - 4 * a, 0.0), 

981 ) 

982 

983 # Load 4x4 neighborhood with zeros padding (unrolled loop) 

984 input_base = n * C * H_in * W_in + c * H_in * W_in 

985 

986 # Initialize accumulator 

987 val = 0.0 

988 

989 # Row 0 

990 y_idx0 = y0_int 

991 y_in_bounds0 = (y_idx0 >= 0) & (y_idx0 < H_in) 

992 

993 x_idx00 = x0_int 

994 x_in_bounds00 = (x_idx00 >= 0) & (x_idx00 < W_in) 

995 offset00 = input_base + y_idx0 * W_in + x_idx00 

996 val00 = tl.load( 

997 ptr_input + offset00, 

998 mask=x_in_bounds00 & y_in_bounds0 & ~grid_x_nan & ~grid_y_nan, 

999 other=0.0, 

1000 ).to(tl.float32) 

1001 val += val00 * weight_x0 * weight_y0 

1002 

1003 x_idx01 = x0_int + 1 

1004 x_in_bounds01 = (x_idx01 >= 0) & (x_idx01 < W_in) 

1005 offset01 = input_base + y_idx0 * W_in + x_idx01 

1006 val01 = tl.load( 

1007 ptr_input + offset01, 

1008 mask=x_in_bounds01 & y_in_bounds0 & ~grid_x_nan & ~grid_y_nan, 

1009 other=0.0, 

1010 ).to(tl.float32) 

1011 val += val01 * weight_x1 * weight_y0 

1012 

1013 x_idx02 = x0_int + 2 

1014 x_in_bounds02 = (x_idx02 >= 0) & (x_idx02 < W_in) 

1015 offset02 = input_base + y_idx0 * W_in + x_idx02 

1016 val02 = tl.load( 

1017 ptr_input + offset02, 

1018 mask=x_in_bounds02 & y_in_bounds0 & ~grid_x_nan & ~grid_y_nan, 

1019 other=0.0, 

1020 ).to(tl.float32) 

1021 val += val02 * weight_x2 * weight_y0 

1022 

1023 x_idx03 = x0_int + 3 

1024 x_in_bounds03 = (x_idx03 >= 0) & (x_idx03 < W_in) 

1025 offset03 = input_base + y_idx0 * W_in + x_idx03 

1026 val03 = tl.load( 

1027 ptr_input + offset03, 

1028 mask=x_in_bounds03 & y_in_bounds0 & ~grid_x_nan & ~grid_y_nan, 

1029 other=0.0, 

1030 ).to(tl.float32) 

1031 val += val03 * weight_x3 * weight_y0 

1032 

1033 # Row 1 

1034 y_idx1 = y0_int + 1 

1035 y_in_bounds1 = (y_idx1 >= 0) & (y_idx1 < H_in) 

1036 

1037 x_idx10 = x0_int 

1038 x_in_bounds10 = (x_idx10 >= 0) & (x_idx10 < W_in) 

1039 offset10 = input_base + y_idx1 * W_in + x_idx10 

1040 val10 = tl.load( 

1041 ptr_input + offset10, 

1042 mask=x_in_bounds10 & y_in_bounds1 & ~grid_x_nan & ~grid_y_nan, 

1043 other=0.0, 

1044 ).to(tl.float32) 

1045 val += val10 * weight_x0 * weight_y1 

1046 

1047 x_idx11 = x0_int + 1 

1048 x_in_bounds11 = (x_idx11 >= 0) & (x_idx11 < W_in) 

1049 offset11 = input_base + y_idx1 * W_in + x_idx11 

1050 val11 = tl.load( 

1051 ptr_input + offset11, 

1052 mask=x_in_bounds11 & y_in_bounds1 & ~grid_x_nan & ~grid_y_nan, 

1053 other=0.0, 

1054 ).to(tl.float32) 

1055 val += val11 * weight_x1 * weight_y1 

1056 

1057 x_idx12 = x0_int + 2 

1058 x_in_bounds12 = (x_idx12 >= 0) & (x_idx12 < W_in) 

1059 offset12 = input_base + y_idx1 * W_in + x_idx12 

1060 val12 = tl.load( 

1061 ptr_input + offset12, 

1062 mask=x_in_bounds12 & y_in_bounds1 & ~grid_x_nan & ~grid_y_nan, 

1063 other=0.0, 

1064 ).to(tl.float32) 

1065 val += val12 * weight_x2 * weight_y1 

1066 

1067 x_idx13 = x0_int + 3 

1068 x_in_bounds13 = (x_idx13 >= 0) & (x_idx13 < W_in) 

1069 offset13 = input_base + y_idx1 * W_in + x_idx13 

1070 val13 = tl.load( 

1071 ptr_input + offset13, 

1072 mask=x_in_bounds13 & y_in_bounds1 & ~grid_x_nan & ~grid_y_nan, 

1073 other=0.0, 

1074 ).to(tl.float32) 

1075 val += val13 * weight_x3 * weight_y1 

1076 

1077 # Row 2 

1078 y_idx2 = y0_int + 2 

1079 y_in_bounds2 = (y_idx2 >= 0) & (y_idx2 < H_in) 

1080 

1081 x_idx20 = x0_int 

1082 x_in_bounds20 = (x_idx20 >= 0) & (x_idx20 < W_in) 

1083 offset20 = input_base + y_idx2 * W_in + x_idx20 

1084 val20 = tl.load( 

1085 ptr_input + offset20, 

1086 mask=x_in_bounds20 & y_in_bounds2 & ~grid_x_nan & ~grid_y_nan, 

1087 other=0.0, 

1088 ).to(tl.float32) 

1089 val += val20 * weight_x0 * weight_y2 

1090 

1091 x_idx21 = x0_int + 1 

1092 x_in_bounds21 = (x_idx21 >= 0) & (x_idx21 < W_in) 

1093 offset21 = input_base + y_idx2 * W_in + x_idx21 

1094 val21 = tl.load( 

1095 ptr_input + offset21, 

1096 mask=x_in_bounds21 & y_in_bounds2 & ~grid_x_nan & ~grid_y_nan, 

1097 other=0.0, 

1098 ).to(tl.float32) 

1099 val += val21 * weight_x1 * weight_y2 

1100 

1101 x_idx22 = x0_int + 2 

1102 x_in_bounds22 = (x_idx22 >= 0) & (x_idx22 < W_in) 

1103 offset22 = input_base + y_idx2 * W_in + x_idx22 

1104 val22 = tl.load( 

1105 ptr_input + offset22, 

1106 mask=x_in_bounds22 & y_in_bounds2 & ~grid_x_nan & ~grid_y_nan, 

1107 other=0.0, 

1108 ).to(tl.float32) 

1109 val += val22 * weight_x2 * weight_y2 

1110 

1111 x_idx23 = x0_int + 3 

1112 x_in_bounds23 = (x_idx23 >= 0) & (x_idx23 < W_in) 

1113 offset23 = input_base + y_idx2 * W_in + x_idx23 

1114 val23 = tl.load( 

1115 ptr_input + offset23, 

1116 mask=x_in_bounds23 & y_in_bounds2 & ~grid_x_nan & ~grid_y_nan, 

1117 other=0.0, 

1118 ).to(tl.float32) 

1119 val += val23 * weight_x3 * weight_y2 

1120 

1121 # Row 3 

1122 y_idx3 = y0_int + 3 

1123 y_in_bounds3 = (y_idx3 >= 0) & (y_idx3 < H_in) 

1124 

1125 x_idx30 = x0_int 

1126 x_in_bounds30 = (x_idx30 >= 0) & (x_idx30 < W_in) 

1127 offset30 = input_base + y_idx3 * W_in + x_idx30 

1128 val30 = tl.load( 

1129 ptr_input + offset30, 

1130 mask=x_in_bounds30 & y_in_bounds3 & ~grid_x_nan & ~grid_y_nan, 

1131 other=0.0, 

1132 ).to(tl.float32) 

1133 val += val30 * weight_x0 * weight_y3 

1134 

1135 x_idx31 = x0_int + 1 

1136 x_in_bounds31 = (x_idx31 >= 0) & (x_idx31 < W_in) 

1137 offset31 = input_base + y_idx3 * W_in + x_idx31 

1138 val31 = tl.load( 

1139 ptr_input + offset31, 

1140 mask=x_in_bounds31 & y_in_bounds3 & ~grid_x_nan & ~grid_y_nan, 

1141 other=0.0, 

1142 ).to(tl.float32) 

1143 val += val31 * weight_x1 * weight_y3 

1144 

1145 x_idx32 = x0_int + 2 

1146 x_in_bounds32 = (x_idx32 >= 0) & (x_idx32 < W_in) 

1147 offset32 = input_base + y_idx3 * W_in + x_idx32 

1148 val32 = tl.load( 

1149 ptr_input + offset32, 

1150 mask=x_in_bounds32 & y_in_bounds3 & ~grid_x_nan & ~grid_y_nan, 

1151 other=0.0, 

1152 ).to(tl.float32) 

1153 val += val32 * weight_x2 * weight_y3 

1154 

1155 x_idx33 = x0_int + 3 

1156 x_in_bounds33 = (x_idx33 >= 0) & (x_idx33 < W_in) 

1157 offset33 = input_base + y_idx3 * W_in + x_idx33 

1158 val33 = tl.load( 

1159 ptr_input + offset33, 

1160 mask=x_in_bounds33 & y_in_bounds3 & ~grid_x_nan & ~grid_y_nan, 

1161 other=0.0, 

1162 ).to(tl.float32) 

1163 val += val33 * weight_x3 * weight_y3 

1164 

1165 # Store output 

1166 output_offset = n * C * H_out * W_out + c * H_out * W_out + h_out * W_out + w_out 

1167 tl.store(ptr_output + output_offset, val) 

1168 

1169 

1170@libentry() 

1171@triton.autotune( 

1172 configs=runtime.get_tuned_config("grid_sample_2d_bicubic"), 

1173 key=["N", "C", "H_out", "W_out"], 

1174) 

1175@triton.jit 

1176def grid_sample_2d_bicubic_border_kernel( 

1177 ptr_output, 

1178 ptr_input, 

1179 ptr_grid, 

1180 N, 

1181 C, 

1182 H_in, 

1183 W_in, 

1184 H_out, 

1185 W_out, 

1186 align_corners: tl.constexpr, 

1187 BLOCK_SIZE: tl.constexpr, 

1188): 

1189 """ 

1190 Grid sample kernel for 2D bicubic interpolation with border padding. 

1191 """ 

1192 pid = tl.program_id(0) 

1193 nc = pid // (H_out * W_out) 

1194 hw = pid % (H_out * W_out) 

1195 

1196 n = nc // C 

1197 c = nc % C 

1198 h_out = hw // W_out 

1199 w_out = hw % W_out 

1200 

1201 # Load grid coordinates 

1202 grid_idx = n * H_out * W_out * 2 + h_out * W_out * 2 + w_out * 2 

1203 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32) 

1204 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32) 

1205 

1206 # Handle NaN 

1207 grid_x_nan = grid_x != grid_x 

1208 grid_y_nan = grid_y != grid_y 

1209 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

1210 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

1211 

1212 # Denormalize to pixel space 

1213 if align_corners: 

1214 x = (grid_x + 1.0) * (W_in - 1) / 2.0 

1215 y = (grid_y + 1.0) * (H_in - 1) / 2.0 

1216 else: 

1217 x = (grid_x + 1.0) * W_in / 2.0 - 0.5 

1218 y = (grid_y + 1.0) * H_in / 2.0 - 0.5 

1219 

1220 # Find 4x4 neighborhood 

1221 x0 = tl.floor(x) - 1 

1222 y0 = tl.floor(y) - 1 

1223 x0_int = tl.cast(x0, tl.int32) 

1224 y0_int = tl.cast(y0, tl.int32) 

1225 

1226 # Compute Keys' cubic weights (a = -0.75) 

1227 a = -0.75 

1228 

1229 # X weights - compute inline for each pixel 

1230 # W(x) = (a+2)|x|³ - (a+3)|x|² + 1, for |x| ≤ 1 

1231 # W(x) = a|x|³ - 5a|x|² + 8a|x| - 4a, for 1 < |x| < 2 

1232 

1233 # Load 4x4 neighborhood with border padding 

1234 input_base = n * C * H_in * W_in + c * H_in * W_in 

1235 val = 0.0 

1236 

1237 # Unrolled loop for 4x4 neighborhood 

1238 # Row 0 

1239 y_idx = y0_int 

1240 y_idx_clamped = tl.maximum(0, tl.minimum(y_idx, H_in - 1)) 

1241 dy0 = y0 - y 

1242 wy0 = tl.abs(dy0) 

1243 weight_y0 = tl.where( 

1244 wy0 < 1.0, 

1245 ((a + 2) * wy0 - (a + 3)) * wy0 * wy0 + 1, 

1246 tl.where(wy0 < 2.0, ((wy0 - 5) * wy0 + 8) * wy0 * a - 4 * a, 0.0), 

1247 ) 

1248 

1249 # Col 0 

1250 x_idx = x0_int 

1251 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1)) 

1252 dx0 = x0 - x 

1253 wx0 = tl.abs(dx0) 

1254 weight_x0 = tl.where( 

1255 wx0 < 1.0, 

1256 ((a + 2) * wx0 - (a + 3)) * wx0 * wx0 + 1, 

1257 tl.where(wx0 < 2.0, ((wx0 - 5) * wx0 + 8) * wx0 * a - 4 * a, 0.0), 

1258 ) 

1259 offset = input_base + y_idx_clamped * W_in + x_idx_clamped 

1260 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x0 * weight_y0 

1261 

1262 # Col 1 

1263 x_idx = x0_int + 1 

1264 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1)) 

1265 dx1 = x0 + 1 - x 

1266 wx1 = tl.abs(dx1) 

1267 weight_x1 = tl.where( 

1268 wx1 < 1.0, 

1269 ((a + 2) * wx1 - (a + 3)) * wx1 * wx1 + 1, 

1270 tl.where(wx1 < 2.0, ((wx1 - 5) * wx1 + 8) * wx1 * a - 4 * a, 0.0), 

1271 ) 

1272 offset = input_base + y_idx_clamped * W_in + x_idx_clamped 

1273 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x1 * weight_y0 

1274 

1275 # Col 2 

1276 x_idx = x0_int + 2 

1277 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1)) 

1278 dx2 = x0 + 2 - x 

1279 wx2 = tl.abs(dx2) 

1280 weight_x2 = tl.where( 

1281 wx2 < 1.0, 

1282 ((a + 2) * wx2 - (a + 3)) * wx2 * wx2 + 1, 

1283 tl.where(wx2 < 2.0, ((wx2 - 5) * wx2 + 8) * wx2 * a - 4 * a, 0.0), 

1284 ) 

1285 offset = input_base + y_idx_clamped * W_in + x_idx_clamped 

1286 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x2 * weight_y0 

1287 

1288 # Col 3 

1289 x_idx = x0_int + 3 

1290 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1)) 

1291 dx3 = x0 + 3 - x 

1292 wx3 = tl.abs(dx3) 

1293 weight_x3 = tl.where( 

1294 wx3 < 1.0, 

1295 ((a + 2) * wx3 - (a + 3)) * wx3 * wx3 + 1, 

1296 tl.where(wx3 < 2.0, ((wx3 - 5) * wx3 + 8) * wx3 * a - 4 * a, 0.0), 

1297 ) 

1298 offset = input_base + y_idx_clamped * W_in + x_idx_clamped 

1299 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x3 * weight_y0 

1300 

1301 # Row 1 

1302 y_idx = y0_int + 1 

1303 y_idx_clamped = tl.maximum(0, tl.minimum(y_idx, H_in - 1)) 

1304 dy1 = y0 + 1 - y 

1305 wy1 = tl.abs(dy1) 

1306 weight_y1 = tl.where( 

1307 wy1 < 1.0, 

1308 ((a + 2) * wy1 - (a + 3)) * wy1 * wy1 + 1, 

1309 tl.where(wy1 < 2.0, ((wy1 - 5) * wy1 + 8) * wy1 * a - 4 * a, 0.0), 

1310 ) 

1311 

1312 x_idx = x0_int 

1313 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1)) 

1314 offset = input_base + y_idx_clamped * W_in + x_idx_clamped 

1315 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x0 * weight_y1 

1316 

1317 x_idx = x0_int + 1 

1318 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1)) 

1319 offset = input_base + y_idx_clamped * W_in + x_idx_clamped 

1320 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x1 * weight_y1 

1321 

1322 x_idx = x0_int + 2 

1323 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1)) 

1324 offset = input_base + y_idx_clamped * W_in + x_idx_clamped 

1325 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x2 * weight_y1 

1326 

1327 x_idx = x0_int + 3 

1328 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1)) 

1329 offset = input_base + y_idx_clamped * W_in + x_idx_clamped 

1330 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x3 * weight_y1 

1331 

1332 # Row 2 

1333 y_idx = y0_int + 2 

1334 y_idx_clamped = tl.maximum(0, tl.minimum(y_idx, H_in - 1)) 

1335 dy2 = y0 + 2 - y 

1336 wy2 = tl.abs(dy2) 

1337 weight_y2 = tl.where( 

1338 wy2 < 1.0, 

1339 ((a + 2) * wy2 - (a + 3)) * wy2 * wy2 + 1, 

1340 tl.where(wy2 < 2.0, ((wy2 - 5) * wy2 + 8) * wy2 * a - 4 * a, 0.0), 

1341 ) 

1342 

1343 x_idx = x0_int 

1344 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1)) 

1345 offset = input_base + y_idx_clamped * W_in + x_idx_clamped 

1346 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x0 * weight_y2 

1347 

1348 x_idx = x0_int + 1 

1349 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1)) 

1350 offset = input_base + y_idx_clamped * W_in + x_idx_clamped 

1351 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x1 * weight_y2 

1352 

1353 x_idx = x0_int + 2 

1354 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1)) 

1355 offset = input_base + y_idx_clamped * W_in + x_idx_clamped 

1356 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x2 * weight_y2 

1357 

1358 x_idx = x0_int + 3 

1359 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1)) 

1360 offset = input_base + y_idx_clamped * W_in + x_idx_clamped 

1361 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x3 * weight_y2 

1362 

1363 # Row 3 

1364 y_idx = y0_int + 3 

1365 y_idx_clamped = tl.maximum(0, tl.minimum(y_idx, H_in - 1)) 

1366 dy3 = y0 + 3 - y 

1367 wy3 = tl.abs(dy3) 

1368 weight_y3 = tl.where( 

1369 wy3 < 1.0, 

1370 ((a + 2) * wy3 - (a + 3)) * wy3 * wy3 + 1, 

1371 tl.where(wy3 < 2.0, ((wy3 - 5) * wy3 + 8) * wy3 * a - 4 * a, 0.0), 

1372 ) 

1373 

1374 x_idx = x0_int 

1375 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1)) 

1376 offset = input_base + y_idx_clamped * W_in + x_idx_clamped 

1377 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x0 * weight_y3 

1378 

1379 x_idx = x0_int + 1 

1380 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1)) 

1381 offset = input_base + y_idx_clamped * W_in + x_idx_clamped 

1382 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x1 * weight_y3 

1383 

1384 x_idx = x0_int + 2 

1385 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1)) 

1386 offset = input_base + y_idx_clamped * W_in + x_idx_clamped 

1387 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x2 * weight_y3 

1388 

1389 x_idx = x0_int + 3 

1390 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1)) 

1391 offset = input_base + y_idx_clamped * W_in + x_idx_clamped 

1392 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x3 * weight_y3 

1393 

1394 # Handle NaN 

1395 val = tl.where(grid_x_nan | grid_y_nan, 0.0, val) 

1396 

1397 # Store output 

1398 output_offset = n * C * H_out * W_out + c * H_out * W_out + h_out * W_out + w_out 

1399 tl.store(ptr_output + output_offset, val) 

1400 

1401 

1402@libentry() 

1403@triton.autotune( 

1404 configs=runtime.get_tuned_config("grid_sample_2d_bicubic"), 

1405 key=["N", "C", "H_out", "W_out"], 

1406) 

1407@triton.jit 

1408def grid_sample_2d_bicubic_reflection_kernel( 

1409 ptr_output, 

1410 ptr_input, 

1411 ptr_grid, 

1412 N, 

1413 C, 

1414 H_in, 

1415 W_in, 

1416 H_out, 

1417 W_out, 

1418 align_corners: tl.constexpr, 

1419 BLOCK_SIZE: tl.constexpr, 

1420): 

1421 """ 

1422 Grid sample kernel for 2D bicubic interpolation with reflection padding. 

1423 """ 

1424 pid = tl.program_id(0) 

1425 nc = pid // (H_out * W_out) 

1426 hw = pid % (H_out * W_out) 

1427 

1428 n = nc // C 

1429 c = nc % C 

1430 h_out = hw // W_out 

1431 w_out = hw % W_out 

1432 

1433 # Load grid coordinates 

1434 grid_idx = n * H_out * W_out * 2 + h_out * W_out * 2 + w_out * 2 

1435 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32) 

1436 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32) 

1437 

1438 # Handle NaN 

1439 grid_x_nan = grid_x != grid_x 

1440 grid_y_nan = grid_y != grid_y 

1441 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

1442 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

1443 

1444 # Reflection padding in GRID space 

1445 grid_x_shifted = grid_x + 1.0 

1446 grid_x_mod = grid_x_shifted % 4.0 

1447 grid_x_mod = tl.where(grid_x_mod < 0, grid_x_mod + 4.0, grid_x_mod) 

1448 grid_x_refl_mod = tl.where(grid_x_mod <= 2.0, grid_x_mod, 4.0 - grid_x_mod) 

1449 grid_x_refl = grid_x_refl_mod - 1.0 

1450 

1451 grid_y_shifted = grid_y + 1.0 

1452 grid_y_mod = grid_y_shifted % 4.0 

1453 grid_y_mod = tl.where(grid_y_mod < 0, grid_y_mod + 4.0, grid_y_mod) 

1454 grid_y_refl_mod = tl.where(grid_y_mod <= 2.0, grid_y_mod, 4.0 - grid_y_mod) 

1455 grid_y_refl = grid_y_refl_mod - 1.0 

1456 

1457 # Denormalize to pixel space 

1458 if align_corners: 

1459 x = (grid_x_refl + 1.0) * (W_in - 1) / 2.0 

1460 y = (grid_y_refl + 1.0) * (H_in - 1) / 2.0 

1461 else: 

1462 x = (grid_x_refl + 1.0) * W_in / 2.0 - 0.5 

1463 y = (grid_y_refl + 1.0) * H_in / 2.0 - 0.5 

1464 

1465 # Find 4x4 neighborhood 

1466 x0 = tl.floor(x) - 1 

1467 y0 = tl.floor(y) - 1 

1468 x0_int = tl.cast(x0, tl.int32) 

1469 y0_int = tl.cast(y0, tl.int32) 

1470 

1471 # Clamp for safety 

1472 x0_int = tl.maximum(0, tl.minimum(x0_int, W_in - 1)) 

1473 y0_int = tl.maximum(0, tl.minimum(y0_int, H_in - 1)) 

1474 

1475 # Compute Keys' cubic weights (a = -0.75) 

1476 a = -0.75 

1477 

1478 # Pre-compute X weights 

1479 dx0 = x0 - x 

1480 wx0 = tl.abs(dx0) 

1481 weight_x0 = tl.where( 

1482 wx0 < 1.0, 

1483 ((a + 2) * wx0 - (a + 3)) * wx0 * wx0 + 1, 

1484 tl.where(wx0 < 2.0, ((wx0 - 5) * wx0 + 8) * wx0 * a - 4 * a, 0.0), 

1485 ) 

1486 

1487 dx1 = x0 + 1 - x 

1488 wx1 = tl.abs(dx1) 

1489 weight_x1 = tl.where( 

1490 wx1 < 1.0, 

1491 ((a + 2) * wx1 - (a + 3)) * wx1 * wx1 + 1, 

1492 tl.where(wx1 < 2.0, ((wx1 - 5) * wx1 + 8) * wx1 * a - 4 * a, 0.0), 

1493 ) 

1494 

1495 dx2 = x0 + 2 - x 

1496 wx2 = tl.abs(dx2) 

1497 weight_x2 = tl.where( 

1498 wx2 < 1.0, 

1499 ((a + 2) * wx2 - (a + 3)) * wx2 * wx2 + 1, 

1500 tl.where(wx2 < 2.0, ((wx2 - 5) * wx2 + 8) * wx2 * a - 4 * a, 0.0), 

1501 ) 

1502 

1503 dx3 = x0 + 3 - x 

1504 wx3 = tl.abs(dx3) 

1505 weight_x3 = tl.where( 

1506 wx3 < 1.0, 

1507 ((a + 2) * wx3 - (a + 3)) * wx3 * wx3 + 1, 

1508 tl.where(wx3 < 2.0, ((wx3 - 5) * wx3 + 8) * wx3 * a - 4 * a, 0.0), 

1509 ) 

1510 

1511 # Pre-compute Y weights 

1512 dy0 = y0 - y 

1513 wy0 = tl.abs(dy0) 

1514 weight_y0 = tl.where( 

1515 wy0 < 1.0, 

1516 ((a + 2) * wy0 - (a + 3)) * wy0 * wy0 + 1, 

1517 tl.where(wy0 < 2.0, ((wy0 - 5) * wy0 + 8) * wy0 * a - 4 * a, 0.0), 

1518 ) 

1519 

1520 dy1 = y0 + 1 - y 

1521 wy1 = tl.abs(dy1) 

1522 weight_y1 = tl.where( 

1523 wy1 < 1.0, 

1524 ((a + 2) * wy1 - (a + 3)) * wy1 * wy1 + 1, 

1525 tl.where(wy1 < 2.0, ((wy1 - 5) * wy1 + 8) * wy1 * a - 4 * a, 0.0), 

1526 ) 

1527 

1528 dy2 = y0 + 2 - y 

1529 wy2 = tl.abs(dy2) 

1530 weight_y2 = tl.where( 

1531 wy2 < 1.0, 

1532 ((a + 2) * wy2 - (a + 3)) * wy2 * wy2 + 1, 

1533 tl.where(wy2 < 2.0, ((wy2 - 5) * wy2 + 8) * wy2 * a - 4 * a, 0.0), 

1534 ) 

1535 

1536 dy3 = y0 + 3 - y 

1537 wy3 = tl.abs(dy3) 

1538 weight_y3 = tl.where( 

1539 wy3 < 1.0, 

1540 ((a + 2) * wy3 - (a + 3)) * wy3 * wy3 + 1, 

1541 tl.where(wy3 < 2.0, ((wy3 - 5) * wy3 + 8) * wy3 * a - 4 * a, 0.0), 

1542 ) 

1543 

1544 # Load 4x4 neighborhood with clamping (reflection already applied) 

1545 input_base = n * C * H_in * W_in + c * H_in * W_in 

1546 val = 0.0 

1547 

1548 # Unrolled loops for 4x4 neighborhood 

1549 for i in range(4): 

1550 y_idx = y0_int + i 

1551 y_idx_clamped = tl.maximum(0, tl.minimum(y_idx, H_in - 1)) 

1552 weight_y = tl.where( 

1553 i == 0, 

1554 weight_y0, 

1555 tl.where(i == 1, weight_y1, tl.where(i == 2, weight_y2, weight_y3)), 

1556 ) 

1557 

1558 for j in range(4): 

1559 x_idx = x0_int + j 

1560 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1)) 

1561 weight_x = tl.where( 

1562 j == 0, 

1563 weight_x0, 

1564 tl.where(j == 1, weight_x1, tl.where(j == 2, weight_x2, weight_x3)), 

1565 ) 

1566 

1567 offset = input_base + y_idx_clamped * W_in + x_idx_clamped 

1568 pixel_val = tl.load(ptr_input + offset).to(tl.float32) 

1569 val += pixel_val * weight_x * weight_y 

1570 

1571 # Handle NaN 

1572 val = tl.where(grid_x_nan | grid_y_nan, 0.0, val) 

1573 

1574 # Store output 

1575 output_offset = n * C * H_out * W_out + c * H_out * W_out + h_out * W_out + w_out 

1576 tl.store(ptr_output + output_offset, val) 

1577 

1578 

1579# ============================================================================ 

1580# 5D Support Kernels (Volumetric Data) 

1581# ============================================================================ 

1582 

1583 

1584@libentry() 

1585@triton.autotune( 

1586 configs=runtime.get_tuned_config("grid_sample_3d_nearest"), 

1587 key=["N", "C", "D_out", "H_out", "W_out"], 

1588) 

1589@triton.jit 

1590def grid_sample_3d_nearest_zeros_kernel( 

1591 ptr_output, 

1592 ptr_input, 

1593 ptr_grid, 

1594 N, 

1595 C, 

1596 D_in, 

1597 H_in, 

1598 W_in, 

1599 D_out, 

1600 H_out, 

1601 W_out, 

1602 align_corners: tl.constexpr, 

1603 BLOCK_SIZE: tl.constexpr, 

1604): 

1605 """ 

1606 Grid sample kernel for 3D nearest neighbor interpolation with zeros padding. 

1607 Handles 5D input (N, C, D_in, H_in, W_in) and 5D grid (N, D_out, H_out, W_out, 3). 

1608 """ 

1609 pid = tl.program_id(0) 

1610 ncd = pid // (D_out * H_out * W_out) 

1611 dhw = pid % (D_out * H_out * W_out) 

1612 

1613 n = ncd // C 

1614 c = ncd % C 

1615 d_out = dhw // (H_out * W_out) 

1616 hw = dhw % (H_out * W_out) 

1617 h_out = hw // W_out 

1618 w_out = hw % W_out 

1619 

1620 # Load 3D grid coordinates 

1621 grid_idx = ( 

1622 n * D_out * H_out * W_out * 3 

1623 + d_out * H_out * W_out * 3 

1624 + h_out * W_out * 3 

1625 + w_out * 3 

1626 ) 

1627 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32) 

1628 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32) 

1629 grid_z = tl.load(ptr_grid + grid_idx + 2).to(tl.float32) 

1630 

1631 # Handle NaN 

1632 grid_x_nan = grid_x != grid_x 

1633 grid_y_nan = grid_y != grid_y 

1634 grid_z_nan = grid_z != grid_z 

1635 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

1636 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

1637 grid_z = tl.where(grid_z_nan, -2.0, grid_z) 

1638 

1639 # Denormalize to pixel space 

1640 if align_corners: 

1641 x = (grid_x + 1.0) * (W_in - 1) / 2.0 

1642 y = (grid_y + 1.0) * (H_in - 1) / 2.0 

1643 z = (grid_z + 1.0) * (D_in - 1) / 2.0 

1644 else: 

1645 x = (grid_x + 1.0) * W_in / 2.0 - 0.5 

1646 y = (grid_y + 1.0) * H_in / 2.0 - 0.5 

1647 z = (grid_z + 1.0) * D_in / 2.0 - 0.5 

1648 

1649 # Banker's rounding for all three coordinates 

1650 x_floor = tl.floor(x) 

1651 y_floor = tl.floor(y) 

1652 z_floor = tl.floor(z) 

1653 x_frac = x - x_floor 

1654 y_frac = y - y_floor 

1655 z_frac = z - z_floor 

1656 x_is_half = x_frac == 0.5 

1657 y_is_half = y_frac == 0.5 

1658 z_is_half = z_frac == 0.5 

1659 x_floor_int = tl.cast(x_floor, tl.int32) 

1660 y_floor_int = tl.cast(y_floor, tl.int32) 

1661 z_floor_int = tl.cast(z_floor, tl.int32) 

1662 x_is_even = x_floor_int % 2 == 0 

1663 y_is_even = y_floor_int % 2 == 0 

1664 z_is_even = z_floor_int % 2 == 0 

1665 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1) 

1666 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1) 

1667 z_round = tl.where(z_frac < 0.5, z_floor, z_floor + 1) 

1668 x_idx = tl.cast( 

1669 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round), 

1670 tl.int32, 

1671 ) 

1672 y_idx = tl.cast( 

1673 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round), 

1674 tl.int32, 

1675 ) 

1676 z_idx = tl.cast( 

1677 tl.where(z_is_half, tl.where(z_is_even, z_floor, z_floor + 1), z_round), 

1678 tl.int32, 

1679 ) 

1680 

1681 # Check bounds for 3D 

1682 mask = ( 

1683 (x_idx >= 0) 

1684 & (x_idx < W_in) 

1685 & (y_idx >= 0) 

1686 & (y_idx < H_in) 

1687 & (z_idx >= 0) 

1688 & (z_idx < D_in) 

1689 & ~grid_x_nan 

1690 & ~grid_y_nan 

1691 & ~grid_z_nan 

1692 ) 

1693 

1694 # Load input pixel (5D tensor: N, C, D, H, W) 

1695 input_offset = ( 

1696 n * C * D_in * H_in * W_in 

1697 + c * D_in * H_in * W_in 

1698 + z_idx * H_in * W_in 

1699 + y_idx * W_in 

1700 + x_idx 

1701 ) 

1702 val = tl.load(ptr_input + input_offset, mask=mask, other=0.0).to(tl.float32) 

1703 

1704 # Store output (5D tensor: N, C, D, H, W) 

1705 output_offset = ( 

1706 n * C * D_out * H_out * W_out 

1707 + c * D_out * H_out * W_out 

1708 + d_out * H_out * W_out 

1709 + h_out * W_out 

1710 + w_out 

1711 ) 

1712 tl.store(ptr_output + output_offset, val) 

1713 

1714 

1715@libentry() 

1716@triton.autotune( 

1717 configs=runtime.get_tuned_config("grid_sample_3d_nearest"), 

1718 key=["N", "C", "D_out", "H_out", "W_out"], 

1719) 

1720@triton.jit 

1721def grid_sample_3d_nearest_border_kernel( 

1722 ptr_output, 

1723 ptr_input, 

1724 ptr_grid, 

1725 N, 

1726 C, 

1727 D_in, 

1728 H_in, 

1729 W_in, 

1730 D_out, 

1731 H_out, 

1732 W_out, 

1733 align_corners: tl.constexpr, 

1734 BLOCK_SIZE: tl.constexpr, 

1735): 

1736 """ 

1737 Grid sample kernel for 3D nearest neighbor interpolation with border padding. 

1738 """ 

1739 pid = tl.program_id(0) 

1740 ncd = pid // (D_out * H_out * W_out) 

1741 dhw = pid % (D_out * H_out * W_out) 

1742 

1743 n = ncd // C 

1744 c = ncd % C 

1745 d_out = dhw // (H_out * W_out) 

1746 hw = dhw % (H_out * W_out) 

1747 h_out = hw // W_out 

1748 w_out = hw % W_out 

1749 

1750 # Load 3D grid coordinates 

1751 grid_idx = ( 

1752 n * D_out * H_out * W_out * 3 

1753 + d_out * H_out * W_out * 3 

1754 + h_out * W_out * 3 

1755 + w_out * 3 

1756 ) 

1757 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32) 

1758 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32) 

1759 grid_z = tl.load(ptr_grid + grid_idx + 2).to(tl.float32) 

1760 

1761 # Handle NaN 

1762 grid_x_nan = grid_x != grid_x 

1763 grid_y_nan = grid_y != grid_y 

1764 grid_z_nan = grid_z != grid_z 

1765 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

1766 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

1767 grid_z = tl.where(grid_z_nan, -2.0, grid_z) 

1768 

1769 # Denormalize to pixel space 

1770 if align_corners: 

1771 x = (grid_x + 1.0) * (W_in - 1) / 2.0 

1772 y = (grid_y + 1.0) * (H_in - 1) / 2.0 

1773 z = (grid_z + 1.0) * (D_in - 1) / 2.0 

1774 else: 

1775 x = (grid_x + 1.0) * W_in / 2.0 - 0.5 

1776 y = (grid_y + 1.0) * H_in / 2.0 - 0.5 

1777 z = (grid_z + 1.0) * D_in / 2.0 - 0.5 

1778 

1779 # Banker's rounding 

1780 x_floor = tl.floor(x) 

1781 y_floor = tl.floor(y) 

1782 z_floor = tl.floor(z) 

1783 x_frac = x - x_floor 

1784 y_frac = y - y_floor 

1785 z_frac = z - z_floor 

1786 x_is_half = x_frac == 0.5 

1787 y_is_half = y_frac == 0.5 

1788 z_is_half = z_frac == 0.5 

1789 x_floor_int = tl.cast(x_floor, tl.int32) 

1790 y_floor_int = tl.cast(y_floor, tl.int32) 

1791 z_floor_int = tl.cast(z_floor, tl.int32) 

1792 x_is_even = x_floor_int % 2 == 0 

1793 y_is_even = y_floor_int % 2 == 0 

1794 z_is_even = z_floor_int % 2 == 0 

1795 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1) 

1796 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1) 

1797 z_round = tl.where(z_frac < 0.5, z_floor, z_floor + 1) 

1798 x_idx = tl.cast( 

1799 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round), 

1800 tl.int32, 

1801 ) 

1802 y_idx = tl.cast( 

1803 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round), 

1804 tl.int32, 

1805 ) 

1806 z_idx = tl.cast( 

1807 tl.where(z_is_half, tl.where(z_is_even, z_floor, z_floor + 1), z_round), 

1808 tl.int32, 

1809 ) 

1810 

1811 # Clamp to valid bounds (border padding) 

1812 x_idx = tl.maximum(0, tl.minimum(x_idx, W_in - 1)) 

1813 y_idx = tl.maximum(0, tl.minimum(y_idx, H_in - 1)) 

1814 z_idx = tl.maximum(0, tl.minimum(z_idx, D_in - 1)) 

1815 

1816 # Load input pixel 

1817 val = tl.where( 

1818 grid_x_nan | grid_y_nan | grid_z_nan, 

1819 0.0, 

1820 tl.load( 

1821 ptr_input 

1822 + n * C * D_in * H_in * W_in 

1823 + c * D_in * H_in * W_in 

1824 + z_idx * H_in * W_in 

1825 + y_idx * W_in 

1826 + x_idx 

1827 ).to(tl.float32), 

1828 ) 

1829 

1830 # Store output 

1831 output_offset = ( 

1832 n * C * D_out * H_out * W_out 

1833 + c * D_out * H_out * W_out 

1834 + d_out * H_out * W_out 

1835 + h_out * W_out 

1836 + w_out 

1837 ) 

1838 tl.store(ptr_output + output_offset, val) 

1839 

1840 

1841@libentry() 

1842@triton.autotune( 

1843 configs=runtime.get_tuned_config("grid_sample_3d_nearest"), 

1844 key=["N", "C", "D_out", "H_out", "W_out"], 

1845) 

1846@triton.jit 

1847def grid_sample_3d_nearest_reflection_kernel( 

1848 ptr_output, 

1849 ptr_input, 

1850 ptr_grid, 

1851 N, 

1852 C, 

1853 D_in, 

1854 H_in, 

1855 W_in, 

1856 D_out, 

1857 H_out, 

1858 W_out, 

1859 align_corners: tl.constexpr, 

1860 BLOCK_SIZE: tl.constexpr, 

1861): 

1862 """ 

1863 Grid sample kernel for 3D nearest neighbor interpolation with reflection padding. 

1864 """ 

1865 pid = tl.program_id(0) 

1866 ncd = pid // (D_out * H_out * W_out) 

1867 dhw = pid % (D_out * H_out * W_out) 

1868 

1869 n = ncd // C 

1870 c = ncd % C 

1871 d_out = dhw // (H_out * W_out) 

1872 hw = dhw % (H_out * W_out) 

1873 h_out = hw // W_out 

1874 w_out = hw % W_out 

1875 

1876 # Load 3D grid coordinates 

1877 grid_idx = ( 

1878 n * D_out * H_out * W_out * 3 

1879 + d_out * H_out * W_out * 3 

1880 + h_out * W_out * 3 

1881 + w_out * 3 

1882 ) 

1883 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32) 

1884 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32) 

1885 grid_z = tl.load(ptr_grid + grid_idx + 2).to(tl.float32) 

1886 

1887 # Handle NaN 

1888 grid_x_nan = grid_x != grid_x 

1889 grid_y_nan = grid_y != grid_y 

1890 grid_z_nan = grid_z != grid_z 

1891 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

1892 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

1893 grid_z = tl.where(grid_z_nan, -2.0, grid_z) 

1894 

1895 # Reflection padding in GRID space (triangle wave with period 4) 

1896 grid_x_shifted = grid_x + 1.0 

1897 grid_x_mod = grid_x_shifted % 4.0 

1898 grid_x_mod = tl.where(grid_x_mod < 0, grid_x_mod + 4.0, grid_x_mod) 

1899 grid_x_refl_mod = tl.where(grid_x_mod <= 2.0, grid_x_mod, 4.0 - grid_x_mod) 

1900 grid_x_refl = grid_x_refl_mod - 1.0 

1901 

1902 grid_y_shifted = grid_y + 1.0 

1903 grid_y_mod = grid_y_shifted % 4.0 

1904 grid_y_mod = tl.where(grid_y_mod < 0, grid_y_mod + 4.0, grid_y_mod) 

1905 grid_y_refl_mod = tl.where(grid_y_mod <= 2.0, grid_y_mod, 4.0 - grid_y_mod) 

1906 grid_y_refl = grid_y_refl_mod - 1.0 

1907 

1908 grid_z_shifted = grid_z + 1.0 

1909 grid_z_mod = grid_z_shifted % 4.0 

1910 grid_z_mod = tl.where(grid_z_mod < 0, grid_z_mod + 4.0, grid_z_mod) 

1911 grid_z_refl_mod = tl.where(grid_z_mod <= 2.0, grid_z_mod, 4.0 - grid_z_mod) 

1912 grid_z_refl = grid_z_refl_mod - 1.0 

1913 

1914 # Denormalize to pixel space 

1915 if align_corners: 

1916 x = (grid_x_refl + 1.0) * (W_in - 1) / 2.0 

1917 y = (grid_y_refl + 1.0) * (H_in - 1) / 2.0 

1918 z = (grid_z_refl + 1.0) * (D_in - 1) / 2.0 

1919 else: 

1920 x = (grid_x_refl + 1.0) * W_in / 2.0 - 0.5 

1921 y = (grid_y_refl + 1.0) * H_in / 2.0 - 0.5 

1922 z = (grid_z_refl + 1.0) * D_in / 2.0 - 0.5 

1923 

1924 # Banker's rounding 

1925 x_floor = tl.floor(x) 

1926 y_floor = tl.floor(y) 

1927 z_floor = tl.floor(z) 

1928 x_frac = x - x_floor 

1929 y_frac = y - y_floor 

1930 z_frac = z - z_floor 

1931 x_is_half = x_frac == 0.5 

1932 y_is_half = y_frac == 0.5 

1933 z_is_half = z_frac == 0.5 

1934 x_floor_int = tl.cast(x_floor, tl.int32) 

1935 y_floor_int = tl.cast(y_floor, tl.int32) 

1936 z_floor_int = tl.cast(z_floor, tl.int32) 

1937 x_is_even = x_floor_int % 2 == 0 

1938 y_is_even = y_floor_int % 2 == 0 

1939 z_is_even = z_floor_int % 2 == 0 

1940 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1) 

1941 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1) 

1942 z_round = tl.where(z_frac < 0.5, z_floor, z_floor + 1) 

1943 x_idx = tl.cast( 

1944 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round), 

1945 tl.int32, 

1946 ) 

1947 y_idx = tl.cast( 

1948 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round), 

1949 tl.int32, 

1950 ) 

1951 z_idx = tl.cast( 

1952 tl.where(z_is_half, tl.where(z_is_even, z_floor, z_floor + 1), z_round), 

1953 tl.int32, 

1954 ) 

1955 

1956 # Clamp for safety 

1957 x_idx = tl.maximum(0, tl.minimum(x_idx, W_in - 1)) 

1958 y_idx = tl.maximum(0, tl.minimum(y_idx, H_in - 1)) 

1959 z_idx = tl.maximum(0, tl.minimum(z_idx, D_in - 1)) 

1960 

1961 # Load input pixel 

1962 val = tl.where( 

1963 grid_x_nan | grid_y_nan | grid_z_nan, 

1964 0.0, 

1965 tl.load( 

1966 ptr_input 

1967 + n * C * D_in * H_in * W_in 

1968 + c * D_in * H_in * W_in 

1969 + z_idx * H_in * W_in 

1970 + y_idx * W_in 

1971 + x_idx 

1972 ).to(tl.float32), 

1973 ) 

1974 

1975 # Store output 

1976 output_offset = ( 

1977 n * C * D_out * H_out * W_out 

1978 + c * D_out * H_out * W_out 

1979 + d_out * H_out * W_out 

1980 + h_out * W_out 

1981 + w_out 

1982 ) 

1983 tl.store(ptr_output + output_offset, val) 

1984 

1985 

1986@libentry() 

1987@triton.autotune( 

1988 configs=runtime.get_tuned_config("grid_sample_3d_trilinear"), 

1989 key=["N", "C", "D_out", "H_out", "W_out"], 

1990) 

1991@triton.jit 

1992def grid_sample_3d_trilinear_zeros_kernel( 

1993 ptr_output, 

1994 ptr_input, 

1995 ptr_grid, 

1996 N, 

1997 C, 

1998 D_in, 

1999 H_in, 

2000 W_in, 

2001 D_out, 

2002 H_out, 

2003 W_out, 

2004 align_corners: tl.constexpr, 

2005 BLOCK_SIZE: tl.constexpr, 

2006): 

2007 """ 

2008 Grid sample kernel for 3D trilinear interpolation with zeros padding. 

2009 Loads 8 corner pixels and performs trilinear interpolation. 

2010 """ 

2011 pid = tl.program_id(0) 

2012 ncd = pid // (D_out * H_out * W_out) 

2013 dhw = pid % (D_out * H_out * W_out) 

2014 

2015 n = ncd // C 

2016 c = ncd % C 

2017 d_out = dhw // (H_out * W_out) 

2018 hw = dhw % (H_out * W_out) 

2019 h_out = hw // W_out 

2020 w_out = hw % W_out 

2021 

2022 # Load 3D grid coordinates 

2023 grid_idx = ( 

2024 n * D_out * H_out * W_out * 3 

2025 + d_out * H_out * W_out * 3 

2026 + h_out * W_out * 3 

2027 + w_out * 3 

2028 ) 

2029 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32) 

2030 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32) 

2031 grid_z = tl.load(ptr_grid + grid_idx + 2).to(tl.float32) 

2032 

2033 # Handle NaN 

2034 grid_x_nan = grid_x != grid_x 

2035 grid_y_nan = grid_y != grid_y 

2036 grid_z_nan = grid_z != grid_z 

2037 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

2038 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

2039 grid_z = tl.where(grid_z_nan, -2.0, grid_z) 

2040 

2041 # Denormalize to pixel space 

2042 if align_corners: 

2043 x = (grid_x + 1.0) * (W_in - 1) / 2.0 

2044 y = (grid_y + 1.0) * (H_in - 1) / 2.0 

2045 z = (grid_z + 1.0) * (D_in - 1) / 2.0 

2046 else: 

2047 x = (grid_x + 1.0) * W_in / 2.0 - 0.5 

2048 y = (grid_y + 1.0) * H_in / 2.0 - 0.5 

2049 z = (grid_z + 1.0) * D_in / 2.0 - 0.5 

2050 

2051 # Find 8 corner indices (2x2x2) 

2052 x0 = tl.floor(x) 

2053 y0 = tl.floor(y) 

2054 z0 = tl.floor(z) 

2055 x1 = x0 + 1 

2056 y1 = y0 + 1 

2057 z1 = z0 + 1 

2058 

2059 # Compute interpolation weights 

2060 wx = x - x0 

2061 wy = y - y0 

2062 wz = z - z0 

2063 

2064 # Convert to int 

2065 x0_int = tl.cast(x0, tl.int32) 

2066 y0_int = tl.cast(y0, tl.int32) 

2067 z0_int = tl.cast(z0, tl.int32) 

2068 x1_int = tl.cast(x1, tl.int32) 

2069 y1_int = tl.cast(y1, tl.int32) 

2070 z1_int = tl.cast(z1, tl.int32) 

2071 

2072 # Check bounds for each corner (zeros padding) 

2073 x0_in = (x0_int >= 0) & (x0_int < W_in) 

2074 x1_in = (x1_int >= 0) & (x1_int < W_in) 

2075 y0_in = (y0_int >= 0) & (y0_int < H_in) 

2076 y1_in = (y1_int >= 0) & (y1_int < H_in) 

2077 z0_in = (z0_int >= 0) & (z0_int < D_in) 

2078 z1_in = (z1_int >= 0) & (z1_int < D_in) 

2079 

2080 # Load 8 corner pixels with zeros padding 

2081 input_base = n * C * D_in * H_in * W_in + c * D_in * H_in * W_in 

2082 

2083 # z=y=x=0,0,0 

2084 offset = input_base + z0_int * H_in * W_in + y0_int * W_in + x0_int 

2085 p000 = tl.load( 

2086 ptr_input + offset, 

2087 mask=x0_in & y0_in & z0_in & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan, 

2088 other=0.0, 

2089 ).to(tl.float32) 

2090 

2091 # z=y=0, x=1 

2092 offset = input_base + z0_int * H_in * W_in + y0_int * W_in + x1_int 

2093 p001 = tl.load( 

2094 ptr_input + offset, 

2095 mask=x1_in & y0_in & z0_in & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan, 

2096 other=0.0, 

2097 ).to(tl.float32) 

2098 

2099 # z=0, y=1, x=0 

2100 offset = input_base + z0_int * H_in * W_in + y1_int * W_in + x0_int 

2101 p010 = tl.load( 

2102 ptr_input + offset, 

2103 mask=x0_in & y1_in & z0_in & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan, 

2104 other=0.0, 

2105 ).to(tl.float32) 

2106 

2107 # z=0, y=1, x=1 

2108 offset = input_base + z0_int * H_in * W_in + y1_int * W_in + x1_int 

2109 p011 = tl.load( 

2110 ptr_input + offset, 

2111 mask=x1_in & y1_in & z0_in & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan, 

2112 other=0.0, 

2113 ).to(tl.float32) 

2114 

2115 # z=1, y=x=0,0 

2116 offset = input_base + z1_int * H_in * W_in + y0_int * W_in + x0_int 

2117 p100 = tl.load( 

2118 ptr_input + offset, 

2119 mask=x0_in & y0_in & z1_in & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan, 

2120 other=0.0, 

2121 ).to(tl.float32) 

2122 

2123 # z=1, y=0, x=1 

2124 offset = input_base + z1_int * H_in * W_in + y0_int * W_in + x1_int 

2125 p101 = tl.load( 

2126 ptr_input + offset, 

2127 mask=x1_in & y0_in & z1_in & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan, 

2128 other=0.0, 

2129 ).to(tl.float32) 

2130 

2131 # z=1, y=1, x=0 

2132 offset = input_base + z1_int * H_in * W_in + y1_int * W_in + x0_int 

2133 p110 = tl.load( 

2134 ptr_input + offset, 

2135 mask=x0_in & y1_in & z1_in & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan, 

2136 other=0.0, 

2137 ).to(tl.float32) 

2138 

2139 # z=1, y=1, x=1 

2140 offset = input_base + z1_int * H_in * W_in + y1_int * W_in + x1_int 

2141 p111 = tl.load( 

2142 ptr_input + offset, 

2143 mask=x1_in & y1_in & z1_in & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan, 

2144 other=0.0, 

2145 ).to(tl.float32) 

2146 

2147 # Trilinear interpolation 

2148 # Interpolate along x first, then y, then z 

2149 # Front face (z=0) 

2150 c000 = p000 * (1.0 - wx) + p001 * wx 

2151 c001 = p010 * (1.0 - wx) + p011 * wx 

2152 front = c000 * (1.0 - wy) + c001 * wy 

2153 

2154 # Back face (z=1) 

2155 c100 = p100 * (1.0 - wx) + p101 * wx 

2156 c101 = p110 * (1.0 - wx) + p111 * wx 

2157 back = c100 * (1.0 - wy) + c101 * wy 

2158 

2159 # Interpolate along z 

2160 val = front * (1.0 - wz) + back * wz 

2161 

2162 # Store output 

2163 output_offset = ( 

2164 n * C * D_out * H_out * W_out 

2165 + c * D_out * H_out * W_out 

2166 + d_out * H_out * W_out 

2167 + h_out * W_out 

2168 + w_out 

2169 ) 

2170 tl.store(ptr_output + output_offset, val) 

2171 

2172 

2173@libentry() 

2174@triton.autotune( 

2175 configs=runtime.get_tuned_config("grid_sample_3d_trilinear"), 

2176 key=["N", "C", "D_out", "H_out", "W_out"], 

2177) 

2178@triton.jit 

2179def grid_sample_3d_trilinear_border_kernel( 

2180 ptr_output, 

2181 ptr_input, 

2182 ptr_grid, 

2183 N, 

2184 C, 

2185 D_in, 

2186 H_in, 

2187 W_in, 

2188 D_out, 

2189 H_out, 

2190 W_out, 

2191 align_corners: tl.constexpr, 

2192 BLOCK_SIZE: tl.constexpr, 

2193): 

2194 """ 

2195 Grid sample kernel for 3D trilinear interpolation with border padding. 

2196 """ 

2197 pid = tl.program_id(0) 

2198 ncd = pid // (D_out * H_out * W_out) 

2199 dhw = pid % (D_out * H_out * W_out) 

2200 

2201 n = ncd // C 

2202 c = ncd % C 

2203 d_out = dhw // (H_out * W_out) 

2204 hw = dhw % (H_out * W_out) 

2205 h_out = hw // W_out 

2206 w_out = hw % W_out 

2207 

2208 # Load 3D grid coordinates 

2209 grid_idx = ( 

2210 n * D_out * H_out * W_out * 3 

2211 + d_out * H_out * W_out * 3 

2212 + h_out * W_out * 3 

2213 + w_out * 3 

2214 ) 

2215 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32) 

2216 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32) 

2217 grid_z = tl.load(ptr_grid + grid_idx + 2).to(tl.float32) 

2218 

2219 # Handle NaN 

2220 grid_x_nan = grid_x != grid_x 

2221 grid_y_nan = grid_y != grid_y 

2222 grid_z_nan = grid_z != grid_z 

2223 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

2224 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

2225 grid_z = tl.where(grid_z_nan, -2.0, grid_z) 

2226 

2227 # Denormalize to pixel space 

2228 if align_corners: 

2229 x = (grid_x + 1.0) * (W_in - 1) / 2.0 

2230 y = (grid_y + 1.0) * (H_in - 1) / 2.0 

2231 z = (grid_z + 1.0) * (D_in - 1) / 2.0 

2232 else: 

2233 x = (grid_x + 1.0) * W_in / 2.0 - 0.5 

2234 y = (grid_y + 1.0) * H_in / 2.0 - 0.5 

2235 z = (grid_z + 1.0) * D_in / 2.0 - 0.5 

2236 

2237 # Find 8 corner indices 

2238 x0 = tl.floor(x) 

2239 y0 = tl.floor(y) 

2240 z0 = tl.floor(z) 

2241 x1 = x0 + 1 

2242 y1 = y0 + 1 

2243 z1 = z0 + 1 

2244 

2245 # Compute weights 

2246 wx = x - x0 

2247 wy = y - y0 

2248 wz = z - z0 

2249 

2250 # Convert to int and clamp 

2251 x0_int = tl.maximum(0, tl.minimum(tl.cast(x0, tl.int32), W_in - 1)) 

2252 x1_int = tl.maximum(0, tl.minimum(tl.cast(x1, tl.int32), W_in - 1)) 

2253 y0_int = tl.maximum(0, tl.minimum(tl.cast(y0, tl.int32), H_in - 1)) 

2254 y1_int = tl.maximum(0, tl.minimum(tl.cast(y1, tl.int32), H_in - 1)) 

2255 z0_int = tl.maximum(0, tl.minimum(tl.cast(z0, tl.int32), D_in - 1)) 

2256 z1_int = tl.maximum(0, tl.minimum(tl.cast(z1, tl.int32), D_in - 1)) 

2257 

2258 # Load 8 corner pixels (no mask needed due to clamping) 

2259 input_base = n * C * D_in * H_in * W_in + c * D_in * H_in * W_in 

2260 

2261 p000 = tl.load( 

2262 ptr_input + input_base + z0_int * H_in * W_in + y0_int * W_in + x0_int 

2263 ).to(tl.float32) 

2264 p001 = tl.load( 

2265 ptr_input + input_base + z0_int * H_in * W_in + y0_int * W_in + x1_int 

2266 ).to(tl.float32) 

2267 p010 = tl.load( 

2268 ptr_input + input_base + z0_int * H_in * W_in + y1_int * W_in + x0_int 

2269 ).to(tl.float32) 

2270 p011 = tl.load( 

2271 ptr_input + input_base + z0_int * H_in * W_in + y1_int * W_in + x1_int 

2272 ).to(tl.float32) 

2273 p100 = tl.load( 

2274 ptr_input + input_base + z1_int * H_in * W_in + y0_int * W_in + x0_int 

2275 ).to(tl.float32) 

2276 p101 = tl.load( 

2277 ptr_input + input_base + z1_int * H_in * W_in + y0_int * W_in + x1_int 

2278 ).to(tl.float32) 

2279 p110 = tl.load( 

2280 ptr_input + input_base + z1_int * H_in * W_in + y1_int * W_in + x0_int 

2281 ).to(tl.float32) 

2282 p111 = tl.load( 

2283 ptr_input + input_base + z1_int * H_in * W_in + y1_int * W_in + x1_int 

2284 ).to(tl.float32) 

2285 

2286 # Trilinear interpolation 

2287 c000 = p000 * (1.0 - wx) + p001 * wx 

2288 c001 = p010 * (1.0 - wx) + p011 * wx 

2289 front = c000 * (1.0 - wy) + c001 * wy 

2290 

2291 c100 = p100 * (1.0 - wx) + p101 * wx 

2292 c101 = p110 * (1.0 - wx) + p111 * wx 

2293 back = c100 * (1.0 - wy) + c101 * wy 

2294 

2295 val = tl.where( 

2296 grid_x_nan | grid_y_nan | grid_z_nan, 0.0, front * (1.0 - wz) + back * wz 

2297 ) 

2298 

2299 # Store output 

2300 output_offset = ( 

2301 n * C * D_out * H_out * W_out 

2302 + c * D_out * H_out * W_out 

2303 + d_out * H_out * W_out 

2304 + h_out * W_out 

2305 + w_out 

2306 ) 

2307 tl.store(ptr_output + output_offset, val) 

2308 

2309 

2310@libentry() 

2311@triton.autotune( 

2312 configs=runtime.get_tuned_config("grid_sample_3d_trilinear"), 

2313 key=["N", "C", "D_out", "H_out", "W_out"], 

2314) 

2315@triton.jit 

2316def grid_sample_3d_trilinear_reflection_kernel( 

2317 ptr_output, 

2318 ptr_input, 

2319 ptr_grid, 

2320 N, 

2321 C, 

2322 D_in, 

2323 H_in, 

2324 W_in, 

2325 D_out, 

2326 H_out, 

2327 W_out, 

2328 align_corners: tl.constexpr, 

2329 BLOCK_SIZE: tl.constexpr, 

2330): 

2331 """ 

2332 Grid sample kernel for 3D trilinear interpolation with reflection padding. 

2333 """ 

2334 pid = tl.program_id(0) 

2335 ncd = pid // (D_out * H_out * W_out) 

2336 dhw = pid % (D_out * H_out * W_out) 

2337 

2338 n = ncd // C 

2339 c = ncd % C 

2340 d_out = dhw // (H_out * W_out) 

2341 hw = dhw % (H_out * W_out) 

2342 h_out = hw // W_out 

2343 w_out = hw % W_out 

2344 

2345 # Load 3D grid coordinates 

2346 grid_idx = ( 

2347 n * D_out * H_out * W_out * 3 

2348 + d_out * H_out * W_out * 3 

2349 + h_out * W_out * 3 

2350 + w_out * 3 

2351 ) 

2352 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32) 

2353 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32) 

2354 grid_z = tl.load(ptr_grid + grid_idx + 2).to(tl.float32) 

2355 

2356 # Handle NaN 

2357 grid_x_nan = grid_x != grid_x 

2358 grid_y_nan = grid_y != grid_y 

2359 grid_z_nan = grid_z != grid_z 

2360 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

2361 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

2362 grid_z = tl.where(grid_z_nan, -2.0, grid_z) 

2363 

2364 # Reflection padding in GRID space (triangle wave) 

2365 grid_x_shifted = grid_x + 1.0 

2366 grid_x_mod = grid_x_shifted % 4.0 

2367 grid_x_mod = tl.where(grid_x_mod < 0, grid_x_mod + 4.0, grid_x_mod) 

2368 grid_x_refl_mod = tl.where(grid_x_mod <= 2.0, grid_x_mod, 4.0 - grid_x_mod) 

2369 grid_x_refl = grid_x_refl_mod - 1.0 

2370 

2371 grid_y_shifted = grid_y + 1.0 

2372 grid_y_mod = grid_y_shifted % 4.0 

2373 grid_y_mod = tl.where(grid_y_mod < 0, grid_y_mod + 4.0, grid_y_mod) 

2374 grid_y_refl_mod = tl.where(grid_y_mod <= 2.0, grid_y_mod, 4.0 - grid_y_mod) 

2375 grid_y_refl = grid_y_refl_mod - 1.0 

2376 

2377 grid_z_shifted = grid_z + 1.0 

2378 grid_z_mod = grid_z_shifted % 4.0 

2379 grid_z_mod = tl.where(grid_z_mod < 0, grid_z_mod + 4.0, grid_z_mod) 

2380 grid_z_refl_mod = tl.where(grid_z_mod <= 2.0, grid_z_mod, 4.0 - grid_z_mod) 

2381 grid_z_refl = grid_z_refl_mod - 1.0 

2382 

2383 # Denormalize to pixel space 

2384 if align_corners: 

2385 x = (grid_x_refl + 1.0) * (W_in - 1) / 2.0 

2386 y = (grid_y_refl + 1.0) * (H_in - 1) / 2.0 

2387 z = (grid_z_refl + 1.0) * (D_in - 1) / 2.0 

2388 else: 

2389 x = (grid_x_refl + 1.0) * W_in / 2.0 - 0.5 

2390 y = (grid_y_refl + 1.0) * H_in / 2.0 - 0.5 

2391 z = (grid_z_refl + 1.0) * D_in / 2.0 - 0.5 

2392 

2393 # Find 8 corner indices 

2394 x0 = tl.floor(x) 

2395 y0 = tl.floor(y) 

2396 z0 = tl.floor(z) 

2397 x1 = x0 + 1 

2398 y1 = y0 + 1 

2399 z1 = z0 + 1 

2400 

2401 # Compute weights 

2402 wx = x - x0 

2403 wy = y - y0 

2404 wz = z - z0 

2405 

2406 # Convert to int and clamp 

2407 x0_int = tl.maximum(0, tl.minimum(tl.cast(x0, tl.int32), W_in - 1)) 

2408 x1_int = tl.maximum(0, tl.minimum(tl.cast(x1, tl.int32), W_in - 1)) 

2409 y0_int = tl.maximum(0, tl.minimum(tl.cast(y0, tl.int32), H_in - 1)) 

2410 y1_int = tl.maximum(0, tl.minimum(tl.cast(y1, tl.int32), H_in - 1)) 

2411 z0_int = tl.maximum(0, tl.minimum(tl.cast(z0, tl.int32), D_in - 1)) 

2412 z1_int = tl.maximum(0, tl.minimum(tl.cast(z1, tl.int32), D_in - 1)) 

2413 

2414 # Load 8 corner pixels 

2415 input_base = n * C * D_in * H_in * W_in + c * D_in * H_in * W_in 

2416 

2417 p000 = tl.load( 

2418 ptr_input + input_base + z0_int * H_in * W_in + y0_int * W_in + x0_int 

2419 ).to(tl.float32) 

2420 p001 = tl.load( 

2421 ptr_input + input_base + z0_int * H_in * W_in + y0_int * W_in + x1_int 

2422 ).to(tl.float32) 

2423 p010 = tl.load( 

2424 ptr_input + input_base + z0_int * H_in * W_in + y1_int * W_in + x0_int 

2425 ).to(tl.float32) 

2426 p011 = tl.load( 

2427 ptr_input + input_base + z0_int * H_in * W_in + y1_int * W_in + x1_int 

2428 ).to(tl.float32) 

2429 p100 = tl.load( 

2430 ptr_input + input_base + z1_int * H_in * W_in + y0_int * W_in + x0_int 

2431 ).to(tl.float32) 

2432 p101 = tl.load( 

2433 ptr_input + input_base + z1_int * H_in * W_in + y0_int * W_in + x1_int 

2434 ).to(tl.float32) 

2435 p110 = tl.load( 

2436 ptr_input + input_base + z1_int * H_in * W_in + y1_int * W_in + x0_int 

2437 ).to(tl.float32) 

2438 p111 = tl.load( 

2439 ptr_input + input_base + z1_int * H_in * W_in + y1_int * W_in + x1_int 

2440 ).to(tl.float32) 

2441 

2442 # Trilinear interpolation 

2443 c000 = p000 * (1.0 - wx) + p001 * wx 

2444 c001 = p010 * (1.0 - wx) + p011 * wx 

2445 front = c000 * (1.0 - wy) + c001 * wy 

2446 

2447 c100 = p100 * (1.0 - wx) + p101 * wx 

2448 c101 = p110 * (1.0 - wx) + p111 * wx 

2449 back = c100 * (1.0 - wy) + c101 * wy 

2450 

2451 val = tl.where( 

2452 grid_x_nan | grid_y_nan | grid_z_nan, 0.0, front * (1.0 - wz) + back * wz 

2453 ) 

2454 

2455 # Store output 

2456 output_offset = ( 

2457 n * C * D_out * H_out * W_out 

2458 + c * D_out * H_out * W_out 

2459 + d_out * H_out * W_out 

2460 + h_out * W_out 

2461 + w_out 

2462 ) 

2463 tl.store(ptr_output + output_offset, val) 

2464 

2465 

2466# ============================================================================ 

2467# 3D Tiled Kernels for Medium-to-Large 5D Inputs (3D Blocking: D×H×W) 

2468# ============================================================================ 

2469 

2470 

2471@libentry() 

2472@triton.autotune( 

2473 configs=runtime.get_tuned_config("grid_sample_3d_nearest_tiled"), 

2474 key=["N", "C", "D_out", "H_out", "W_out"], 

2475) 

2476@triton.jit 

2477def grid_sample_3d_nearest_zeros_tiled_kernel( 

2478 ptr_output, 

2479 ptr_input, 

2480 ptr_grid, 

2481 N, 

2482 C, 

2483 D_in, 

2484 H_in, 

2485 W_in, 

2486 D_out, 

2487 H_out, 

2488 W_out, 

2489 align_corners: tl.constexpr, 

2490 BLOCK_D: tl.constexpr, 

2491 BLOCK_H: tl.constexpr, 

2492 BLOCK_W: tl.constexpr, 

2493): 

2494 """ 

2495 Grid sample kernel for 3D nearest neighbor interpolation with zeros padding (tiled version). 

2496 

2497 This kernel processes a BLOCK_D × BLOCK_H × BLOCK_W tile of output voxels at once, 

2498 enabling better memory coalescing and data reuse for medium-to-large 5D inputs. 

2499 

2500 Args: 

2501 ptr_output: Pointer to output tensor (N, C, D_out, H_out, W_out) 

2502 ptr_input: Pointer to input tensor (N, C, D_in, H_in, W_in) 

2503 ptr_grid: Pointer to grid tensor (N, D_out, H_out, W_out, 3) 

2504 N: Batch size 

2505 C: Number of channels 

2506 D_in: Input depth 

2507 H_in: Input height 

2508 W_in: Input width 

2509 D_out: Output depth 

2510 H_out: Output height 

2511 W_out: Output width 

2512 align_corners: Whether to align corners 

2513 BLOCK_D: Block depth for tiling 

2514 BLOCK_H: Block height for tiling 

2515 BLOCK_W: Block width for tiling 

2516 """ 

2517 # 2D program IDs: pid_nc for (batch, channel), pid_dhw for spatial tile 

2518 pid_nc = tl.program_id(0) 

2519 pid_dhw = tl.program_id(1) 

2520 

2521 # Compute batch and channel 

2522 n = pid_nc // C 

2523 c = pid_nc % C 

2524 

2525 # Decompose flattened 3D tile index to (d, h, w) block indices 

2526 num_w_blocks = tl.cdiv(W_out, BLOCK_W) 

2527 num_h_blocks = tl.cdiv(H_out, BLOCK_H) 

2528 num_hw_blocks = num_h_blocks * num_w_blocks 

2529 

2530 d_block_idx = pid_dhw // num_hw_blocks 

2531 hw_block_idx = pid_dhw % num_hw_blocks 

2532 h_block_idx = hw_block_idx // num_w_blocks 

2533 w_block_idx = hw_block_idx % num_w_blocks 

2534 

2535 # Compute voxel offsets within tile (3D broadcasting) 

2536 d_offsets = d_block_idx * BLOCK_D + tl.arange(0, BLOCK_D) 

2537 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

2538 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

2539 

2540 # Mask for boundary tiles 

2541 d_mask = d_offsets < D_out 

2542 h_mask = h_offsets < H_out 

2543 w_mask = w_offsets < W_out 

2544 tile_mask = d_mask[:, None, None] & h_mask[None, :, None] & w_mask[None, None, :] 

2545 

2546 # Reshape for 3D broadcasting: (BLOCK_D, BLOCK_H, BLOCK_W) 

2547 d_out_3d = d_offsets[:, None, None] 

2548 h_out_3d = h_offsets[None, :, None] 

2549 w_out_3d = w_offsets[None, None, :] 

2550 

2551 # Load 3D grid coordinates for entire tile (vectorized) 

2552 # Grid shape: (N, D_out, H_out, W_out, 3) 

2553 grid_base = n * D_out * H_out * W_out * 3 

2554 

2555 # Load x, y, z coordinates: (BLOCK_D, BLOCK_H, BLOCK_W) 

2556 grid_x_offsets = ( 

2557 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 

2558 ) 

2559 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to( 

2560 tl.float32 

2561 ) 

2562 

2563 grid_y_offsets = ( 

2564 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 1 

2565 ) 

2566 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to( 

2567 tl.float32 

2568 ) 

2569 

2570 grid_z_offsets = ( 

2571 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 2 

2572 ) 

2573 grid_z = tl.load(ptr_grid + grid_z_offsets, mask=tile_mask, other=0.0).to( 

2574 tl.float32 

2575 ) 

2576 

2577 # Handle NaN - use sentinel value -2.0 (outside valid grid range [-1, 1]) 

2578 grid_x_nan = grid_x != grid_x 

2579 grid_y_nan = grid_y != grid_y 

2580 grid_z_nan = grid_z != grid_z 

2581 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

2582 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

2583 grid_z = tl.where(grid_z_nan, -2.0, grid_z) 

2584 

2585 # Denormalize to pixel space 

2586 if align_corners: 

2587 # Pixel centers at -1 and 1 

2588 x = (grid_x + 1.0) * (W_in - 1) / 2.0 

2589 y = (grid_y + 1.0) * (H_in - 1) / 2.0 

2590 z = (grid_z + 1.0) * (D_in - 1) / 2.0 

2591 else: 

2592 # Pixel corners at -1 and 1 

2593 x = (grid_x + 1.0) * W_in / 2.0 - 0.5 

2594 y = (grid_y + 1.0) * H_in / 2.0 - 0.5 

2595 z = (grid_z + 1.0) * D_in / 2.0 - 0.5 

2596 

2597 # Apply banker's rounding (vectorized across tile) 

2598 x_floor = tl.floor(x) 

2599 y_floor = tl.floor(y) 

2600 z_floor = tl.floor(z) 

2601 x_frac = x - x_floor 

2602 y_frac = y - y_floor 

2603 z_frac = z - z_floor 

2604 

2605 x_is_half = x_frac == 0.5 

2606 y_is_half = y_frac == 0.5 

2607 z_is_half = z_frac == 0.5 

2608 x_floor_int = tl.cast(x_floor, tl.int32) 

2609 y_floor_int = tl.cast(y_floor, tl.int32) 

2610 z_floor_int = tl.cast(z_floor, tl.int32) 

2611 

2612 x_is_even = x_floor_int % 2 == 0 

2613 y_is_even = y_floor_int % 2 == 0 

2614 z_is_even = z_floor_int % 2 == 0 

2615 

2616 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1) 

2617 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1) 

2618 z_round = tl.where(z_frac < 0.5, z_floor, z_floor + 1) 

2619 

2620 x_idx = tl.cast( 

2621 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round), 

2622 tl.int32, 

2623 ) 

2624 y_idx = tl.cast( 

2625 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round), 

2626 tl.int32, 

2627 ) 

2628 z_idx = tl.cast( 

2629 tl.where(z_is_half, tl.where(z_is_even, z_floor, z_floor + 1), z_round), 

2630 tl.int32, 

2631 ) 

2632 

2633 # Check bounds (vectorized) 

2634 x_in_bounds = (x_idx >= 0) & (x_idx < W_in) 

2635 y_in_bounds = (y_idx >= 0) & (y_idx < H_in) 

2636 z_in_bounds = (z_idx >= 0) & (z_idx < D_in) 

2637 valid_mask = ( 

2638 tile_mask 

2639 & x_in_bounds 

2640 & y_in_bounds 

2641 & z_in_bounds 

2642 & ~grid_x_nan 

2643 & ~grid_y_nan 

2644 & ~grid_z_nan 

2645 ) 

2646 

2647 # Load input voxels for entire tile 

2648 # Input shape: (N, C, D_in, H_in, W_in) 

2649 input_base = n * C * D_in * H_in * W_in + c * D_in * H_in * W_in 

2650 input_offsets = input_base + z_idx * H_in * W_in + y_idx * W_in + x_idx 

2651 

2652 # Vectorized load: (BLOCK_D, BLOCK_H, BLOCK_W) 

2653 vals = tl.load(ptr_input + input_offsets, mask=valid_mask, other=0.0) 

2654 

2655 # Store to output 

2656 # Output shape: (N, C, D_out, H_out, W_out) 

2657 output_base = n * C * D_out * H_out * W_out + c * D_out * H_out * W_out 

2658 output_offsets = output_base + ( 

2659 d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d 

2660 ) 

2661 

2662 tl.store(ptr_output + output_offsets, vals, mask=tile_mask) 

2663 

2664 

2665@libentry() 

2666@triton.autotune( 

2667 configs=runtime.get_tuned_config("grid_sample_3d_nearest_tiled"), 

2668 key=["N", "C", "D_out", "H_out", "W_out"], 

2669) 

2670@triton.jit 

2671def grid_sample_3d_nearest_border_tiled_kernel( 

2672 ptr_output, 

2673 ptr_input, 

2674 ptr_grid, 

2675 N, 

2676 C, 

2677 D_in, 

2678 H_in, 

2679 W_in, 

2680 D_out, 

2681 H_out, 

2682 W_out, 

2683 align_corners: tl.constexpr, 

2684 BLOCK_D: tl.constexpr, 

2685 BLOCK_H: tl.constexpr, 

2686 BLOCK_W: tl.constexpr, 

2687): 

2688 """ 

2689 Grid sample kernel for 3D nearest neighbor interpolation with border padding (tiled version). 

2690 

2691 Border padding: coordinates outside the input range are clamped to the boundary. 

2692 """ 

2693 # 2D program IDs: pid_nc for (batch, channel), pid_dhw for spatial tile 

2694 pid_nc = tl.program_id(0) 

2695 pid_dhw = tl.program_id(1) 

2696 

2697 # Compute batch and channel 

2698 n = pid_nc // C 

2699 c = pid_nc % C 

2700 

2701 # Decompose flattened 3D tile index to (d, h, w) block indices 

2702 num_w_blocks = tl.cdiv(W_out, BLOCK_W) 

2703 num_h_blocks = tl.cdiv(H_out, BLOCK_H) 

2704 num_hw_blocks = num_h_blocks * num_w_blocks 

2705 

2706 d_block_idx = pid_dhw // num_hw_blocks 

2707 hw_block_idx = pid_dhw % num_hw_blocks 

2708 h_block_idx = hw_block_idx // num_w_blocks 

2709 w_block_idx = hw_block_idx % num_w_blocks 

2710 

2711 # Compute voxel offsets within tile (3D broadcasting) 

2712 d_offsets = d_block_idx * BLOCK_D + tl.arange(0, BLOCK_D) 

2713 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

2714 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

2715 

2716 # Mask for boundary tiles 

2717 d_mask = d_offsets < D_out 

2718 h_mask = h_offsets < H_out 

2719 w_mask = w_offsets < W_out 

2720 tile_mask = d_mask[:, None, None] & h_mask[None, :, None] & w_mask[None, None, :] 

2721 

2722 # Reshape for 3D broadcasting: (BLOCK_D, BLOCK_H, BLOCK_W) 

2723 d_out_3d = d_offsets[:, None, None] 

2724 h_out_3d = h_offsets[None, :, None] 

2725 w_out_3d = w_offsets[None, None, :] 

2726 

2727 # Load 3D grid coordinates for entire tile (vectorized) 

2728 grid_base = n * D_out * H_out * W_out * 3 

2729 

2730 grid_x_offsets = ( 

2731 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 

2732 ) 

2733 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to( 

2734 tl.float32 

2735 ) 

2736 

2737 grid_y_offsets = ( 

2738 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 1 

2739 ) 

2740 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to( 

2741 tl.float32 

2742 ) 

2743 

2744 grid_z_offsets = ( 

2745 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 2 

2746 ) 

2747 grid_z = tl.load(ptr_grid + grid_z_offsets, mask=tile_mask, other=0.0).to( 

2748 tl.float32 

2749 ) 

2750 

2751 # Handle NaN 

2752 grid_x_nan = grid_x != grid_x 

2753 grid_y_nan = grid_y != grid_y 

2754 grid_z_nan = grid_z != grid_z 

2755 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

2756 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

2757 grid_z = tl.where(grid_z_nan, -2.0, grid_z) 

2758 

2759 # Denormalize to pixel space 

2760 if align_corners: 

2761 x = (grid_x + 1.0) * (W_in - 1) / 2.0 

2762 y = (grid_y + 1.0) * (H_in - 1) / 2.0 

2763 z = (grid_z + 1.0) * (D_in - 1) / 2.0 

2764 else: 

2765 x = (grid_x + 1.0) * W_in / 2.0 - 0.5 

2766 y = (grid_y + 1.0) * H_in / 2.0 - 0.5 

2767 z = (grid_z + 1.0) * D_in / 2.0 - 0.5 

2768 

2769 # Apply banker's rounding (vectorized across tile) 

2770 x_floor = tl.floor(x) 

2771 y_floor = tl.floor(y) 

2772 z_floor = tl.floor(z) 

2773 x_frac = x - x_floor 

2774 y_frac = y - y_floor 

2775 z_frac = z - z_floor 

2776 

2777 x_is_half = x_frac == 0.5 

2778 y_is_half = y_frac == 0.5 

2779 z_is_half = z_frac == 0.5 

2780 x_floor_int = tl.cast(x_floor, tl.int32) 

2781 y_floor_int = tl.cast(y_floor, tl.int32) 

2782 z_floor_int = tl.cast(z_floor, tl.int32) 

2783 

2784 x_is_even = x_floor_int % 2 == 0 

2785 y_is_even = y_floor_int % 2 == 0 

2786 z_is_even = z_floor_int % 2 == 0 

2787 

2788 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1) 

2789 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1) 

2790 z_round = tl.where(z_frac < 0.5, z_floor, z_floor + 1) 

2791 

2792 x_idx = tl.cast( 

2793 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round), 

2794 tl.int32, 

2795 ) 

2796 y_idx = tl.cast( 

2797 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round), 

2798 tl.int32, 

2799 ) 

2800 z_idx = tl.cast( 

2801 tl.where(z_is_half, tl.where(z_is_even, z_floor, z_floor + 1), z_round), 

2802 tl.int32, 

2803 ) 

2804 

2805 # Border padding: clamp coordinates to valid range 

2806 x_idx = tl.maximum(0, tl.minimum(x_idx, W_in - 1)) 

2807 y_idx = tl.maximum(0, tl.minimum(y_idx, H_in - 1)) 

2808 z_idx = tl.maximum(0, tl.minimum(z_idx, D_in - 1)) 

2809 

2810 # Valid mask: only tile boundary and NaN (no bounds check needed for border) 

2811 valid_mask = tile_mask & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan 

2812 

2813 # Load input voxels for entire tile 

2814 input_base = n * C * D_in * H_in * W_in + c * D_in * H_in * W_in 

2815 input_offsets = input_base + z_idx * H_in * W_in + y_idx * W_in + x_idx 

2816 

2817 # Load and handle NaN separately (border padding doesn't help with NaN) 

2818 vals_raw = tl.load(ptr_input + input_offsets, mask=valid_mask, other=0.0) 

2819 vals = tl.where(grid_x_nan | grid_y_nan | grid_z_nan, 0.0, vals_raw) 

2820 

2821 # Store to output 

2822 output_base = n * C * D_out * H_out * W_out + c * D_out * H_out * W_out 

2823 output_offsets = output_base + ( 

2824 d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d 

2825 ) 

2826 

2827 tl.store(ptr_output + output_offsets, vals, mask=tile_mask) 

2828 

2829 

2830@libentry() 

2831@triton.autotune( 

2832 configs=runtime.get_tuned_config("grid_sample_3d_nearest_tiled"), 

2833 key=["N", "C", "D_out", "H_out", "W_out"], 

2834) 

2835@triton.jit 

2836def grid_sample_3d_nearest_reflection_tiled_kernel( 

2837 ptr_output, 

2838 ptr_input, 

2839 ptr_grid, 

2840 N, 

2841 C, 

2842 D_in, 

2843 H_in, 

2844 W_in, 

2845 D_out, 

2846 H_out, 

2847 W_out, 

2848 align_corners: tl.constexpr, 

2849 BLOCK_D: tl.constexpr, 

2850 BLOCK_H: tl.constexpr, 

2851 BLOCK_W: tl.constexpr, 

2852): 

2853 """ 

2854 Grid sample kernel for 3D nearest neighbor interpolation with reflection padding (tiled version). 

2855 

2856 Reflection padding: coordinates outside the input range are reflected back into the valid range 

2857 using a triangle wave pattern with period 4. 

2858 """ 

2859 # 2D program IDs: pid_nc for (batch, channel), pid_dhw for spatial tile 

2860 pid_nc = tl.program_id(0) 

2861 pid_dhw = tl.program_id(1) 

2862 

2863 # Compute batch and channel 

2864 n = pid_nc // C 

2865 c = pid_nc % C 

2866 

2867 # Decompose flattened 3D tile index to (d, h, w) block indices 

2868 num_w_blocks = tl.cdiv(W_out, BLOCK_W) 

2869 num_h_blocks = tl.cdiv(H_out, BLOCK_H) 

2870 num_hw_blocks = num_h_blocks * num_w_blocks 

2871 

2872 d_block_idx = pid_dhw // num_hw_blocks 

2873 hw_block_idx = pid_dhw % num_hw_blocks 

2874 h_block_idx = hw_block_idx // num_w_blocks 

2875 w_block_idx = hw_block_idx % num_w_blocks 

2876 

2877 # Compute voxel offsets within tile (3D broadcasting) 

2878 d_offsets = d_block_idx * BLOCK_D + tl.arange(0, BLOCK_D) 

2879 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

2880 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

2881 

2882 # Mask for boundary tiles 

2883 d_mask = d_offsets < D_out 

2884 h_mask = h_offsets < H_out 

2885 w_mask = w_offsets < W_out 

2886 tile_mask = d_mask[:, None, None] & h_mask[None, :, None] & w_mask[None, None, :] 

2887 

2888 # Reshape for 3D broadcasting: (BLOCK_D, BLOCK_H, BLOCK_W) 

2889 d_out_3d = d_offsets[:, None, None] 

2890 h_out_3d = h_offsets[None, :, None] 

2891 w_out_3d = w_offsets[None, None, :] 

2892 

2893 # Load 3D grid coordinates for entire tile (vectorized) 

2894 grid_base = n * D_out * H_out * W_out * 3 

2895 

2896 grid_x_offsets = ( 

2897 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 

2898 ) 

2899 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to( 

2900 tl.float32 

2901 ) 

2902 

2903 grid_y_offsets = ( 

2904 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 1 

2905 ) 

2906 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to( 

2907 tl.float32 

2908 ) 

2909 

2910 grid_z_offsets = ( 

2911 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 2 

2912 ) 

2913 grid_z = tl.load(ptr_grid + grid_z_offsets, mask=tile_mask, other=0.0).to( 

2914 tl.float32 

2915 ) 

2916 

2917 # Handle NaN 

2918 grid_x_nan = grid_x != grid_x 

2919 grid_y_nan = grid_y != grid_y 

2920 grid_z_nan = grid_z != grid_z 

2921 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

2922 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

2923 grid_z = tl.where(grid_z_nan, -2.0, grid_z) 

2924 

2925 # Apply triangle wave reflection with period 4 (before denormalization) 

2926 # This maps coordinates outside [-1, 1] back into this range by reflection 

2927 # Process grid_x 

2928 grid_x_shifted = grid_x + 1.0 

2929 grid_x_mod = grid_x_shifted % 4.0 

2930 grid_x_mod = tl.where(grid_x_mod < 0, grid_x_mod + 4.0, grid_x_mod) 

2931 grid_x_refl_mod = tl.where(grid_x_mod <= 2.0, grid_x_mod, 4.0 - grid_x_mod) 

2932 grid_x = grid_x_refl_mod - 1.0 

2933 

2934 # Process grid_y 

2935 grid_y_shifted = grid_y + 1.0 

2936 grid_y_mod = grid_y_shifted % 4.0 

2937 grid_y_mod = tl.where(grid_y_mod < 0, grid_y_mod + 4.0, grid_y_mod) 

2938 grid_y_refl_mod = tl.where(grid_y_mod <= 2.0, grid_y_mod, 4.0 - grid_y_mod) 

2939 grid_y = grid_y_refl_mod - 1.0 

2940 

2941 # Process grid_z 

2942 grid_z_shifted = grid_z + 1.0 

2943 grid_z_mod = grid_z_shifted % 4.0 

2944 grid_z_mod = tl.where(grid_z_mod < 0, grid_z_mod + 4.0, grid_z_mod) 

2945 grid_z_refl_mod = tl.where(grid_z_mod <= 2.0, grid_z_mod, 4.0 - grid_z_mod) 

2946 grid_z = grid_z_refl_mod - 1.0 

2947 

2948 # Denormalize reflected coordinates to pixel space 

2949 if align_corners: 

2950 x = (grid_x + 1.0) * (W_in - 1) / 2.0 

2951 y = (grid_y + 1.0) * (H_in - 1) / 2.0 

2952 z = (grid_z + 1.0) * (D_in - 1) / 2.0 

2953 else: 

2954 x = (grid_x + 1.0) * W_in / 2.0 - 0.5 

2955 y = (grid_y + 1.0) * H_in / 2.0 - 0.5 

2956 z = (grid_z + 1.0) * D_in / 2.0 - 0.5 

2957 

2958 # Apply banker's rounding (vectorized across tile) 

2959 x_floor = tl.floor(x) 

2960 y_floor = tl.floor(y) 

2961 z_floor = tl.floor(z) 

2962 x_frac = x - x_floor 

2963 y_frac = y - y_floor 

2964 z_frac = z - z_floor 

2965 

2966 x_is_half = x_frac == 0.5 

2967 y_is_half = y_frac == 0.5 

2968 z_is_half = z_frac == 0.5 

2969 x_floor_int = tl.cast(x_floor, tl.int32) 

2970 y_floor_int = tl.cast(y_floor, tl.int32) 

2971 z_floor_int = tl.cast(z_floor, tl.int32) 

2972 

2973 x_is_even = x_floor_int % 2 == 0 

2974 y_is_even = y_floor_int % 2 == 0 

2975 z_is_even = z_floor_int % 2 == 0 

2976 

2977 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1) 

2978 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1) 

2979 z_round = tl.where(z_frac < 0.5, z_floor, z_floor + 1) 

2980 

2981 x_idx = tl.cast( 

2982 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round), 

2983 tl.int32, 

2984 ) 

2985 y_idx = tl.cast( 

2986 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round), 

2987 tl.int32, 

2988 ) 

2989 z_idx = tl.cast( 

2990 tl.where(z_is_half, tl.where(z_is_even, z_floor, z_floor + 1), z_round), 

2991 tl.int32, 

2992 ) 

2993 

2994 # Check bounds (reflection ensures coordinates are valid, but still check) 

2995 x_in_bounds = (x_idx >= 0) & (x_idx < W_in) 

2996 y_in_bounds = (y_idx >= 0) & (y_idx < H_in) 

2997 z_in_bounds = (z_idx >= 0) & (z_idx < D_in) 

2998 valid_mask = ( 

2999 tile_mask 

3000 & x_in_bounds 

3001 & y_in_bounds 

3002 & z_in_bounds 

3003 & ~grid_x_nan 

3004 & ~grid_y_nan 

3005 & ~grid_z_nan 

3006 ) 

3007 

3008 # Load input voxels for entire tile 

3009 input_base = n * C * D_in * H_in * W_in + c * D_in * H_in * W_in 

3010 input_offsets = input_base + z_idx * H_in * W_in + y_idx * W_in + x_idx 

3011 

3012 vals = tl.load(ptr_input + input_offsets, mask=valid_mask, other=0.0) 

3013 

3014 # Store to output 

3015 output_base = n * C * D_out * H_out * W_out + c * D_out * H_out * W_out 

3016 output_offsets = output_base + ( 

3017 d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d 

3018 ) 

3019 

3020 tl.store(ptr_output + output_offsets, vals, mask=tile_mask) 

3021 

3022 

3023@libentry() 

3024@triton.autotune( 

3025 configs=runtime.get_tuned_config("grid_sample_3d_trilinear_tiled"), 

3026 key=["N", "C", "D_out", "H_out", "W_out"], 

3027) 

3028@triton.jit 

3029def grid_sample_3d_trilinear_zeros_tiled_kernel( 

3030 ptr_output, 

3031 ptr_input, 

3032 ptr_grid, 

3033 N, 

3034 C, 

3035 D_in, 

3036 H_in, 

3037 W_in, 

3038 D_out, 

3039 H_out, 

3040 W_out, 

3041 align_corners: tl.constexpr, 

3042 BLOCK_D: tl.constexpr, 

3043 BLOCK_H: tl.constexpr, 

3044 BLOCK_W: tl.constexpr, 

3045): 

3046 """ 

3047 Grid sample kernel for 3D trilinear interpolation with zeros padding (tiled version). 

3048 

3049 Trilinear interpolation uses 8 corner points (2×2×2 cube) for each output voxel. 

3050 """ 

3051 # 2D program IDs: pid_nc for (batch, channel), pid_dhw for spatial tile 

3052 pid_nc = tl.program_id(0) 

3053 pid_dhw = tl.program_id(1) 

3054 

3055 # Compute batch and channel 

3056 n = pid_nc // C 

3057 c = pid_nc % C 

3058 

3059 # Decompose flattened 3D tile index to (d, h, w) block indices 

3060 num_w_blocks = tl.cdiv(W_out, BLOCK_W) 

3061 num_h_blocks = tl.cdiv(H_out, BLOCK_H) 

3062 num_hw_blocks = num_h_blocks * num_w_blocks 

3063 

3064 d_block_idx = pid_dhw // num_hw_blocks 

3065 hw_block_idx = pid_dhw % num_hw_blocks 

3066 h_block_idx = hw_block_idx // num_w_blocks 

3067 w_block_idx = hw_block_idx % num_w_blocks 

3068 

3069 # Compute voxel offsets within tile (3D broadcasting) 

3070 d_offsets = d_block_idx * BLOCK_D + tl.arange(0, BLOCK_D) 

3071 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

3072 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

3073 

3074 # Mask for boundary tiles 

3075 d_mask = d_offsets < D_out 

3076 h_mask = h_offsets < H_out 

3077 w_mask = w_offsets < W_out 

3078 tile_mask = d_mask[:, None, None] & h_mask[None, :, None] & w_mask[None, None, :] 

3079 

3080 # Reshape for 3D broadcasting: (BLOCK_D, BLOCK_H, BLOCK_W) 

3081 d_out_3d = d_offsets[:, None, None] 

3082 h_out_3d = h_offsets[None, :, None] 

3083 w_out_3d = w_offsets[None, None, :] 

3084 

3085 # Load 3D grid coordinates for entire tile (vectorized) 

3086 grid_base = n * D_out * H_out * W_out * 3 

3087 

3088 grid_x_offsets = ( 

3089 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 

3090 ) 

3091 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to( 

3092 tl.float32 

3093 ) 

3094 

3095 grid_y_offsets = ( 

3096 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 1 

3097 ) 

3098 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to( 

3099 tl.float32 

3100 ) 

3101 

3102 grid_z_offsets = ( 

3103 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 2 

3104 ) 

3105 grid_z = tl.load(ptr_grid + grid_z_offsets, mask=tile_mask, other=0.0).to( 

3106 tl.float32 

3107 ) 

3108 

3109 # Handle NaN 

3110 grid_x_nan = grid_x != grid_x 

3111 grid_y_nan = grid_y != grid_y 

3112 grid_z_nan = grid_z != grid_z 

3113 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

3114 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

3115 grid_z = tl.where(grid_z_nan, -2.0, grid_z) 

3116 

3117 # Denormalize to pixel space 

3118 if align_corners: 

3119 x = (grid_x + 1.0) * (W_in - 1) / 2.0 

3120 y = (grid_y + 1.0) * (H_in - 1) / 2.0 

3121 z = (grid_z + 1.0) * (D_in - 1) / 2.0 

3122 else: 

3123 x = (grid_x + 1.0) * W_in / 2.0 - 0.5 

3124 y = (grid_y + 1.0) * H_in / 2.0 - 0.5 

3125 z = (grid_z + 1.0) * D_in / 2.0 - 0.5 

3126 

3127 # Compute 8 corner indices for entire tile 

3128 x0 = tl.floor(x) 

3129 y0 = tl.floor(y) 

3130 z0 = tl.floor(z) 

3131 x1 = x0 + 1 

3132 y1 = y0 + 1 

3133 z1 = z0 + 1 

3134 

3135 # Interpolation weights 

3136 wx = x - x0 

3137 wy = y - y0 

3138 wz = z - z0 

3139 

3140 # Convert to integers 

3141 x0_int = tl.cast(x0, tl.int32) 

3142 y0_int = tl.cast(y0, tl.int32) 

3143 z0_int = tl.cast(z0, tl.int32) 

3144 x1_int = tl.cast(x1, tl.int32) 

3145 y1_int = tl.cast(y1, tl.int32) 

3146 z1_int = tl.cast(z1, tl.int32) 

3147 

3148 # Boundary checks 

3149 x0_in = (x0_int >= 0) & (x0_int < W_in) 

3150 x1_in = (x1_int >= 0) & (x1_int < W_in) 

3151 y0_in = (y0_int >= 0) & (y0_int < H_in) 

3152 y1_in = (y1_int >= 0) & (y1_int < H_in) 

3153 z0_in = (z0_int >= 0) & (z0_int < D_in) 

3154 z1_in = (z1_int >= 0) & (z1_int < D_in) 

3155 

3156 # Load 8 corners (vectorized) 

3157 input_base = n * C * D_in * H_in * W_in + c * D_in * H_in * W_in 

3158 

3159 # p000: (x=0, y=0, z=0) 

3160 offset = input_base + z0_int * H_in * W_in + y0_int * W_in + x0_int 

3161 p000 = tl.load( 

3162 ptr_input + offset, 

3163 mask=tile_mask 

3164 & x0_in 

3165 & y0_in 

3166 & z0_in 

3167 & ~grid_x_nan 

3168 & ~grid_y_nan 

3169 & ~grid_z_nan, 

3170 other=0.0, 

3171 ).to(tl.float32) 

3172 

3173 # p001: (x=1, y=0, z=0) 

3174 offset = input_base + z0_int * H_in * W_in + y0_int * W_in + x1_int 

3175 p001 = tl.load( 

3176 ptr_input + offset, 

3177 mask=tile_mask 

3178 & x1_in 

3179 & y0_in 

3180 & z0_in 

3181 & ~grid_x_nan 

3182 & ~grid_y_nan 

3183 & ~grid_z_nan, 

3184 other=0.0, 

3185 ).to(tl.float32) 

3186 

3187 # p010: (x=0, y=1, z=0) 

3188 offset = input_base + z0_int * H_in * W_in + y1_int * W_in + x0_int 

3189 p010 = tl.load( 

3190 ptr_input + offset, 

3191 mask=tile_mask 

3192 & x0_in 

3193 & y1_in 

3194 & z0_in 

3195 & ~grid_x_nan 

3196 & ~grid_y_nan 

3197 & ~grid_z_nan, 

3198 other=0.0, 

3199 ).to(tl.float32) 

3200 

3201 # p011: (x=1, y=1, z=0) 

3202 offset = input_base + z0_int * H_in * W_in + y1_int * W_in + x1_int 

3203 p011 = tl.load( 

3204 ptr_input + offset, 

3205 mask=tile_mask 

3206 & x1_in 

3207 & y1_in 

3208 & z0_in 

3209 & ~grid_x_nan 

3210 & ~grid_y_nan 

3211 & ~grid_z_nan, 

3212 other=0.0, 

3213 ).to(tl.float32) 

3214 

3215 # p100: (x=0, y=0, z=1) 

3216 offset = input_base + z1_int * H_in * W_in + y0_int * W_in + x0_int 

3217 p100 = tl.load( 

3218 ptr_input + offset, 

3219 mask=tile_mask 

3220 & x0_in 

3221 & y0_in 

3222 & z1_in 

3223 & ~grid_x_nan 

3224 & ~grid_y_nan 

3225 & ~grid_z_nan, 

3226 other=0.0, 

3227 ).to(tl.float32) 

3228 

3229 # p101: (x=1, y=0, z=1) 

3230 offset = input_base + z1_int * H_in * W_in + y0_int * W_in + x1_int 

3231 p101 = tl.load( 

3232 ptr_input + offset, 

3233 mask=tile_mask 

3234 & x1_in 

3235 & y0_in 

3236 & z1_in 

3237 & ~grid_x_nan 

3238 & ~grid_y_nan 

3239 & ~grid_z_nan, 

3240 other=0.0, 

3241 ).to(tl.float32) 

3242 

3243 # p110: (x=0, y=1, z=1) 

3244 offset = input_base + z1_int * H_in * W_in + y1_int * W_in + x0_int 

3245 p110 = tl.load( 

3246 ptr_input + offset, 

3247 mask=tile_mask 

3248 & x0_in 

3249 & y1_in 

3250 & z1_in 

3251 & ~grid_x_nan 

3252 & ~grid_y_nan 

3253 & ~grid_z_nan, 

3254 other=0.0, 

3255 ).to(tl.float32) 

3256 

3257 # p111: (x=1, y=1, z=1) 

3258 offset = input_base + z1_int * H_in * W_in + y1_int * W_in + x1_int 

3259 p111 = tl.load( 

3260 ptr_input + offset, 

3261 mask=tile_mask 

3262 & x1_in 

3263 & y1_in 

3264 & z1_in 

3265 & ~grid_x_nan 

3266 & ~grid_y_nan 

3267 & ~grid_z_nan, 

3268 other=0.0, 

3269 ).to(tl.float32) 

3270 

3271 # 3-stage trilinear interpolation 

3272 # Stage 1: Interpolate along x 

3273 c000 = p000 * (1.0 - wx) + p001 * wx # z=0, y=0 

3274 c001 = p010 * (1.0 - wx) + p011 * wx # z=0, y=1 

3275 c010 = p100 * (1.0 - wx) + p101 * wx # z=1, y=0 

3276 c011 = p110 * (1.0 - wx) + p111 * wx # z=1, y=1 

3277 

3278 # Stage 2: Interpolate along y 

3279 c00 = c000 * (1.0 - wy) + c001 * wy # z=0 

3280 c01 = c010 * (1.0 - wy) + c011 * wy # z=1 

3281 

3282 # Stage 3: Interpolate along z (final) 

3283 vals = c00 * (1.0 - wz) + c01 * wz 

3284 

3285 # Store to output 

3286 output_base = n * C * D_out * H_out * W_out + c * D_out * H_out * W_out 

3287 output_offsets = output_base + ( 

3288 d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d 

3289 ) 

3290 

3291 tl.store(ptr_output + output_offsets, vals, mask=tile_mask) 

3292 

3293 

3294@libentry() 

3295@triton.autotune( 

3296 configs=runtime.get_tuned_config("grid_sample_3d_trilinear_tiled"), 

3297 key=["N", "C", "D_out", "H_out", "W_out"], 

3298) 

3299@triton.jit 

3300def grid_sample_3d_trilinear_border_tiled_kernel( 

3301 ptr_output, 

3302 ptr_input, 

3303 ptr_grid, 

3304 N, 

3305 C, 

3306 D_in, 

3307 H_in, 

3308 W_in, 

3309 D_out, 

3310 H_out, 

3311 W_out, 

3312 align_corners: tl.constexpr, 

3313 BLOCK_D: tl.constexpr, 

3314 BLOCK_H: tl.constexpr, 

3315 BLOCK_W: tl.constexpr, 

3316): 

3317 """ 

3318 Grid sample kernel for 3D trilinear interpolation with border padding (tiled version). 

3319 """ 

3320 # 2D program IDs: pid_nc for (batch, channel), pid_dhw for spatial tile 

3321 pid_nc = tl.program_id(0) 

3322 pid_dhw = tl.program_id(1) 

3323 

3324 # Compute batch and channel 

3325 n = pid_nc // C 

3326 c = pid_nc % C 

3327 

3328 # Decompose flattened 3D tile index to (d, h, w) block indices 

3329 num_w_blocks = tl.cdiv(W_out, BLOCK_W) 

3330 num_h_blocks = tl.cdiv(H_out, BLOCK_H) 

3331 num_hw_blocks = num_h_blocks * num_w_blocks 

3332 

3333 d_block_idx = pid_dhw // num_hw_blocks 

3334 hw_block_idx = pid_dhw % num_hw_blocks 

3335 h_block_idx = hw_block_idx // num_w_blocks 

3336 w_block_idx = hw_block_idx % num_w_blocks 

3337 

3338 # Compute voxel offsets within tile (3D broadcasting) 

3339 d_offsets = d_block_idx * BLOCK_D + tl.arange(0, BLOCK_D) 

3340 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

3341 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

3342 

3343 # Mask for boundary tiles 

3344 d_mask = d_offsets < D_out 

3345 h_mask = h_offsets < H_out 

3346 w_mask = w_offsets < W_out 

3347 tile_mask = d_mask[:, None, None] & h_mask[None, :, None] & w_mask[None, None, :] 

3348 

3349 # Reshape for 3D broadcasting: (BLOCK_D, BLOCK_H, BLOCK_W) 

3350 d_out_3d = d_offsets[:, None, None] 

3351 h_out_3d = h_offsets[None, :, None] 

3352 w_out_3d = w_offsets[None, None, :] 

3353 

3354 # Load 3D grid coordinates for entire tile (vectorized) 

3355 grid_base = n * D_out * H_out * W_out * 3 

3356 

3357 grid_x_offsets = ( 

3358 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 

3359 ) 

3360 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to( 

3361 tl.float32 

3362 ) 

3363 

3364 grid_y_offsets = ( 

3365 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 1 

3366 ) 

3367 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to( 

3368 tl.float32 

3369 ) 

3370 

3371 grid_z_offsets = ( 

3372 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 2 

3373 ) 

3374 grid_z = tl.load(ptr_grid + grid_z_offsets, mask=tile_mask, other=0.0).to( 

3375 tl.float32 

3376 ) 

3377 

3378 # Handle NaN 

3379 grid_x_nan = grid_x != grid_x 

3380 grid_y_nan = grid_y != grid_y 

3381 grid_z_nan = grid_z != grid_z 

3382 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

3383 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

3384 grid_z = tl.where(grid_z_nan, -2.0, grid_z) 

3385 

3386 # Denormalize to pixel space 

3387 if align_corners: 

3388 x = (grid_x + 1.0) * (W_in - 1) / 2.0 

3389 y = (grid_y + 1.0) * (H_in - 1) / 2.0 

3390 z = (grid_z + 1.0) * (D_in - 1) / 2.0 

3391 else: 

3392 x = (grid_x + 1.0) * W_in / 2.0 - 0.5 

3393 y = (grid_y + 1.0) * H_in / 2.0 - 0.5 

3394 z = (grid_z + 1.0) * D_in / 2.0 - 0.5 

3395 

3396 # Compute 8 corner indices for entire tile 

3397 x0 = tl.floor(x) 

3398 y0 = tl.floor(y) 

3399 z0 = tl.floor(z) 

3400 x1 = x0 + 1 

3401 y1 = y0 + 1 

3402 z1 = z0 + 1 

3403 

3404 # Interpolation weights 

3405 wx = x - x0 

3406 wy = y - y0 

3407 wz = z - z0 

3408 

3409 # Convert to integers and clamp for border padding 

3410 x0_int = tl.maximum(0, tl.minimum(tl.cast(x0, tl.int32), W_in - 1)) 

3411 x1_int = tl.maximum(0, tl.minimum(tl.cast(x1, tl.int32), W_in - 1)) 

3412 y0_int = tl.maximum(0, tl.minimum(tl.cast(y0, tl.int32), H_in - 1)) 

3413 y1_int = tl.maximum(0, tl.minimum(tl.cast(y1, tl.int32), H_in - 1)) 

3414 z0_int = tl.maximum(0, tl.minimum(tl.cast(z0, tl.int32), D_in - 1)) 

3415 z1_int = tl.maximum(0, tl.minimum(tl.cast(z1, tl.int32), D_in - 1)) 

3416 

3417 # Load 8 corners (vectorized, no bounds mask needed for border) 

3418 input_base = n * C * D_in * H_in * W_in + c * D_in * H_in * W_in 

3419 

3420 offset = input_base + z0_int * H_in * W_in + y0_int * W_in + x0_int 

3421 p000 = tl.load( 

3422 ptr_input + offset, 

3423 mask=tile_mask & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan, 

3424 other=0.0, 

3425 ).to(tl.float32) 

3426 

3427 offset = input_base + z0_int * H_in * W_in + y0_int * W_in + x1_int 

3428 p001 = tl.load( 

3429 ptr_input + offset, 

3430 mask=tile_mask & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan, 

3431 other=0.0, 

3432 ).to(tl.float32) 

3433 

3434 offset = input_base + z0_int * H_in * W_in + y1_int * W_in + x0_int 

3435 p010 = tl.load( 

3436 ptr_input + offset, 

3437 mask=tile_mask & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan, 

3438 other=0.0, 

3439 ).to(tl.float32) 

3440 

3441 offset = input_base + z0_int * H_in * W_in + y1_int * W_in + x1_int 

3442 p011 = tl.load( 

3443 ptr_input + offset, 

3444 mask=tile_mask & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan, 

3445 other=0.0, 

3446 ).to(tl.float32) 

3447 

3448 offset = input_base + z1_int * H_in * W_in + y0_int * W_in + x0_int 

3449 p100 = tl.load( 

3450 ptr_input + offset, 

3451 mask=tile_mask & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan, 

3452 other=0.0, 

3453 ).to(tl.float32) 

3454 

3455 offset = input_base + z1_int * H_in * W_in + y0_int * W_in + x1_int 

3456 p101 = tl.load( 

3457 ptr_input + offset, 

3458 mask=tile_mask & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan, 

3459 other=0.0, 

3460 ).to(tl.float32) 

3461 

3462 offset = input_base + z1_int * H_in * W_in + y1_int * W_in + x0_int 

3463 p110 = tl.load( 

3464 ptr_input + offset, 

3465 mask=tile_mask & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan, 

3466 other=0.0, 

3467 ).to(tl.float32) 

3468 

3469 offset = input_base + z1_int * H_in * W_in + y1_int * W_in + x1_int 

3470 p111 = tl.load( 

3471 ptr_input + offset, 

3472 mask=tile_mask & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan, 

3473 other=0.0, 

3474 ).to(tl.float32) 

3475 

3476 # 3-stage trilinear interpolation 

3477 c000 = p000 * (1.0 - wx) + p001 * wx 

3478 c001 = p010 * (1.0 - wx) + p011 * wx 

3479 c010 = p100 * (1.0 - wx) + p101 * wx 

3480 c011 = p110 * (1.0 - wx) + p111 * wx 

3481 

3482 c00 = c000 * (1.0 - wy) + c001 * wy 

3483 c01 = c010 * (1.0 - wy) + c011 * wy 

3484 

3485 vals = c00 * (1.0 - wz) + c01 * wz 

3486 

3487 # Handle NaN 

3488 vals = tl.where(grid_x_nan | grid_y_nan | grid_z_nan, 0.0, vals) 

3489 

3490 # Store to output 

3491 output_base = n * C * D_out * H_out * W_out + c * D_out * H_out * W_out 

3492 output_offsets = output_base + ( 

3493 d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d 

3494 ) 

3495 

3496 tl.store(ptr_output + output_offsets, vals, mask=tile_mask) 

3497 

3498 

3499@libentry() 

3500@triton.autotune( 

3501 configs=runtime.get_tuned_config("grid_sample_3d_trilinear_tiled"), 

3502 key=["N", "C", "D_out", "H_out", "W_out"], 

3503) 

3504@triton.jit 

3505def grid_sample_3d_trilinear_reflection_tiled_kernel( 

3506 ptr_output, 

3507 ptr_input, 

3508 ptr_grid, 

3509 N, 

3510 C, 

3511 D_in, 

3512 H_in, 

3513 W_in, 

3514 D_out, 

3515 H_out, 

3516 W_out, 

3517 align_corners: tl.constexpr, 

3518 BLOCK_D: tl.constexpr, 

3519 BLOCK_H: tl.constexpr, 

3520 BLOCK_W: tl.constexpr, 

3521): 

3522 """ 

3523 Grid sample kernel for 3D trilinear interpolation with reflection padding (tiled version). 

3524 """ 

3525 # 2D program IDs: pid_nc for (batch, channel), pid_dhw for spatial tile 

3526 pid_nc = tl.program_id(0) 

3527 pid_dhw = tl.program_id(1) 

3528 

3529 # Compute batch and channel 

3530 n = pid_nc // C 

3531 c = pid_nc % C 

3532 

3533 # Decompose flattened 3D tile index to (d, h, w) block indices 

3534 num_w_blocks = tl.cdiv(W_out, BLOCK_W) 

3535 num_h_blocks = tl.cdiv(H_out, BLOCK_H) 

3536 num_hw_blocks = num_h_blocks * num_w_blocks 

3537 

3538 d_block_idx = pid_dhw // num_hw_blocks 

3539 hw_block_idx = pid_dhw % num_hw_blocks 

3540 h_block_idx = hw_block_idx // num_w_blocks 

3541 w_block_idx = hw_block_idx % num_w_blocks 

3542 

3543 # Compute voxel offsets within tile (3D broadcasting) 

3544 d_offsets = d_block_idx * BLOCK_D + tl.arange(0, BLOCK_D) 

3545 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

3546 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

3547 

3548 # Mask for boundary tiles 

3549 d_mask = d_offsets < D_out 

3550 h_mask = h_offsets < H_out 

3551 w_mask = w_offsets < W_out 

3552 tile_mask = d_mask[:, None, None] & h_mask[None, :, None] & w_mask[None, None, :] 

3553 

3554 # Reshape for 3D broadcasting: (BLOCK_D, BLOCK_H, BLOCK_W) 

3555 d_out_3d = d_offsets[:, None, None] 

3556 h_out_3d = h_offsets[None, :, None] 

3557 w_out_3d = w_offsets[None, None, :] 

3558 

3559 # Load 3D grid coordinates for entire tile (vectorized) 

3560 grid_base = n * D_out * H_out * W_out * 3 

3561 

3562 grid_x_offsets = ( 

3563 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 

3564 ) 

3565 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to( 

3566 tl.float32 

3567 ) 

3568 

3569 grid_y_offsets = ( 

3570 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 1 

3571 ) 

3572 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to( 

3573 tl.float32 

3574 ) 

3575 

3576 grid_z_offsets = ( 

3577 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 2 

3578 ) 

3579 grid_z = tl.load(ptr_grid + grid_z_offsets, mask=tile_mask, other=0.0).to( 

3580 tl.float32 

3581 ) 

3582 

3583 # Handle NaN 

3584 grid_x_nan = grid_x != grid_x 

3585 grid_y_nan = grid_y != grid_y 

3586 grid_z_nan = grid_z != grid_z 

3587 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

3588 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

3589 grid_z = tl.where(grid_z_nan, -2.0, grid_z) 

3590 

3591 # Apply triangle wave reflection with period 4 

3592 grid_x_shifted = grid_x + 1.0 

3593 grid_x_mod = grid_x_shifted % 4.0 

3594 grid_x_mod = tl.where(grid_x_mod < 0, grid_x_mod + 4.0, grid_x_mod) 

3595 grid_x_refl_mod = tl.where(grid_x_mod <= 2.0, grid_x_mod, 4.0 - grid_x_mod) 

3596 grid_x = grid_x_refl_mod - 1.0 

3597 

3598 grid_y_shifted = grid_y + 1.0 

3599 grid_y_mod = grid_y_shifted % 4.0 

3600 grid_y_mod = tl.where(grid_y_mod < 0, grid_y_mod + 4.0, grid_y_mod) 

3601 grid_y_refl_mod = tl.where(grid_y_mod <= 2.0, grid_y_mod, 4.0 - grid_y_mod) 

3602 grid_y = grid_y_refl_mod - 1.0 

3603 

3604 grid_z_shifted = grid_z + 1.0 

3605 grid_z_mod = grid_z_shifted % 4.0 

3606 grid_z_mod = tl.where(grid_z_mod < 0, grid_z_mod + 4.0, grid_z_mod) 

3607 grid_z_refl_mod = tl.where(grid_z_mod <= 2.0, grid_z_mod, 4.0 - grid_z_mod) 

3608 grid_z = grid_z_refl_mod - 1.0 

3609 

3610 # Denormalize reflected coordinates to pixel space 

3611 if align_corners: 

3612 x = (grid_x + 1.0) * (W_in - 1) / 2.0 

3613 y = (grid_y + 1.0) * (H_in - 1) / 2.0 

3614 z = (grid_z + 1.0) * (D_in - 1) / 2.0 

3615 else: 

3616 x = (grid_x + 1.0) * W_in / 2.0 - 0.5 

3617 y = (grid_y + 1.0) * H_in / 2.0 - 0.5 

3618 z = (grid_z + 1.0) * D_in / 2.0 - 0.5 

3619 

3620 # Compute 8 corner indices for entire tile 

3621 x0 = tl.floor(x) 

3622 y0 = tl.floor(y) 

3623 z0 = tl.floor(z) 

3624 x1 = x0 + 1 

3625 y1 = y0 + 1 

3626 z1 = z0 + 1 

3627 

3628 # Interpolation weights 

3629 wx = x - x0 

3630 wy = y - y0 

3631 wz = z - z0 

3632 

3633 # Convert to integers 

3634 x0_int = tl.cast(x0, tl.int32) 

3635 y0_int = tl.cast(y0, tl.int32) 

3636 z0_int = tl.cast(z0, tl.int32) 

3637 x1_int = tl.cast(x1, tl.int32) 

3638 y1_int = tl.cast(y1, tl.int32) 

3639 z1_int = tl.cast(z1, tl.int32) 

3640 

3641 # Boundary checks (reflection ensures coordinates are mostly valid, but still check) 

3642 x0_in = (x0_int >= 0) & (x0_int < W_in) 

3643 x1_in = (x1_int >= 0) & (x1_int < W_in) 

3644 y0_in = (y0_int >= 0) & (y0_int < H_in) 

3645 y1_in = (y1_int >= 0) & (y1_int < H_in) 

3646 z0_in = (z0_int >= 0) & (z0_int < D_in) 

3647 z1_in = (z1_int >= 0) & (z1_int < D_in) 

3648 

3649 # Load 8 corners (vectorized) 

3650 input_base = n * C * D_in * H_in * W_in + c * D_in * H_in * W_in 

3651 

3652 offset = input_base + z0_int * H_in * W_in + y0_int * W_in + x0_int 

3653 p000 = tl.load( 

3654 ptr_input + offset, 

3655 mask=tile_mask 

3656 & x0_in 

3657 & y0_in 

3658 & z0_in 

3659 & ~grid_x_nan 

3660 & ~grid_y_nan 

3661 & ~grid_z_nan, 

3662 other=0.0, 

3663 ).to(tl.float32) 

3664 

3665 offset = input_base + z0_int * H_in * W_in + y0_int * W_in + x1_int 

3666 p001 = tl.load( 

3667 ptr_input + offset, 

3668 mask=tile_mask 

3669 & x1_in 

3670 & y0_in 

3671 & z0_in 

3672 & ~grid_x_nan 

3673 & ~grid_y_nan 

3674 & ~grid_z_nan, 

3675 other=0.0, 

3676 ).to(tl.float32) 

3677 

3678 offset = input_base + z0_int * H_in * W_in + y1_int * W_in + x0_int 

3679 p010 = tl.load( 

3680 ptr_input + offset, 

3681 mask=tile_mask 

3682 & x0_in 

3683 & y1_in 

3684 & z0_in 

3685 & ~grid_x_nan 

3686 & ~grid_y_nan 

3687 & ~grid_z_nan, 

3688 other=0.0, 

3689 ).to(tl.float32) 

3690 

3691 offset = input_base + z0_int * H_in * W_in + y1_int * W_in + x1_int 

3692 p011 = tl.load( 

3693 ptr_input + offset, 

3694 mask=tile_mask 

3695 & x1_in 

3696 & y1_in 

3697 & z0_in 

3698 & ~grid_x_nan 

3699 & ~grid_y_nan 

3700 & ~grid_z_nan, 

3701 other=0.0, 

3702 ).to(tl.float32) 

3703 

3704 offset = input_base + z1_int * H_in * W_in + y0_int * W_in + x0_int 

3705 p100 = tl.load( 

3706 ptr_input + offset, 

3707 mask=tile_mask 

3708 & x0_in 

3709 & y0_in 

3710 & z1_in 

3711 & ~grid_x_nan 

3712 & ~grid_y_nan 

3713 & ~grid_z_nan, 

3714 other=0.0, 

3715 ).to(tl.float32) 

3716 

3717 offset = input_base + z1_int * H_in * W_in + y0_int * W_in + x1_int 

3718 p101 = tl.load( 

3719 ptr_input + offset, 

3720 mask=tile_mask 

3721 & x1_in 

3722 & y0_in 

3723 & z1_in 

3724 & ~grid_x_nan 

3725 & ~grid_y_nan 

3726 & ~grid_z_nan, 

3727 other=0.0, 

3728 ).to(tl.float32) 

3729 

3730 offset = input_base + z1_int * H_in * W_in + y1_int * W_in + x0_int 

3731 p110 = tl.load( 

3732 ptr_input + offset, 

3733 mask=tile_mask 

3734 & x0_in 

3735 & y1_in 

3736 & z1_in 

3737 & ~grid_x_nan 

3738 & ~grid_y_nan 

3739 & ~grid_z_nan, 

3740 other=0.0, 

3741 ).to(tl.float32) 

3742 

3743 offset = input_base + z1_int * H_in * W_in + y1_int * W_in + x1_int 

3744 p111 = tl.load( 

3745 ptr_input + offset, 

3746 mask=tile_mask 

3747 & x1_in 

3748 & y1_in 

3749 & z1_in 

3750 & ~grid_x_nan 

3751 & ~grid_y_nan 

3752 & ~grid_z_nan, 

3753 other=0.0, 

3754 ).to(tl.float32) 

3755 

3756 # 3-stage trilinear interpolation 

3757 c000 = p000 * (1.0 - wx) + p001 * wx 

3758 c001 = p010 * (1.0 - wx) + p011 * wx 

3759 c010 = p100 * (1.0 - wx) + p101 * wx 

3760 c011 = p110 * (1.0 - wx) + p111 * wx 

3761 

3762 c00 = c000 * (1.0 - wy) + c001 * wy 

3763 c01 = c010 * (1.0 - wy) + c011 * wy 

3764 

3765 vals = c00 * (1.0 - wz) + c01 * wz 

3766 

3767 # Store to output 

3768 output_base = n * C * D_out * H_out * W_out + c * D_out * H_out * W_out 

3769 output_offsets = output_base + ( 

3770 d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d 

3771 ) 

3772 

3773 tl.store(ptr_output + output_offsets, vals, mask=tile_mask) 

3774 

3775 

3776# ============================================================================ 

3777# Tiled Kernels for Medium-to-Large Inputs (Multi-dimensional Blocking) 

3778# ============================================================================ 

3779 

3780 

3781@libentry() 

3782@triton.autotune( 

3783 configs=runtime.get_tuned_config("grid_sample_2d_nearest_tiled"), 

3784 key=["N", "C", "H_out", "W_out"], 

3785) 

3786@triton.jit 

3787def grid_sample_2d_nearest_zeros_tiled_kernel( 

3788 ptr_output, 

3789 ptr_input, 

3790 ptr_grid, 

3791 N, 

3792 C, 

3793 H_in, 

3794 W_in, 

3795 H_out, 

3796 W_out, 

3797 align_corners: tl.constexpr, 

3798 BLOCK_H: tl.constexpr, 

3799 BLOCK_W: tl.constexpr, 

3800): 

3801 """ 

3802 Grid sample kernel for 2D nearest neighbor interpolation with zeros padding (tiled version). 

3803 

3804 This kernel processes a BLOCK_H × BLOCK_W tile of output pixels at once, 

3805 enabling better memory coalescing and data reuse for medium-to-large inputs. 

3806 

3807 Args: 

3808 ptr_output: Pointer to output tensor 

3809 ptr_input: Pointer to input tensor 

3810 ptr_grid: Pointer to grid tensor 

3811 N: Batch size 

3812 C: Number of channels 

3813 H_in: Input height 

3814 W_in: Input width 

3815 H_out: Output height 

3816 W_out: Output width 

3817 align_corners: Whether to align corners 

3818 BLOCK_H: Block height for tiling 

3819 BLOCK_W: Block width for tiling 

3820 """ 

3821 # 2D program IDs: pid_nc for (batch, channel), pid_hw for spatial tile 

3822 pid_nc = tl.program_id(0) 

3823 pid_hw = tl.program_id(1) 

3824 

3825 # Compute batch and channel 

3826 n = pid_nc // C 

3827 c = pid_nc % C 

3828 

3829 # Compute tile position in output grid 

3830 num_w_blocks = tl.cdiv(W_out, BLOCK_W) 

3831 h_block_idx = pid_hw // num_w_blocks 

3832 w_block_idx = pid_hw % num_w_blocks 

3833 

3834 # Compute pixel offsets within tile 

3835 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

3836 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

3837 

3838 # Mask for boundary tiles 

3839 h_mask = h_offsets < H_out 

3840 w_mask = w_offsets < W_out 

3841 tile_mask = h_mask[:, None] & w_mask[None, :] 

3842 

3843 # Reshape for broadcasting: (BLOCK_H, BLOCK_W) 

3844 h_out_flat = h_offsets[:, None] 

3845 w_out_flat = w_offsets[None, :] 

3846 

3847 # Load grid coordinates for entire tile (vectorized) 

3848 # Grid shape: (N, H_out, W_out, 2) 

3849 grid_base = n * H_out * W_out * 2 

3850 

3851 # Load x coordinates: (BLOCK_H, BLOCK_W) 

3852 grid_x_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2 

3853 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to( 

3854 tl.float32 

3855 ) 

3856 

3857 # Load y coordinates: (BLOCK_H, BLOCK_W) 

3858 grid_y_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2 + 1 

3859 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to( 

3860 tl.float32 

3861 ) 

3862 

3863 # Handle NaN - use sentinel value -2.0 (outside valid grid range [-1, 1]) 

3864 grid_x_nan = grid_x != grid_x # True if NaN 

3865 grid_y_nan = grid_y != grid_y # True if NaN 

3866 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

3867 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

3868 

3869 # Denormalize to pixel space 

3870 if align_corners: 

3871 # Pixel centers at -1 and 1 

3872 x = (grid_x + 1.0) * (W_in - 1) / 2.0 

3873 y = (grid_y + 1.0) * (H_in - 1) / 2.0 

3874 else: 

3875 # Pixel corners at -1 and 1 

3876 x = (grid_x + 1.0) * W_in / 2.0 - 0.5 

3877 y = (grid_y + 1.0) * H_in / 2.0 - 0.5 

3878 

3879 # Apply banker's rounding (vectorized across tile) 

3880 x_floor = tl.floor(x) 

3881 y_floor = tl.floor(y) 

3882 x_frac = x - x_floor 

3883 y_frac = y - y_floor 

3884 

3885 x_is_half = x_frac == 0.5 

3886 y_is_half = y_frac == 0.5 

3887 x_floor_int = tl.cast(x_floor, tl.int32) 

3888 y_floor_int = tl.cast(y_floor, tl.int32) 

3889 

3890 x_is_even = x_floor_int % 2 == 0 

3891 y_is_even = y_floor_int % 2 == 0 

3892 

3893 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1) 

3894 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1) 

3895 

3896 x_idx = tl.cast( 

3897 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round), 

3898 tl.int32, 

3899 ) 

3900 y_idx = tl.cast( 

3901 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round), 

3902 tl.int32, 

3903 ) 

3904 

3905 # Check bounds (vectorized) 

3906 x_in_bounds = (x_idx >= 0) & (x_idx < W_in) 

3907 y_in_bounds = (y_idx >= 0) & (y_idx < H_in) 

3908 valid_mask = tile_mask & x_in_bounds & y_in_bounds & ~grid_x_nan & ~grid_y_nan 

3909 

3910 # Load input pixels for entire tile 

3911 # Input shape: (N, C, H_in, W_in) 

3912 input_base = n * C * H_in * W_in + c * H_in * W_in 

3913 input_offsets = input_base + y_idx * W_in + x_idx 

3914 

3915 # Vectorized load: (BLOCK_H, BLOCK_W) 

3916 vals = tl.load(ptr_input + input_offsets, mask=valid_mask, other=0.0) 

3917 

3918 # Store to output 

3919 # Output shape: (N, C, H_out, W_out) 

3920 output_base = n * C * H_out * W_out + c * H_out * W_out 

3921 output_offsets = output_base + (h_out_flat * W_out + w_out_flat) 

3922 

3923 tl.store(ptr_output + output_offsets, vals, mask=tile_mask) 

3924 

3925 

3926@libentry() 

3927@triton.autotune( 

3928 configs=runtime.get_tuned_config("grid_sample_2d_bilinear_tiled"), 

3929 key=["N", "C", "H_out", "W_out"], 

3930) 

3931@triton.jit 

3932def grid_sample_2d_bilinear_zeros_tiled_kernel( 

3933 ptr_output, 

3934 ptr_input, 

3935 ptr_grid, 

3936 N, 

3937 C, 

3938 H_in, 

3939 W_in, 

3940 H_out, 

3941 W_out, 

3942 align_corners: tl.constexpr, 

3943 BLOCK_H: tl.constexpr, 

3944 BLOCK_W: tl.constexpr, 

3945): 

3946 """ 

3947 Grid sample kernel for 2D bilinear interpolation with zeros padding (tiled version). 

3948 

3949 This kernel processes a BLOCK_H × BLOCK_W tile of output pixels at once, 

3950 enabling better memory coalescing and data reuse for medium-to-large inputs. 

3951 

3952 Args: 

3953 ptr_output: Pointer to output tensor 

3954 ptr_input: Pointer to input tensor 

3955 ptr_grid: Pointer to grid tensor 

3956 N: Batch size 

3957 C: Number of channels 

3958 H_in: Input height 

3959 W_in: Input width 

3960 H_out: Output height 

3961 W_out: Output width 

3962 align_corners: Whether to align corners 

3963 BLOCK_H: Block height for tiling 

3964 BLOCK_W: Block width for tiling 

3965 """ 

3966 # 2D program IDs: pid_nc for (batch, channel), pid_hw for spatial tile 

3967 pid_nc = tl.program_id(0) 

3968 pid_hw = tl.program_id(1) 

3969 

3970 # Compute batch and channel 

3971 n = pid_nc // C 

3972 c = pid_nc % C 

3973 

3974 # Compute tile position in output grid 

3975 num_w_blocks = tl.cdiv(W_out, BLOCK_W) 

3976 h_block_idx = pid_hw // num_w_blocks 

3977 w_block_idx = pid_hw % num_w_blocks 

3978 

3979 # Compute pixel offsets within tile 

3980 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

3981 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

3982 

3983 # Mask for boundary tiles 

3984 h_mask = h_offsets < H_out 

3985 w_mask = w_offsets < W_out 

3986 tile_mask = h_mask[:, None] & w_mask[None, :] 

3987 

3988 # Reshape for broadcasting: (BLOCK_H, BLOCK_W) 

3989 h_out_flat = h_offsets[:, None] 

3990 w_out_flat = w_offsets[None, :] 

3991 

3992 # Load grid coordinates for entire tile (vectorized) 

3993 # Grid shape: (N, H_out, W_out, 2) 

3994 grid_base = n * H_out * W_out * 2 

3995 

3996 # Load x coordinates: (BLOCK_H, BLOCK_W) 

3997 grid_x_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2 

3998 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to( 

3999 tl.float32 

4000 ) 

4001 

4002 # Load y coordinates: (BLOCK_H, BLOCK_W) 

4003 grid_y_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2 + 1 

4004 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to( 

4005 tl.float32 

4006 ) 

4007 

4008 # Handle NaN - use sentinel value -2.0 

4009 grid_x_nan = grid_x != grid_x 

4010 grid_y_nan = grid_y != grid_y 

4011 grid_x = tl.where(grid_x_nan, -2.0, grid_x) 

4012 grid_y = tl.where(grid_y_nan, -2.0, grid_y) 

4013 

4014 # Denormalize to pixel space 

4015 if align_corners: 

4016 # Pixel centers at -1 and 1 

4017 x = (grid_x + 1.0) * (W_in - 1) / 2.0 

4018 y = (grid_y + 1.0) * (H_in - 1) / 2.0 

4019 else: 

4020 # Pixel corners at -1 and 1 

4021 x = (grid_x + 1.0) * W_in / 2.0 - 0.5 

4022 y = (grid_y + 1.0) * H_in / 2.0 - 0.5 

4023 

4024 # Compute corner indices for entire tile (vectorized) 

4025 x0 = tl.floor(x) 

4026 x1 = x0 + 1 

4027 y0 = tl.floor(y) 

4028 y1 = y0 + 1 

4029 

4030 # Cast to int for indexing 

4031 x0_int = tl.cast(x0, tl.int32) 

4032 x1_int = tl.cast(x1, tl.int32) 

4033 y0_int = tl.cast(y0, tl.int32) 

4034 y1_int = tl.cast(y1, tl.int32) 

4035 

4036 # Check bounds for all 4 corners 

4037 x0_in = (x0_int >= 0) & (x0_int < W_in) 

4038 x1_in = (x1_int >= 0) & (x1_int < W_in) 

4039 y0_in = (y0_int >= 0) & (y0_int < H_in) 

4040 y1_in = (y1_int >= 0) & (y1_int < H_in) 

4041 

4042 # Compute interpolation weights 

4043 wx = x - tl.cast(x0, tl.float32) 

4044 wy = y - tl.cast(y0, tl.float32) 

4045 

4046 # Load 4 corner pixels (vectorized) 

4047 # Input shape: (N, C, H_in, W_in) 

4048 input_base = n * C * H_in * W_in + c * H_in * W_in 

4049 

4050 p00_offsets = input_base + y0_int * W_in + x0_int 

4051 p00 = tl.load( 

4052 ptr_input + p00_offsets, 

4053 mask=tile_mask & x0_in & y0_in & ~grid_x_nan & ~grid_y_nan, 

4054 other=0.0, 

4055 ) 

4056 

4057 p01_offsets = input_base + y0_int * W_in + x1_int 

4058 p01 = tl.load( 

4059 ptr_input + p01_offsets, 

4060 mask=tile_mask & x1_in & y0_in & ~grid_x_nan & ~grid_y_nan, 

4061 other=0.0, 

4062 ) 

4063 

4064 p10_offsets = input_base + y1_int * W_in + x0_int 

4065 p10 = tl.load( 

4066 ptr_input + p10_offsets, 

4067 mask=tile_mask & x0_in & y1_in & ~grid_x_nan & ~grid_y_nan, 

4068 other=0.0, 

4069 ) 

4070 

4071 p11_offsets = input_base + y1_int * W_in + x1_int 

4072 p11 = tl.load( 

4073 ptr_input + p11_offsets, 

4074 mask=tile_mask & x1_in & y1_in & ~grid_x_nan & ~grid_y_nan, 

4075 other=0.0, 

4076 ) 

4077 

4078 # Bilinear interpolation (vectorized) 

4079 # Interpolate along x, then y 

4080 top = p00 * (1.0 - wx) + p01 * wx 

4081 bottom = p10 * (1.0 - wx) + p11 * wx 

4082 vals = top * (1.0 - wy) + bottom * wy 

4083 

4084 # Store to output 

4085 # Output shape: (N, C, H_out, W_out) 

4086 output_base = n * C * H_out * W_out + c * H_out * W_out 

4087 output_offsets = output_base + (h_out_flat * W_out + w_out_flat) 

4088 

4089 tl.store(ptr_output + output_offsets, vals, mask=tile_mask) 

4090 

4091 

4092@libentry() 

4093@triton.autotune( 

4094 configs=runtime.get_tuned_config("grid_sample_2d_nearest_tiled"), 

4095 key=["N", "C", "H_out", "W_out"], 

4096) 

4097@triton.jit 

4098def grid_sample_2d_nearest_border_tiled_kernel( 

4099 ptr_output, 

4100 ptr_input, 

4101 ptr_grid, 

4102 N, 

4103 C, 

4104 H_in, 

4105 W_in, 

4106 H_out, 

4107 W_out, 

4108 align_corners: tl.constexpr, 

4109 BLOCK_H: tl.constexpr, 

4110 BLOCK_W: tl.constexpr, 

4111): 

4112 """ 

4113 Grid sample kernel for 2D nearest neighbor interpolation with border padding (tiled version). 

4114 

4115 Border padding: coordinates are clamped to valid range [0, W_in) x [0, H_in). 

4116 """ 

4117 # 2D program IDs: pid_nc for (batch, channel), pid_hw for spatial tile 

4118 pid_nc = tl.program_id(0) 

4119 pid_hw = tl.program_id(1) 

4120 

4121 # Compute batch and channel 

4122 n = pid_nc // C 

4123 c = pid_nc % C 

4124 

4125 # Compute tile position in output grid 

4126 num_w_blocks = tl.cdiv(W_out, BLOCK_W) 

4127 h_block_idx = pid_hw // num_w_blocks 

4128 w_block_idx = pid_hw % num_w_blocks 

4129 

4130 # Compute pixel offsets within tile 

4131 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

4132 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

4133 

4134 # Mask for boundary tiles 

4135 h_mask = h_offsets < H_out 

4136 w_mask = w_offsets < W_out 

4137 tile_mask = h_mask[:, None] & w_mask[None, :] 

4138 

4139 # Reshape for broadcasting: (BLOCK_H, BLOCK_W) 

4140 h_out_flat = h_offsets[:, None] 

4141 w_out_flat = w_offsets[None, :] 

4142 

4143 # Load grid coordinates for entire tile (vectorized) 

4144 grid_base = n * H_out * W_out * 2 

4145 

4146 # Load x coordinates: (BLOCK_H, BLOCK_W) 

4147 grid_x_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2 

4148 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to( 

4149 tl.float32 

4150 ) 

4151 

4152 # Load y coordinates: (BLOCK_H, BLOCK_W) 

4153 grid_y_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2 + 1 

4154 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to( 

4155 tl.float32 

4156 ) 

4157 

4158 # Handle NaN - use sentinel -1.0 like original kernel 

4159 grid_x = tl.where(grid_x != grid_x, -1.0, grid_x) 

4160 grid_y = tl.where(grid_y != grid_y, -1.0, grid_y) 

4161 

4162 # Denormalize to pixel space 

4163 if align_corners: 

4164 x = (grid_x + 1.0) * (W_in - 1) / 2.0 

4165 y = (grid_y + 1.0) * (H_in - 1) / 2.0 

4166 else: 

4167 x = (grid_x + 1.0) * W_in / 2.0 - 0.5 

4168 y = (grid_y + 1.0) * H_in / 2.0 - 0.5 

4169 

4170 # Apply banker's rounding (vectorized across tile) 

4171 x_floor = tl.floor(x) 

4172 y_floor = tl.floor(y) 

4173 x_frac = x - x_floor 

4174 y_frac = y - y_floor 

4175 

4176 x_is_half = x_frac == 0.5 

4177 y_is_half = y_frac == 0.5 

4178 x_floor_int = tl.cast(x_floor, tl.int32) 

4179 y_floor_int = tl.cast(y_floor, tl.int32) 

4180 

4181 x_is_even = x_floor_int % 2 == 0 

4182 y_is_even = y_floor_int % 2 == 0 

4183 

4184 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1) 

4185 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1) 

4186 

4187 x_idx_unclamped = tl.cast( 

4188 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round), 

4189 tl.int32, 

4190 ) 

4191 y_idx_unclamped = tl.cast( 

4192 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round), 

4193 tl.int32, 

4194 ) 

4195 

4196 # Clamp to valid range (border padding) 

4197 x_idx = tl.maximum(0, tl.minimum(x_idx_unclamped, W_in - 1)) 

4198 y_idx = tl.maximum(0, tl.minimum(y_idx_unclamped, H_in - 1)) 

4199 

4200 # Load input pixels for entire tile (no mask needed - clamping ensures validity) 

4201 input_base = n * C * H_in * W_in + c * H_in * W_in 

4202 input_offsets = input_base + y_idx * W_in + x_idx 

4203 

4204 vals = tl.load(ptr_input + input_offsets, mask=tile_mask, other=0.0) 

4205 

4206 # Store to output 

4207 output_base = n * C * H_out * W_out + c * H_out * W_out 

4208 output_offsets = output_base + (h_out_flat * W_out + w_out_flat) 

4209 

4210 tl.store(ptr_output + output_offsets, vals, mask=tile_mask) 

4211 

4212 

4213@libentry() 

4214@triton.autotune( 

4215 configs=runtime.get_tuned_config("grid_sample_2d_bilinear_tiled"), 

4216 key=["N", "C", "H_out", "W_out"], 

4217) 

4218@triton.jit 

4219def grid_sample_2d_bilinear_border_tiled_kernel( 

4220 ptr_output, 

4221 ptr_input, 

4222 ptr_grid, 

4223 N, 

4224 C, 

4225 H_in, 

4226 W_in, 

4227 H_out, 

4228 W_out, 

4229 align_corners: tl.constexpr, 

4230 BLOCK_H: tl.constexpr, 

4231 BLOCK_W: tl.constexpr, 

4232): 

4233 """ 

4234 Grid sample kernel for 2D bilinear interpolation with border padding (tiled version). 

4235 

4236 Border padding: coordinates are clamped to valid range [0, W_in) x [0, H_in). 

4237 """ 

4238 # 2D program IDs: pid_nc for (batch, channel), pid_hw for spatial tile 

4239 pid_nc = tl.program_id(0) 

4240 pid_hw = tl.program_id(1) 

4241 

4242 # Compute batch and channel 

4243 n = pid_nc // C 

4244 c = pid_nc % C 

4245 

4246 # Compute tile position in output grid 

4247 num_w_blocks = tl.cdiv(W_out, BLOCK_W) 

4248 h_block_idx = pid_hw // num_w_blocks 

4249 w_block_idx = pid_hw % num_w_blocks 

4250 

4251 # Compute pixel offsets within tile 

4252 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

4253 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

4254 

4255 # Mask for boundary tiles 

4256 h_mask = h_offsets < H_out 

4257 w_mask = w_offsets < W_out 

4258 tile_mask = h_mask[:, None] & w_mask[None, :] 

4259 

4260 # Reshape for broadcasting: (BLOCK_H, BLOCK_W) 

4261 h_out_flat = h_offsets[:, None] 

4262 w_out_flat = w_offsets[None, :] 

4263 

4264 # Load grid coordinates for entire tile (vectorized) 

4265 grid_base = n * H_out * W_out * 2 

4266 

4267 # Load x coordinates: (BLOCK_H, BLOCK_W) 

4268 grid_x_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2 

4269 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to( 

4270 tl.float32 

4271 ) 

4272 

4273 # Load y coordinates: (BLOCK_H, BLOCK_W) 

4274 grid_y_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2 + 1 

4275 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to( 

4276 tl.float32 

4277 ) 

4278 

4279 # Handle NaN - use sentinel -1.0 like original kernel 

4280 grid_x = tl.where(grid_x != grid_x, -1.0, grid_x) 

4281 grid_y = tl.where(grid_y != grid_y, -1.0, grid_y) 

4282 

4283 # Denormalize to pixel space 

4284 if align_corners: 

4285 x = (grid_x + 1.0) * (W_in - 1) / 2.0 

4286 y = (grid_y + 1.0) * (H_in - 1) / 2.0 

4287 else: 

4288 x = (grid_x + 1.0) * W_in / 2.0 - 0.5 

4289 y = (grid_y + 1.0) * H_in / 2.0 - 0.5 

4290 

4291 # Compute corner indices for entire tile (vectorized) 

4292 x0 = tl.floor(x) 

4293 x1 = x0 + 1 

4294 y0 = tl.floor(y) 

4295 y1 = y0 + 1 

4296 

4297 # Cast to int for indexing 

4298 x0_int = tl.cast(x0, tl.int32) 

4299 x1_int = tl.cast(x1, tl.int32) 

4300 y0_int = tl.cast(y0, tl.int32) 

4301 y1_int = tl.cast(y1, tl.int32) 

4302 

4303 # Clamp to valid range (border padding) 

4304 x0_int = tl.maximum(0, tl.minimum(x0_int, W_in - 1)) 

4305 x1_int = tl.maximum(0, tl.minimum(x1_int, W_in - 1)) 

4306 y0_int = tl.maximum(0, tl.minimum(y0_int, H_in - 1)) 

4307 y1_int = tl.maximum(0, tl.minimum(y1_int, H_in - 1)) 

4308 

4309 # Compute interpolation weights 

4310 wx = x - tl.cast(x0, tl.float32) 

4311 wy = y - tl.cast(y0, tl.float32) 

4312 

4313 # Load 4 corner pixels (vectorized, no mask needed - clamping ensures validity) 

4314 input_base = n * C * H_in * W_in + c * H_in * W_in 

4315 

4316 p00_offsets = input_base + y0_int * W_in + x0_int 

4317 p00 = tl.load(ptr_input + p00_offsets, mask=tile_mask, other=0.0) 

4318 

4319 p01_offsets = input_base + y0_int * W_in + x1_int 

4320 p01 = tl.load(ptr_input + p01_offsets, mask=tile_mask, other=0.0) 

4321 

4322 p10_offsets = input_base + y1_int * W_in + x0_int 

4323 p10 = tl.load(ptr_input + p10_offsets, mask=tile_mask, other=0.0) 

4324 

4325 p11_offsets = input_base + y1_int * W_in + x1_int 

4326 p11 = tl.load(ptr_input + p11_offsets, mask=tile_mask, other=0.0) 

4327 

4328 # Bilinear interpolation (vectorized) 

4329 top = p00 * (1.0 - wx) + p01 * wx 

4330 bottom = p10 * (1.0 - wx) + p11 * wx 

4331 vals = top * (1.0 - wy) + bottom * wy 

4332 

4333 # Store to output 

4334 output_base = n * C * H_out * W_out + c * H_out * W_out 

4335 output_offsets = output_base + (h_out_flat * W_out + w_out_flat) 

4336 

4337 tl.store(ptr_output + output_offsets, vals, mask=tile_mask) 

4338 

4339 

4340@libentry() 

4341@triton.autotune( 

4342 configs=runtime.get_tuned_config("grid_sample_2d_nearest_tiled"), 

4343 key=["N", "C", "H_out", "W_out"], 

4344) 

4345@triton.jit 

4346def grid_sample_2d_nearest_reflection_tiled_kernel( 

4347 ptr_output, 

4348 ptr_input, 

4349 ptr_grid, 

4350 N, 

4351 C, 

4352 H_in, 

4353 W_in, 

4354 H_out, 

4355 W_out, 

4356 align_corners: tl.constexpr, 

4357 BLOCK_H: tl.constexpr, 

4358 BLOCK_W: tl.constexpr, 

4359): 

4360 """ 

4361 Grid sample kernel for 2D nearest neighbor interpolation with reflection padding (tiled version). 

4362 

4363 Reflection padding: applies triangle wave reflection in grid space. 

4364 """ 

4365 # 2D program IDs: pid_nc for (batch, channel), pid_hw for spatial tile 

4366 pid_nc = tl.program_id(0) 

4367 pid_hw = tl.program_id(1) 

4368 

4369 # Compute batch and channel 

4370 n = pid_nc // C 

4371 c = pid_nc % C 

4372 

4373 # Compute tile position in output grid 

4374 num_w_blocks = tl.cdiv(W_out, BLOCK_W) 

4375 h_block_idx = pid_hw // num_w_blocks 

4376 w_block_idx = pid_hw % num_w_blocks 

4377 

4378 # Compute pixel offsets within tile 

4379 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

4380 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

4381 

4382 # Mask for boundary tiles 

4383 h_mask = h_offsets < H_out 

4384 w_mask = w_offsets < W_out 

4385 tile_mask = h_mask[:, None] & w_mask[None, :] 

4386 

4387 # Reshape for broadcasting: (BLOCK_H, BLOCK_W) 

4388 h_out_flat = h_offsets[:, None] 

4389 w_out_flat = w_offsets[None, :] 

4390 

4391 # Load grid coordinates for entire tile (vectorized) 

4392 grid_base = n * H_out * W_out * 2 

4393 

4394 # Load x coordinates: (BLOCK_H, BLOCK_W) 

4395 grid_x_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2 

4396 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to( 

4397 tl.float32 

4398 ) 

4399 

4400 # Load y coordinates: (BLOCK_H, BLOCK_W) 

4401 grid_y_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2 + 1 

4402 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to( 

4403 tl.float32 

4404 ) 

4405 

4406 # Apply triangle wave reflection in grid space (vectorized) 

4407 # Triangle wave pattern with period 4 

4408 grid_x_shifted = grid_x + 1.0 

4409 grid_x_mod = grid_x_shifted % 4.0 

4410 grid_x_mod = tl.where(grid_x_mod < 0, grid_x_mod + 4.0, grid_x_mod) 

4411 grid_x_refl_mod = tl.where(grid_x_mod <= 2.0, grid_x_mod, 4.0 - grid_x_mod) 

4412 x = grid_x_refl_mod - 1.0 

4413 

4414 grid_y_shifted = grid_y + 1.0 

4415 grid_y_mod = grid_y_shifted % 4.0 

4416 grid_y_mod = tl.where(grid_y_mod < 0, grid_y_mod + 4.0, grid_y_mod) 

4417 grid_y_refl_mod = tl.where(grid_y_mod <= 2.0, grid_y_mod, 4.0 - grid_y_mod) 

4418 y = grid_y_refl_mod - 1.0 

4419 

4420 # Denormalize to pixel space 

4421 if align_corners: 

4422 x = (x + 1.0) * (W_in - 1) / 2.0 

4423 y = (y + 1.0) * (H_in - 1) / 2.0 

4424 else: 

4425 x = (x + 1.0) * W_in / 2.0 - 0.5 

4426 y = (y + 1.0) * H_in / 2.0 - 0.5 

4427 

4428 # Apply banker's rounding (vectorized across tile) 

4429 x_floor = tl.floor(x) 

4430 y_floor = tl.floor(y) 

4431 x_frac = x - x_floor 

4432 y_frac = y - y_floor 

4433 

4434 x_is_half = x_frac == 0.5 

4435 y_is_half = y_frac == 0.5 

4436 x_floor_int = tl.cast(x_floor, tl.int32) 

4437 y_floor_int = tl.cast(y_floor, tl.int32) 

4438 

4439 x_is_even = x_floor_int % 2 == 0 

4440 y_is_even = y_floor_int % 2 == 0 

4441 

4442 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1) 

4443 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1) 

4444 

4445 x_idx_unclamped = tl.cast( 

4446 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round), 

4447 tl.int32, 

4448 ) 

4449 y_idx_unclamped = tl.cast( 

4450 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round), 

4451 tl.int32, 

4452 ) 

4453 

4454 # Clamp to valid bounds (should already be in bounds due to reflection, but clamp for safety) 

4455 x_idx = tl.maximum(0, tl.minimum(x_idx_unclamped, W_in - 1)) 

4456 y_idx = tl.maximum(0, tl.minimum(y_idx_unclamped, H_in - 1)) 

4457 

4458 # Load input pixels for entire tile 

4459 input_base = n * C * H_in * W_in + c * H_in * W_in 

4460 input_offsets = input_base + y_idx * W_in + x_idx 

4461 

4462 vals = tl.load(ptr_input + input_offsets, mask=tile_mask, other=0.0) 

4463 

4464 # Store to output 

4465 output_base = n * C * H_out * W_out + c * H_out * W_out 

4466 output_offsets = output_base + (h_out_flat * W_out + w_out_flat) 

4467 

4468 tl.store(ptr_output + output_offsets, vals, mask=tile_mask) 

4469 

4470 

4471@libentry() 

4472@triton.autotune( 

4473 configs=runtime.get_tuned_config("grid_sample_2d_bilinear_tiled"), 

4474 key=["N", "C", "H_out", "W_out"], 

4475) 

4476@triton.jit 

4477def grid_sample_2d_bilinear_reflection_tiled_kernel( 

4478 ptr_output, 

4479 ptr_input, 

4480 ptr_grid, 

4481 N, 

4482 C, 

4483 H_in, 

4484 W_in, 

4485 H_out, 

4486 W_out, 

4487 align_corners: tl.constexpr, 

4488 BLOCK_H: tl.constexpr, 

4489 BLOCK_W: tl.constexpr, 

4490): 

4491 """ 

4492 Grid sample kernel for 2D bilinear interpolation with reflection padding (tiled version). 

4493 

4494 Reflection padding: applies triangle wave reflection in grid space. 

4495 """ 

4496 # 2D program IDs: pid_nc for (batch, channel), pid_hw for spatial tile 

4497 pid_nc = tl.program_id(0) 

4498 pid_hw = tl.program_id(1) 

4499 

4500 # Compute batch and channel 

4501 n = pid_nc // C 

4502 c = pid_nc % C 

4503 

4504 # Compute tile position in output grid 

4505 num_w_blocks = tl.cdiv(W_out, BLOCK_W) 

4506 h_block_idx = pid_hw // num_w_blocks 

4507 w_block_idx = pid_hw % num_w_blocks 

4508 

4509 # Compute pixel offsets within tile 

4510 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

4511 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

4512 

4513 # Mask for boundary tiles 

4514 h_mask = h_offsets < H_out 

4515 w_mask = w_offsets < W_out 

4516 tile_mask = h_mask[:, None] & w_mask[None, :] 

4517 

4518 # Reshape for broadcasting: (BLOCK_H, BLOCK_W) 

4519 h_out_flat = h_offsets[:, None] 

4520 w_out_flat = w_offsets[None, :] 

4521 

4522 # Load grid coordinates for entire tile (vectorized) 

4523 grid_base = n * H_out * W_out * 2 

4524 

4525 # Load x coordinates: (BLOCK_H, BLOCK_W) 

4526 grid_x_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2 

4527 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to( 

4528 tl.float32 

4529 ) 

4530 

4531 # Load y coordinates: (BLOCK_H, BLOCK_W) 

4532 grid_y_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2 + 1 

4533 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to( 

4534 tl.float32 

4535 ) 

4536 

4537 # Apply triangle wave reflection in grid space (vectorized) 

4538 grid_x_shifted = grid_x + 1.0 

4539 grid_x_mod = grid_x_shifted % 4.0 

4540 grid_x_mod = tl.where(grid_x_mod < 0, grid_x_mod + 4.0, grid_x_mod) 

4541 grid_x_refl_mod = tl.where(grid_x_mod <= 2.0, grid_x_mod, 4.0 - grid_x_mod) 

4542 x = grid_x_refl_mod - 1.0 

4543 

4544 grid_y_shifted = grid_y + 1.0 

4545 grid_y_mod = grid_y_shifted % 4.0 

4546 grid_y_mod = tl.where(grid_y_mod < 0, grid_y_mod + 4.0, grid_y_mod) 

4547 grid_y_refl_mod = tl.where(grid_y_mod <= 2.0, grid_y_mod, 4.0 - grid_y_mod) 

4548 y = grid_y_refl_mod - 1.0 

4549 

4550 # Denormalize to pixel space 

4551 if align_corners: 

4552 x = (x + 1.0) * (W_in - 1) / 2.0 

4553 y = (y + 1.0) * (H_in - 1) / 2.0 

4554 else: 

4555 x = (x + 1.0) * W_in / 2.0 - 0.5 

4556 y = (y + 1.0) * H_in / 2.0 - 0.5 

4557 

4558 # Compute corner indices for entire tile (vectorized) 

4559 x0 = tl.floor(x) 

4560 x1 = x0 + 1 

4561 y0 = tl.floor(y) 

4562 y1 = y0 + 1 

4563 

4564 # Cast to int for indexing 

4565 x0_int = tl.cast(x0, tl.int32) 

4566 x1_int = tl.cast(x1, tl.int32) 

4567 y0_int = tl.cast(y0, tl.int32) 

4568 y1_int = tl.cast(y1, tl.int32) 

4569 

4570 # Clamp to valid bounds (should already be in bounds due to reflection) 

4571 x0_int = tl.maximum(0, tl.minimum(x0_int, W_in - 1)) 

4572 x1_int = tl.maximum(0, tl.minimum(x1_int, W_in - 1)) 

4573 y0_int = tl.maximum(0, tl.minimum(y0_int, H_in - 1)) 

4574 y1_int = tl.maximum(0, tl.minimum(y1_int, H_in - 1)) 

4575 

4576 # Compute interpolation weights 

4577 wx = x - tl.cast(x0, tl.float32) 

4578 wy = y - tl.cast(y0, tl.float32) 

4579 

4580 # Load 4 corner pixels (vectorized) 

4581 input_base = n * C * H_in * W_in + c * H_in * W_in 

4582 

4583 p00_offsets = input_base + y0_int * W_in + x0_int 

4584 p00 = tl.load(ptr_input + p00_offsets, mask=tile_mask, other=0.0) 

4585 

4586 p01_offsets = input_base + y0_int * W_in + x1_int 

4587 p01 = tl.load(ptr_input + p01_offsets, mask=tile_mask, other=0.0) 

4588 

4589 p10_offsets = input_base + y1_int * W_in + x0_int 

4590 p10 = tl.load(ptr_input + p10_offsets, mask=tile_mask, other=0.0) 

4591 

4592 p11_offsets = input_base + y1_int * W_in + x1_int 

4593 p11 = tl.load(ptr_input + p11_offsets, mask=tile_mask, other=0.0) 

4594 

4595 # Bilinear interpolation (vectorized) 

4596 top = p00 * (1.0 - wx) + p01 * wx 

4597 bottom = p10 * (1.0 - wx) + p11 * wx 

4598 vals = top * (1.0 - wy) + bottom * wy 

4599 

4600 # Store to output 

4601 output_base = n * C * H_out * W_out + c * H_out * W_out 

4602 output_offsets = output_base + (h_out_flat * W_out + w_out_flat) 

4603 

4604 tl.store(ptr_output + output_offsets, vals, mask=tile_mask) 

4605 

4606 

4607# ============================================================================ 

4608# Main Dispatch Function 

4609# ============================================================================ 

4610 

4611 

4612def grid_sample( 

4613 input: torch.Tensor, 

4614 grid: torch.Tensor, 

4615 mode: str = "bilinear", 

4616 padding_mode: str = "zeros", 

4617 align_corners: bool = False, 

4618) -> torch.Tensor: 

4619 """ 

4620 Grid sample operation with spatial interpolation. 

4621 

4622 Computes the output using input values and pixel locations from grid. 

4623 Grid specifies sampling pixel locations normalized by input spatial dimensions. 

4624 

4625 Args: 

4626 input: Input tensor of shape (N, C, H_in, W_in) or (N, C, D_in, H_in, W_in) 

4627 grid: Grid tensor of shape (N, H_out, W_out, 2) or (N, D_out, H_out, W_out, 3) 

4628 Values should be in range [-1, 1], normalized by input spatial dimensions 

4629 mode: Interpolation mode - 'bilinear', 'nearest', or 'bicubic' (4D only) 

4630 padding_mode: Padding mode for out-of-bound grid locations 

4631 - 'zeros': use 0 for out-of-bound locations 

4632 - 'border': use border values 

4633 - 'reflection': reflect by border 

4634 align_corners: If True, extrema (-1, 1) refer to center points of corner pixels 

4635 If False, extrema refer to corner points of corner pixels 

4636 

4637 Returns: 

4638 Output tensor of shape (N, C, H_out, W_out) or (N, C, D_out, H_out, W_out) 

4639 

4640 Examples: 

4641 >>> input = torch.randn(1, 3, 32, 32).cuda() 

4642 >>> grid = torch.randn(1, 64, 64, 2).cuda() 

4643 >>> output = grid_sample(input, grid, mode='bilinear') 

4644 >>> print(output.shape) 

4645 torch.Size([1, 3, 64, 64]) 

4646 """ 

4647 # Validate inputs 

4648 _validate_grid_sample_input(input, grid, mode, padding_mode) 

4649 

4650 # Get tensor properties 

4651 dtype = input.dtype 

4652 device = input.device 

4653 

4654 is_3d = input.dim() == 5 

4655 

4656 # Handle 4D inputs (N, C, H_in, W_in) 

4657 if not is_3d: 

4658 N, C, H_in, W_in = input.shape 

4659 _, H_out, W_out, _ = grid.shape 

4660 

4661 # Allocate output tensor 

4662 output = torch.empty((N, C, H_out, W_out), dtype=dtype, device=device) 

4663 

4664 # Adaptive kernel selection based on output size 

4665 # Use tiled kernels for medium-to-large outputs (>= 32x32 = 1024 pixels) 

4666 # Use original per-pixel kernels for small outputs (< 32x32) 

4667 output_pixels = H_out * W_out 

4668 USE_TILED_THRESHOLD = 1024 

4669 

4670 use_tiled = output_pixels >= USE_TILED_THRESHOLD 

4671 

4672 # Select kernel based on mode, padding mode, and output size 

4673 if mode == "nearest": 

4674 if use_tiled: 

4675 # Use tiled kernels for medium-to-large outputs 

4676 if padding_mode == "zeros": 

4677 kernel = grid_sample_2d_nearest_zeros_tiled_kernel 

4678 elif padding_mode == "border": 

4679 kernel = grid_sample_2d_nearest_border_tiled_kernel 

4680 else: # reflection 

4681 kernel = grid_sample_2d_nearest_reflection_tiled_kernel 

4682 else: 

4683 # Use original kernels for small outputs 

4684 if padding_mode == "zeros": 

4685 kernel = grid_sample_2d_nearest_zeros_kernel 

4686 elif padding_mode == "border": 

4687 kernel = grid_sample_2d_nearest_border_kernel 

4688 else: # reflection 

4689 kernel = grid_sample_2d_nearest_reflection_kernel 

4690 elif mode == "bilinear": 

4691 if use_tiled: 

4692 # Use tiled kernels for medium-to-large outputs 

4693 if padding_mode == "zeros": 

4694 kernel = grid_sample_2d_bilinear_zeros_tiled_kernel 

4695 elif padding_mode == "border": 

4696 kernel = grid_sample_2d_bilinear_border_tiled_kernel 

4697 else: # reflection 

4698 kernel = grid_sample_2d_bilinear_reflection_tiled_kernel 

4699 else: 

4700 # Use original kernels for small outputs 

4701 if padding_mode == "zeros": 

4702 kernel = grid_sample_2d_bilinear_zeros_kernel 

4703 elif padding_mode == "border": 

4704 kernel = grid_sample_2d_bilinear_border_kernel 

4705 else: # reflection 

4706 kernel = grid_sample_2d_bilinear_reflection_kernel 

4707 elif mode == "bicubic": 

4708 # Bicubic is already competitive, use original kernels for all sizes 

4709 if padding_mode == "zeros": 

4710 kernel = grid_sample_2d_bicubic_zeros_kernel 

4711 elif padding_mode == "border": 

4712 kernel = grid_sample_2d_bicubic_border_kernel 

4713 else: # reflection 

4714 kernel = grid_sample_2d_bicubic_reflection_kernel 

4715 else: # unsupported mode 

4716 logger.info(f"grid_sample mode '{mode}' not supported") 

4717 raise NotImplementedError 

4718 

4719 # Launch kernel with appropriate grid size 

4720 # For very large outputs (> 512x512), fall back to original kernels to avoid grid size issues 

4721 output_pixels = H_out * W_out 

4722 MAX_TILED_PIXELS = 512 * 512 

4723 

4724 if ( 

4725 use_tiled 

4726 and mode in ["nearest", "bilinear"] 

4727 and output_pixels <= MAX_TILED_PIXELS 

4728 ): 

4729 # Tiled kernels use 2D grid with adaptive tile size selection 

4730 # Goal: Create ~100-500 blocks total for good GPU utilization 

4731 target_total_blocks = ( 

4732 300 # Target: aim for ~300 blocks across all batches/channels 

4733 ) 

4734 min_blocks_per_nc = 50 # Minimum: ensure enough parallelism 

4735 max_blocks_per_nc = 1000 # Maximum: avoid too many blocks 

4736 

4737 # Estimate blocks per (batch, channel) pair 

4738 nc_pairs = N * C 

4739 

4740 # Target blocks per (N, C) pair 

4741 target_blocks_per_nc = max( 

4742 min_blocks_per_nc, 

4743 min(max_blocks_per_nc, target_total_blocks // max(1, nc_pairs)), 

4744 ) 

4745 

4746 # Calculate tile dimensions to achieve target block count 

4747 # Start with square tiles 

4748 target_tile_pixels = output_pixels // target_blocks_per_nc 

4749 target_tile_side = int(max(4, min(128, int(target_tile_pixels**0.5)))) 

4750 

4751 # Snap to power-of-2 for better alignment 

4752 if target_tile_side >= 64: 

4753 block_h = block_w = 64 if target_tile_side < 96 else 128 

4754 elif target_tile_side >= 16: 

4755 block_h = block_w = 32 

4756 elif target_tile_side >= 8: 

4757 block_h = block_w = 16 

4758 else: 

4759 block_h = block_w = 8 

4760 

4761 # For bilinear, use smaller tiles due to higher memory footprint 

4762 if mode == "bilinear": 

4763 block_h = max(4, block_h // 2) 

4764 block_w = max(4, block_w // 2) 

4765 

4766 # Calculate actual grid size 

4767 num_h_blocks = (H_out + block_h - 1) // block_h 

4768 num_w_blocks = (W_out + block_w - 1) // block_w 

4769 grid_size = (N * C, num_h_blocks * num_w_blocks) 

4770 else: 

4771 # Original kernels use 1D grid (for small outputs or very large outputs) 

4772 grid_size = (N * C * H_out * W_out,) 

4773 

4774 kernel[grid_size]( 

4775 output, 

4776 input, 

4777 grid, 

4778 N, 

4779 C, 

4780 H_in, 

4781 W_in, 

4782 H_out, 

4783 W_out, 

4784 align_corners, 

4785 ) 

4786 

4787 return output 

4788 

4789 # Handle 5D inputs (N, C, D_in, H_in, W_in) 

4790 else: # is_3d == True 

4791 N, C, D_in, H_in, W_in = input.shape 

4792 _, D_out, H_out, W_out, _ = grid.shape 

4793 

4794 # Allocate output tensor 

4795 output = torch.empty((N, C, D_out, H_out, W_out), dtype=dtype, device=device) 

4796 

4797 # Adaptive kernel selection based on output size 

4798 # Use tiled kernels for medium-to-large outputs (>= 16x16x16 = 4096 voxels) 

4799 # Increased from 512 to avoid tiled kernel overhead on small outputs 

4800 output_voxels = D_out * H_out * W_out 

4801 USE_TILED_THRESHOLD_3D = 4096 # 16x16x16 

4802 

4803 use_tiled = output_voxels >= USE_TILED_THRESHOLD_3D 

4804 

4805 # Select kernel based on mode, padding mode, and output size 

4806 if mode == "nearest": 

4807 if use_tiled: 

4808 # Use tiled kernels for medium-to-large outputs 

4809 if padding_mode == "zeros": 

4810 kernel = grid_sample_3d_nearest_zeros_tiled_kernel 

4811 elif padding_mode == "border": 

4812 kernel = grid_sample_3d_nearest_border_tiled_kernel 

4813 else: # reflection 

4814 kernel = grid_sample_3d_nearest_reflection_tiled_kernel 

4815 else: 

4816 # Use original kernels for small outputs 

4817 if padding_mode == "zeros": 

4818 kernel = grid_sample_3d_nearest_zeros_kernel 

4819 elif padding_mode == "border": 

4820 kernel = grid_sample_3d_nearest_border_kernel 

4821 else: # reflection 

4822 kernel = grid_sample_3d_nearest_reflection_kernel 

4823 elif mode == "bilinear": # For 5D, bilinear means trilinear 

4824 if use_tiled: 

4825 # Use tiled kernels for medium-to-large outputs 

4826 if padding_mode == "zeros": 

4827 kernel = grid_sample_3d_trilinear_zeros_tiled_kernel 

4828 elif padding_mode == "border": 

4829 kernel = grid_sample_3d_trilinear_border_tiled_kernel 

4830 else: # reflection 

4831 kernel = grid_sample_3d_trilinear_reflection_tiled_kernel 

4832 else: 

4833 # Use original kernels for small outputs 

4834 if padding_mode == "zeros": 

4835 kernel = grid_sample_3d_trilinear_zeros_kernel 

4836 elif padding_mode == "border": 

4837 kernel = grid_sample_3d_trilinear_border_kernel 

4838 else: # reflection 

4839 kernel = grid_sample_3d_trilinear_reflection_kernel 

4840 else: # unsupported mode for 5D 

4841 logger.info(f"grid_sample mode '{mode}' not supported for 5D input") 

4842 raise NotImplementedError("Unsupported mode for 5D input") 

4843 

4844 # Launch kernel with appropriate grid size 

4845 # For very large outputs (> 128x128x128), fall back to original kernels 

4846 if ( 

4847 use_tiled 

4848 and mode in ["nearest", "bilinear"] 

4849 and output_voxels <= MAX_TILED_VOXELS 

4850 ): 

4851 # Tiled kernels use 2D grid with adaptive tile size selection 

4852 # Goal: Create optimal blocks for good GPU utilization (more granular for medium outputs) 

4853 nc_pairs = N * C 

4854 

4855 # More granular targeting to fix 16³ and 32³ performance 

4856 # Key: Need MORE blocks for 16³ and 32³, not fewer 

4857 if output_voxels < VOXEL_THRESHOLD_SMALL: # 16³ - 20³ 

4858 target_total_blocks = TARGET_BLOCKS_SMALL 

4859 min_blocks_per_nc = MIN_BLOCKS_NC_SMALL 

4860 max_blocks_per_nc = MAX_BLOCKS_NC_SMALL 

4861 elif output_voxels < VOXEL_THRESHOLD_MEDIUM: # 20³ - 32³ 

4862 target_total_blocks = TARGET_BLOCKS_MEDIUM 

4863 min_blocks_per_nc = MIN_BLOCKS_NC_MEDIUM 

4864 max_blocks_per_nc = MAX_BLOCKS_NC_MEDIUM 

4865 elif output_voxels < VOXEL_THRESHOLD_LARGE: # 32³ - 50³ 

4866 target_total_blocks = TARGET_BLOCKS_LARGE 

4867 min_blocks_per_nc = MIN_BLOCKS_NC_LARGE 

4868 max_blocks_per_nc = MAX_BLOCKS_NC_LARGE 

4869 elif output_voxels < VOXEL_THRESHOLD_VERY_LARGE: # 50³ - 64³ 

4870 target_total_blocks = TARGET_BLOCKS_VERY_LARGE 

4871 min_blocks_per_nc = MIN_BLOCKS_NC_VERY_LARGE 

4872 max_blocks_per_nc = MAX_BLOCKS_NC_VERY_LARGE 

4873 else: # Large outputs (>= 64³) 

4874 target_total_blocks = TARGET_BLOCKS_EXTRA_LARGE 

4875 min_blocks_per_nc = MIN_BLOCKS_NC_EXTRA_LARGE 

4876 max_blocks_per_nc = MAX_BLOCKS_NC_EXTRA_LARGE 

4877 

4878 # Channel-aware tiling: reduce targets for high channel counts to avoid too many blocks 

4879 # When C is large, we create too many blocks with the current formula 

4880 # Solution: Reduce target_total_blocks proportionally 

4881 if C > CHANNEL_COUNT_THRESHOLD: 

4882 # Scale down targets more aggressively to avoid excessive blocks when C > threshold 

4883 # Use sqrt scaling for better balance 

4884 channel_scale = ( 

4885 CHANNEL_COUNT_THRESHOLD / C 

4886 ) ** CHANNEL_SCALING_EXPONENT 

4887 target_total_blocks = max( 

4888 MIN_TARGET_TOTAL_BLOCKS, int(target_total_blocks * channel_scale) 

4889 ) 

4890 min_blocks_per_nc = max( 

4891 MIN_BLOCKS_PER_NC, int(min_blocks_per_nc * channel_scale) 

4892 ) 

4893 # Keep max_blocks_per_nc unchanged to prevent excessive blocks 

4894 

4895 # Target blocks per (N, C) pair 

4896 target_blocks_per_nc = max( 

4897 min_blocks_per_nc, 

4898 min(max_blocks_per_nc, target_total_blocks // max(1, nc_pairs)), 

4899 ) 

4900 

4901 # Calculate tile dimensions to achieve target block count 

4902 # For 3D, start with cubic tiles 

4903 total_voxels = D_out * H_out * W_out 

4904 target_tile_voxels = total_voxels // target_blocks_per_nc 

4905 target_tile_side = int( 

4906 max( 

4907 MIN_TILE_SIDE, 

4908 min(MAX_TILE_SIDE, int(target_tile_voxels ** (1.0 / 3.0))), 

4909 ) 

4910 ) 

4911 

4912 # Snap to power-of-2 for better alignment 

4913 # Minimum tile size is 4x4x4 for small outputs, 8x8x8 for large 

4914 if target_tile_side >= LARGE_TILE_THRESHOLD: 

4915 block_d = block_h = block_w = ( 

4916 LARGE_TILE_THRESHOLD 

4917 if target_tile_side < VERY_LARGE_TILE_THRESHOLD 

4918 else MAX_TILE_SIDE 

4919 ) 

4920 elif target_tile_side >= MEDIUM_TILE_THRESHOLD: 

4921 block_d = block_h = block_w = MEDIUM_TILE_THRESHOLD 

4922 elif target_tile_side >= SMALL_TILE_THRESHOLD: 

4923 block_d = block_h = block_w = SMALL_TILE_THRESHOLD 

4924 else: 

4925 block_d = block_h = block_w = MIN_TILE_SIDE 

4926 

4927 # For trilinear, use smaller tiles due to higher memory footprint (8x loads) 

4928 if mode == "bilinear": # actually trilinear in 5D 

4929 block_d = max(MIN_BLOCK_DIMENSION, block_d // 2) 

4930 block_h = max(MIN_BLOCK_DIMENSION, block_h // 2) 

4931 block_w = max(MIN_BLOCK_DIMENSION, block_w // 2) 

4932 

4933 # Calculate actual grid size 

4934 num_d_blocks = (D_out + block_d - 1) // block_d 

4935 num_h_blocks = (H_out + block_h - 1) // block_h 

4936 num_w_blocks = (W_out + block_w - 1) // block_w 

4937 grid_size = (N * C, num_d_blocks * num_h_blocks * num_w_blocks) 

4938 else: 

4939 # Original kernels use 1D grid (for small outputs or very large outputs) 

4940 grid_size = (N * C * D_out * H_out * W_out,) 

4941 

4942 # Kernel launch 

4943 kernel[grid_size]( 

4944 output, 

4945 input, 

4946 grid, 

4947 N, 

4948 C, 

4949 D_in, 

4950 H_in, 

4951 W_in, 

4952 D_out, 

4953 H_out, 

4954 W_out, 

4955 align_corners, 

4956 ) 

4957 

4958 return output