Coverage for src/flag_gems/ops/grid_sample.py: 13%
2179 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1"""
2Grid sample operator implementation for FlagGems.
4This module provides the grid sampling operation with various interpolation modes.
5Grid sample computes the output using input values and pixel locations from grid.
6"""
8import logging
10import torch
11import triton
12import triton.language as tl
14from flag_gems import runtime
15from flag_gems.utils import libentry
17logger = logging.getLogger(__name__)
19# ============================================================================
20# Grid Sample Constants
21# ============================================================================
23# Maximum tiled voxel count for tiled kernel usage
24MAX_TILED_VOXELS = 128 * 128 * 128 # ~2M voxels
26# Voxel thresholds for adaptive block targeting
27# These represent approximate cube dimensions: 16³=4096, 20³=8000, 32³=32768, 50³=125000, 64³=262144
28VOXEL_THRESHOLD_SMALL = 8192 # Threshold for small outputs (16³ - 20³)
29VOXEL_THRESHOLD_MEDIUM = 32768 # Threshold for medium outputs (20³ - 32³)
30VOXEL_THRESHOLD_LARGE = 131072 # Threshold for large outputs (32³ - 50³)
31VOXEL_THRESHOLD_VERY_LARGE = 262144 # Threshold for very large outputs (50³ - 64³)
33# Block target configuration for different output sizes
34# Small outputs (16³ - 20³): Higher block count for better utilization
35TARGET_BLOCKS_SMALL = 512
36MIN_BLOCKS_NC_SMALL = 64
37MAX_BLOCKS_NC_SMALL = 1024
39# Medium outputs (20³ - 32³): Even higher block count
40TARGET_BLOCKS_MEDIUM = 768
41MIN_BLOCKS_NC_MEDIUM = 128
42MAX_BLOCKS_NC_MEDIUM = 2048
44# Large outputs (32³ - 50³): Maximum block targeting
45TARGET_BLOCKS_LARGE = 1024
46MIN_BLOCKS_NC_LARGE = 128
47MAX_BLOCKS_NC_LARGE = 2048
49# Very large outputs (50³ - 64³): Reduced block count
50TARGET_BLOCKS_VERY_LARGE = 512
51MIN_BLOCKS_NC_VERY_LARGE = 64
52MAX_BLOCKS_NC_VERY_LARGE = 1024
54# Extra large outputs (>= 64³): Conservative block targeting
55TARGET_BLOCKS_EXTRA_LARGE = 300
56MIN_BLOCKS_NC_EXTRA_LARGE = 50
57MAX_BLOCKS_NC_EXTRA_LARGE = 1000
59# Channel scaling constants
60CHANNEL_COUNT_THRESHOLD = 32 # Channel count above which to scale down block targets
61CHANNEL_SCALING_EXPONENT = 0.7 # Exponent for channel scaling factor
62MIN_TARGET_TOTAL_BLOCKS = 128 # Minimum target total blocks when scaling for channels
63MIN_BLOCKS_PER_NC = 16 # Minimum blocks per (N, C) pair when scaling for channels
65# Tile size constants
66MIN_TILE_SIDE = 4 # Minimum tile side length for 3D outputs
67MAX_TILE_SIDE = 64 # Maximum tile side length for 3D outputs
68LARGE_TILE_THRESHOLD = 32 # Threshold for using 32 or 64 sized tiles
69VERY_LARGE_TILE_THRESHOLD = 48 # Threshold for using 64 instead of 32
70MEDIUM_TILE_THRESHOLD = 16 # Threshold for using 16 sized tiles
71SMALL_TILE_THRESHOLD = 8 # Threshold for using 8 sized tiles
73# Trilinear reduction constants
74MIN_BLOCK_DIMENSION = 2 # Minimum block dimension after halving for trilinear
77def _validate_grid_sample_input(input, grid, mode, padding_mode):
78 """
79 Validate input tensors and parameters for grid_sample.
81 Args:
82 input: Input tensor
83 grid: Grid tensor
84 mode: Interpolation mode
85 padding_mode: Padding mode
87 Raises:
88 ValueError: If inputs or parameters are invalid
89 """
90 if input.dim() not in [4, 5]:
91 raise ValueError("Input must be 4D or 5D")
93 if input.dim() == 4 and grid.dim() != 4:
94 raise ValueError(
95 "For 4D input, grid must be 4D (N, H_out, W_out, 2), "
96 f"but got {grid.dim()}D tensor"
97 )
99 if input.dim() == 5 and grid.dim() != 5:
100 raise ValueError(
101 f"For 5D input, grid must be 5D (N, D_out, H_out, W_out, 3), "
102 f"but got {grid.dim()}D tensor"
103 )
105 if input.dim() == 4 and grid.shape[-1] != 2:
106 raise ValueError(
107 f"For 4D input, grid must have 2 coordinates in last dimension, "
108 f"but got {grid.shape[-1]}"
109 )
111 if input.dim() == 5 and grid.shape[-1] != 3:
112 raise ValueError(
113 f"For 5D input, grid must have 3 coordinates in last dimension, "
114 f"but got {grid.shape[-1]}"
115 )
117 if input.shape[0] != grid.shape[0]:
118 raise ValueError(
119 f"Input and grid must have same batch size, "
120 f"but got {input.shape[0]} and {grid.shape[0]}"
121 )
123 valid_modes = ["bilinear", "nearest", "bicubic"]
124 if mode not in valid_modes:
125 raise ValueError(
126 f"Invalid mode '{mode}'. Expected one of {valid_modes}, "
127 f"but note: bicubic only supports 4D input"
128 )
130 if mode == "bicubic" and input.dim() == 5:
131 raise ValueError("Bicubic interpolation only supports 4D input")
133 valid_padding_modes = ["zeros", "border", "reflection"]
134 if padding_mode not in valid_padding_modes:
135 raise ValueError(
136 f"Invalid padding_mode '{padding_mode}'. Expected one of {valid_padding_modes}"
137 )
140# ============================================================================
141# 2D Nearest Neighbor Kernels
142# ============================================================================
145@libentry()
146@triton.autotune(
147 configs=runtime.get_tuned_config("grid_sample_2d_nearest"),
148 key=["N", "C", "H_out", "W_out"],
149)
150@triton.jit
151def grid_sample_2d_nearest_zeros_kernel(
152 ptr_output,
153 ptr_input,
154 ptr_grid,
155 N,
156 C,
157 H_in,
158 W_in,
159 H_out,
160 W_out,
161 align_corners: tl.constexpr,
162 BLOCK_SIZE: tl.constexpr,
163):
164 """
165 Grid sample kernel for 2D nearest neighbor interpolation with zeros padding.
167 For each output pixel, this kernel:
168 1. Loads the grid coordinates (normalized to [-1, 1])
169 2. Transforms coordinates to pixel space
170 3. Rounds to nearest pixel location
171 4. Loads the input pixel (or 0 if out of bounds)
172 5. Stores to output
174 Args:
175 ptr_output: Pointer to output tensor
176 ptr_input: Pointer to input tensor
177 ptr_grid: Pointer to grid tensor
178 N: Batch size
179 C: Number of channels
180 H_in: Input height
181 W_in: Input width
182 H_out: Output height
183 W_out: Output width
184 align_corners: Whether to align corners
185 BLOCK_SIZE: Block size for tuning
186 """
187 # Each program instance handles one output pixel (for all channels)
188 pid = tl.program_id(0)
189 nc = pid // (H_out * W_out)
190 hw = pid % (H_out * W_out)
192 n = nc // C
193 c = nc % C
194 h_out = hw // W_out
195 w_out = hw % W_out
197 # Load grid coordinates for this output location
198 # Grid shape: (N, H_out, W_out, 2)
199 grid_idx = n * H_out * W_out * 2 + h_out * W_out * 2 + w_out * 2
200 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32)
201 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32)
203 # Handle NaN - use sentinel value -2.0 (outside valid grid range [-1, 1])
204 # We'll detect this and return 0.0 for NaN values
205 grid_x_nan = grid_x != grid_x # True if NaN
206 grid_y_nan = grid_y != grid_y # True if NaN
207 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
208 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
210 # Denormalize to pixel space
211 if align_corners:
212 # Pixel centers at -1 and 1
213 x = (grid_x + 1.0) * (W_in - 1) / 2.0
214 y = (grid_y + 1.0) * (H_in - 1) / 2.0
215 # Use banker's rounding (round half to even) for align_corners=True too
216 x_floor = tl.floor(x)
217 y_floor = tl.floor(y)
218 x_frac = x - x_floor
219 y_frac = y - y_floor
220 x_is_half = x_frac == 0.5
221 y_is_half = y_frac == 0.5
222 x_floor_int = tl.cast(x_floor, tl.int32)
223 y_floor_int = tl.cast(y_floor, tl.int32)
224 x_is_even = x_floor_int % 2 == 0
225 y_is_even = y_floor_int % 2 == 0
226 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1)
227 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1)
228 x_idx = tl.cast(
229 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round),
230 tl.int32,
231 )
232 y_idx = tl.cast(
233 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round),
234 tl.int32,
235 )
236 # Check bounds (align_corners=True: valid range is [0, W_in) x [0, H_in))
237 # Also check for NaN (sentinel value -2.0)
238 mask = (
239 (x_idx >= 0)
240 & (x_idx < W_in)
241 & (y_idx >= 0)
242 & (y_idx < H_in)
243 & ~grid_x_nan
244 & ~grid_y_nan
245 )
246 else:
247 # Pixel corners at -1 and 1
248 x = (grid_x + 1.0) * W_in / 2.0 - 0.5
249 y = (grid_y + 1.0) * H_in / 2.0 - 0.5
250 # Use banker's rounding (round half to even) for align_corners=False
251 x_floor = tl.floor(x)
252 y_floor = tl.floor(y)
253 x_frac = x - x_floor
254 y_frac = y - y_floor
255 x_is_half = x_frac == 0.5
256 y_is_half = y_frac == 0.5
257 x_floor_int = tl.cast(x_floor, tl.int32)
258 y_floor_int = tl.cast(y_floor, tl.int32)
259 x_is_even = x_floor_int % 2 == 0
260 y_is_even = y_floor_int % 2 == 0
261 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1)
262 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1)
263 x_idx = tl.cast(
264 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round),
265 tl.int32,
266 )
267 y_idx = tl.cast(
268 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round),
269 tl.int32,
270 )
272 # Check bounds (align_corners=False)
273 # Also check for NaN (sentinel value -2.0)
274 mask = (
275 (x_idx >= 0)
276 & (x_idx < W_in)
277 & (y_idx >= 0)
278 & (y_idx < H_in)
279 & ~grid_x_nan
280 & ~grid_y_nan
281 )
283 # Input shape: (N, C, H_in, W_in)
284 input_offset = n * C * H_in * W_in + c * H_in * W_in + y_idx * W_in + x_idx
285 val = tl.load(ptr_input + input_offset, mask=mask, other=0.0).to(tl.float32)
287 # Store output
288 # Output shape: (N, C, H_out, W_out)
289 output_offset = n * C * H_out * W_out + c * H_out * W_out + h_out * W_out + w_out
290 tl.store(ptr_output + output_offset, val)
293@libentry()
294@triton.autotune(
295 configs=runtime.get_tuned_config("grid_sample_2d_nearest"),
296 key=["N", "C", "H_out", "W_out"],
297)
298@triton.jit
299def grid_sample_2d_nearest_border_kernel(
300 ptr_output,
301 ptr_input,
302 ptr_grid,
303 N,
304 C,
305 H_in,
306 W_in,
307 H_out,
308 W_out,
309 align_corners: tl.constexpr,
310 BLOCK_SIZE: tl.constexpr,
311):
312 """
313 Grid sample kernel for 2D nearest neighbor interpolation with border padding.
315 Out-of-bound coordinates are clamped to the border.
316 """
317 pid = tl.program_id(0)
318 nc = pid // (H_out * W_out)
319 hw = pid % (H_out * W_out)
321 n = nc // C
322 c = nc % C
323 h_out = hw // W_out
324 w_out = hw % W_out
326 # Load grid coordinates
327 grid_idx = n * H_out * W_out * 2 + h_out * W_out * 2 + w_out * 2
328 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32)
329 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32)
331 # Handle NaN
332 grid_x = tl.where(grid_x != grid_x, -1.0, grid_x)
333 grid_y = tl.where(grid_y != grid_y, -1.0, grid_y)
335 # Denormalize to pixel space
336 if align_corners:
337 x = (grid_x + 1.0) * (W_in - 1) / 2.0
338 y = (grid_y + 1.0) * (H_in - 1) / 2.0
339 # Use banker's rounding (round half to even)
340 x_floor = tl.floor(x)
341 y_floor = tl.floor(y)
342 x_frac = x - x_floor
343 y_frac = y - y_floor
344 x_is_half = x_frac == 0.5
345 y_is_half = y_frac == 0.5
346 x_floor_int = tl.cast(x_floor, tl.int32)
347 y_floor_int = tl.cast(y_floor, tl.int32)
348 x_is_even = x_floor_int % 2 == 0
349 y_is_even = y_floor_int % 2 == 0
350 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1)
351 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1)
352 x_idx_unclamped = tl.cast(
353 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round),
354 tl.int32,
355 )
356 y_idx_unclamped = tl.cast(
357 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round),
358 tl.int32,
359 )
360 # For align_corners=True: clamp to [0, W_in-1]
361 x_idx = tl.maximum(0, tl.minimum(x_idx_unclamped, W_in - 1))
362 y_idx = tl.maximum(0, tl.minimum(y_idx_unclamped, H_in - 1))
363 else:
364 x = (grid_x + 1.0) * W_in / 2.0 - 0.5
365 y = (grid_y + 1.0) * H_in / 2.0 - 0.5
366 # Use banker's rounding (round half to even) for align_corners=False
367 x_floor = tl.floor(x)
368 y_floor = tl.floor(y)
369 x_frac = x - x_floor
370 y_frac = y - y_floor
371 x_is_half = x_frac == 0.5
372 y_is_half = y_frac == 0.5
373 x_floor_int = tl.cast(x_floor, tl.int32)
374 y_floor_int = tl.cast(y_floor, tl.int32)
375 x_is_even = x_floor_int % 2 == 0
376 y_is_even = y_floor_int % 2 == 0
377 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1)
378 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1)
379 x_idx_unclamped = tl.cast(
380 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round),
381 tl.int32,
382 )
383 y_idx_unclamped = tl.cast(
384 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round),
385 tl.int32,
386 )
387 # For align_corners=False: clamp to [0, W_in-1]
388 x_idx = tl.maximum(0, tl.minimum(x_idx_unclamped, W_in - 1))
389 y_idx = tl.maximum(0, tl.minimum(y_idx_unclamped, H_in - 1))
391 # Load input pixel (always in bounds due to clamping)
392 input_offset = n * C * H_in * W_in + c * H_in * W_in + y_idx * W_in + x_idx
393 val = tl.load(ptr_input + input_offset).to(tl.float32)
395 # Store output
396 output_offset = n * C * H_out * W_out + c * H_out * W_out + h_out * W_out + w_out
397 tl.store(ptr_output + output_offset, val)
400@libentry()
401@triton.autotune(
402 configs=runtime.get_tuned_config("grid_sample_2d_nearest"),
403 key=["N", "C", "H_out", "W_out"],
404)
405@triton.jit
406def grid_sample_2d_nearest_reflection_kernel(
407 ptr_output,
408 ptr_input,
409 ptr_grid,
410 N,
411 C,
412 H_in,
413 W_in,
414 H_out,
415 W_out,
416 align_corners: tl.constexpr,
417 BLOCK_SIZE: tl.constexpr,
418):
419 """
420 Grid sample kernel for 2D nearest neighbor interpolation with reflection padding.
422 Out-of-bound coordinates are reflected back into the valid range.
423 """
424 pid = tl.program_id(0)
425 nc = pid // (H_out * W_out)
426 hw = pid % (H_out * W_out)
428 n = nc // C
429 c = nc % C
430 h_out = hw // W_out
431 w_out = hw % W_out
433 # Load grid coordinates
434 grid_idx = n * H_out * W_out * 2 + h_out * W_out * 2 + w_out * 2
435 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32)
436 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32)
438 # Handle NaN
439 grid_x = tl.where(grid_x != grid_x, -1.0, grid_x)
440 grid_y = tl.where(grid_y != grid_y, -1.0, grid_y)
442 # Reflection padding in GRID space (before denormalizing)
443 # The grid space is [-1, 1], reflect at boundaries -1 and 1
444 # Triangle wave pattern with period 4
446 # Shift to [0, 4) range, handling negative modulo correctly
447 grid_x_shifted = grid_x + 1.0
448 # Triton's % operator behaves like C's fmod for floats (preserves sign)
449 # So we need to adjust: for negative values, add period to make it positive
450 grid_x_mod = grid_x_shifted % 4.0
451 grid_x_mod = tl.where(grid_x_mod < 0, grid_x_mod + 4.0, grid_x_mod)
453 # Triangle wave: goes up from 0 to 2, then down from 2 to 0
454 grid_x_refl_mod = tl.where(grid_x_mod <= 2.0, grid_x_mod, 4.0 - grid_x_mod)
455 grid_x_refl = grid_x_refl_mod - 1.0 # Shift back to [-1, 1]
457 # Same for y
458 grid_y_shifted = grid_y + 1.0
459 grid_y_mod = grid_y_shifted % 4.0
460 grid_y_mod = tl.where(grid_y_mod < 0, grid_y_mod + 4.0, grid_y_mod)
461 grid_y_refl_mod = tl.where(grid_y_mod <= 2.0, grid_y_mod, 4.0 - grid_y_mod)
462 grid_y_refl = grid_y_refl_mod - 1.0
464 # Denormalize to pixel space
465 if align_corners:
466 x = (grid_x_refl + 1.0) * (W_in - 1) / 2.0
467 y = (grid_y_refl + 1.0) * (H_in - 1) / 2.0
468 else:
469 x = (grid_x_refl + 1.0) * W_in / 2.0 - 0.5
470 y = (grid_y_refl + 1.0) * H_in / 2.0 - 0.5
472 # Banker's rounding (round half to even)
473 x_floor = tl.floor(x)
474 y_floor = tl.floor(y)
475 x_frac = x - x_floor
476 y_frac = y - y_floor
477 x_is_half = x_frac == 0.5
478 y_is_half = y_frac == 0.5
479 x_floor_int = tl.cast(x_floor, tl.int32)
480 y_floor_int = tl.cast(y_floor, tl.int32)
481 x_is_even = x_floor_int % 2 == 0
482 y_is_even = y_floor_int % 2 == 0
483 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1)
484 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1)
485 x_idx_unclamped = tl.cast(
486 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round),
487 tl.int32,
488 )
489 y_idx_unclamped = tl.cast(
490 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round),
491 tl.int32,
492 )
494 # Clamp to valid bounds (should already be in bounds due to reflection, but clamp for safety)
495 x_idx = tl.maximum(0, tl.minimum(x_idx_unclamped, W_in - 1))
496 y_idx = tl.maximum(0, tl.minimum(y_idx_unclamped, H_in - 1))
498 # Load input pixel
499 input_offset = n * C * H_in * W_in + c * H_in * W_in + y_idx * W_in + x_idx
500 val = tl.load(ptr_input + input_offset).to(tl.float32)
502 # Store output
503 output_offset = n * C * H_out * W_out + c * H_out * W_out + h_out * W_out + w_out
504 tl.store(ptr_output + output_offset, val)
507# ============================================================================
508# Bilinear Interpolation Kernels (4D)
509# ============================================================================
512@libentry()
513@triton.autotune(
514 configs=runtime.get_tuned_config("grid_sample_2d_bilinear"),
515 key=["N", "C", "H_out", "W_out"],
516)
517@triton.jit
518def grid_sample_2d_bilinear_zeros_kernel(
519 ptr_output,
520 ptr_input,
521 ptr_grid,
522 N,
523 C,
524 H_in,
525 W_in,
526 H_out,
527 W_out,
528 align_corners: tl.constexpr,
529 BLOCK_SIZE: tl.constexpr,
530):
531 """
532 Grid sample kernel for 2D bilinear interpolation with zeros padding.
534 Each program instance handles one output pixel location (all channels).
535 Loads 4 corner pixels and performs bilinear interpolation.
536 """
537 # Each program instance processes one output pixel (all channels)
538 pid = tl.program_id(0)
539 nc = pid // (H_out * W_out)
540 hw = pid % (H_out * W_out)
542 n = nc // C
543 c = nc % C
544 h_out = hw // W_out
545 w_out = hw % W_out
547 # Load grid coordinates for this output location
548 # Grid shape: (N, H_out, W_out, 2)
549 grid_idx = n * H_out * W_out * 2 + h_out * W_out * 2 + w_out * 2
550 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32)
551 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32)
553 # Handle NaN - use sentinel value -2.0 (outside valid grid range [-1, 1])
554 grid_x_nan = grid_x != grid_x
555 grid_y_nan = grid_y != grid_y
556 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
557 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
559 # Denormalize to pixel space
560 if align_corners:
561 # Pixel centers at -1 and 1
562 x = (grid_x + 1.0) * (W_in - 1) / 2.0
563 y = (grid_y + 1.0) * (H_in - 1) / 2.0
564 else:
565 # Pixel corners at -1 and 1
566 x = (grid_x + 1.0) * W_in / 2.0 - 0.5
567 y = (grid_y + 1.0) * H_in / 2.0 - 0.5
569 # Find 4 corner indices
570 x0 = tl.floor(x)
571 y0 = tl.floor(y)
572 x1 = x0 + 1
573 y1 = y0 + 1
575 # Compute interpolation weights
576 wx = x - x0
577 wy = y - y0
579 # Convert corner indices to int
580 x0_int = tl.cast(x0, tl.int32)
581 y0_int = tl.cast(y0, tl.int32)
582 x1_int = tl.cast(x1, tl.int32)
583 y1_int = tl.cast(y1, tl.int32)
585 # Check bounds for each corner (zeros padding)
586 x0_in_bounds = (x0_int >= 0) & (x0_int < W_in)
587 x1_in_bounds = (x1_int >= 0) & (x1_int < W_in)
588 y0_in_bounds = (y0_int >= 0) & (y0_int < H_in)
589 y1_in_bounds = (y1_int >= 0) & (y1_int < H_in)
591 # Load 4 corner pixels with zeros padding
592 # Input shape: (N, C, H_in, W_in)
593 input_base = n * C * H_in * W_in + c * H_in * W_in
595 offset_00 = input_base + y0_int * W_in + x0_int
596 offset_01 = input_base + y0_int * W_in + x1_int
597 offset_10 = input_base + y1_int * W_in + x0_int
598 offset_11 = input_base + y1_int * W_in + x1_int
600 p00 = tl.load(
601 ptr_input + offset_00,
602 mask=x0_in_bounds & y0_in_bounds & ~grid_x_nan & ~grid_y_nan,
603 other=0.0,
604 ).to(tl.float32)
605 p01 = tl.load(
606 ptr_input + offset_01,
607 mask=x1_in_bounds & y0_in_bounds & ~grid_x_nan & ~grid_y_nan,
608 other=0.0,
609 ).to(tl.float32)
610 p10 = tl.load(
611 ptr_input + offset_10,
612 mask=x0_in_bounds & y1_in_bounds & ~grid_x_nan & ~grid_y_nan,
613 other=0.0,
614 ).to(tl.float32)
615 p11 = tl.load(
616 ptr_input + offset_11,
617 mask=x1_in_bounds & y1_in_bounds & ~grid_x_nan & ~grid_y_nan,
618 other=0.0,
619 ).to(tl.float32)
621 # Bilinear interpolation
622 # Interpolate along x, then y
623 # top = p00 * (1-wx) + p01 * wx
624 # bottom = p10 * (1-wx) + p11 * wx
625 # result = top * (1-wy) + bottom * wy
626 top = p00 * (1.0 - wx) + p01 * wx
627 bottom = p10 * (1.0 - wx) + p11 * wx
628 val = top * (1.0 - wy) + bottom * wy
630 # Store output
631 # Output shape: (N, C, H_out, W_out)
632 output_offset = n * C * H_out * W_out + c * H_out * W_out + h_out * W_out + w_out
633 tl.store(ptr_output + output_offset, val)
636@libentry()
637@triton.autotune(
638 configs=runtime.get_tuned_config("grid_sample_2d_bilinear"),
639 key=["N", "C", "H_out", "W_out"],
640)
641@triton.jit
642def grid_sample_2d_bilinear_border_kernel(
643 ptr_output,
644 ptr_input,
645 ptr_grid,
646 N,
647 C,
648 H_in,
649 W_in,
650 H_out,
651 W_out,
652 align_corners: tl.constexpr,
653 BLOCK_SIZE: tl.constexpr,
654):
655 """
656 Grid sample kernel for 2D bilinear interpolation with border padding.
658 Clamps coordinates to valid range [0, size-1] for out-of-bound values.
659 """
660 # Each program instance processes one output pixel (all channels)
661 pid = tl.program_id(0)
662 nc = pid // (H_out * W_out)
663 hw = pid % (H_out * W_out)
665 n = nc // C
666 c = nc % C
667 h_out = hw // W_out
668 w_out = hw % W_out
670 # Load grid coordinates for this output location
671 grid_idx = n * H_out * W_out * 2 + h_out * W_out * 2 + w_out * 2
672 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32)
673 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32)
675 # Handle NaN
676 grid_x_nan = grid_x != grid_x
677 grid_y_nan = grid_y != grid_y
678 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
679 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
681 # Denormalize to pixel space
682 if align_corners:
683 x = (grid_x + 1.0) * (W_in - 1) / 2.0
684 y = (grid_y + 1.0) * (H_in - 1) / 2.0
685 else:
686 x = (grid_x + 1.0) * W_in / 2.0 - 0.5
687 y = (grid_y + 1.0) * H_in / 2.0 - 0.5
689 # Find 4 corner indices
690 x0 = tl.floor(x)
691 y0 = tl.floor(y)
692 x1 = x0 + 1
693 y1 = y0 + 1
695 # Convert to int
696 x0_int = tl.cast(x0, tl.int32)
697 y0_int = tl.cast(y0, tl.int32)
698 x1_int = tl.cast(x1, tl.int32)
699 y1_int = tl.cast(y1, tl.int32)
701 # Clamp to valid bounds (border padding)
702 x0_int = tl.maximum(0, tl.minimum(x0_int, W_in - 1))
703 x1_int = tl.maximum(0, tl.minimum(x1_int, W_in - 1))
704 y0_int = tl.maximum(0, tl.minimum(y0_int, H_in - 1))
705 y1_int = tl.maximum(0, tl.minimum(y1_int, H_in - 1))
707 # Compute interpolation weights
708 wx = x - x0
709 wy = y - y0
711 # Load 4 corner pixels (no mask needed due to clamping)
712 input_base = n * C * H_in * W_in + c * H_in * W_in
714 offset_00 = input_base + y0_int * W_in + x0_int
715 offset_01 = input_base + y0_int * W_in + x1_int
716 offset_10 = input_base + y1_int * W_in + x0_int
717 offset_11 = input_base + y1_int * W_in + x1_int
719 # For NaN, return 0.0
720 p00 = tl.load(ptr_input + offset_00)
721 p01 = tl.load(ptr_input + offset_01)
722 p10 = tl.load(ptr_input + offset_10)
723 p11 = tl.load(ptr_input + offset_11)
725 # Bilinear interpolation
726 top = p00 * (1.0 - wx) + p01 * wx
727 bottom = p10 * (1.0 - wx) + p11 * wx
728 val = tl.where(grid_x_nan | grid_y_nan, 0.0, top * (1.0 - wy) + bottom * wy)
730 # Store output
731 output_offset = n * C * H_out * W_out + c * H_out * W_out + h_out * W_out + w_out
732 tl.store(ptr_output + output_offset, val)
735@libentry()
736@triton.autotune(
737 configs=runtime.get_tuned_config("grid_sample_2d_bilinear"),
738 key=["N", "C", "H_out", "W_out"],
739)
740@triton.jit
741def grid_sample_2d_bilinear_reflection_kernel(
742 ptr_output,
743 ptr_input,
744 ptr_grid,
745 N,
746 C,
747 H_in,
748 W_in,
749 H_out,
750 W_out,
751 align_corners: tl.constexpr,
752 BLOCK_SIZE: tl.constexpr,
753):
754 """
755 Grid sample kernel for 2D bilinear interpolation with reflection padding.
757 Reflects coordinates at boundaries using triangle wave pattern in grid space.
758 """
759 # Each program instance processes one output pixel (all channels)
760 pid = tl.program_id(0)
761 nc = pid // (H_out * W_out)
762 hw = pid % (H_out * W_out)
764 n = nc // C
765 c = nc % C
766 h_out = hw // W_out
767 w_out = hw % W_out
769 # Load grid coordinates
770 grid_idx = n * H_out * W_out * 2 + h_out * W_out * 2 + w_out * 2
771 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32)
772 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32)
774 # Handle NaN
775 grid_x_nan = grid_x != grid_x
776 grid_y_nan = grid_y != grid_y
777 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
778 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
780 # Reflection padding in GRID space (before denormalizing)
781 # Triangle wave pattern with period 4
782 grid_x_shifted = grid_x + 1.0
783 grid_x_mod = grid_x_shifted % 4.0
784 grid_x_mod = tl.where(grid_x_mod < 0, grid_x_mod + 4.0, grid_x_mod)
785 grid_x_refl_mod = tl.where(grid_x_mod <= 2.0, grid_x_mod, 4.0 - grid_x_mod)
786 grid_x_refl = grid_x_refl_mod - 1.0
788 grid_y_shifted = grid_y + 1.0
789 grid_y_mod = grid_y_shifted % 4.0
790 grid_y_mod = tl.where(grid_y_mod < 0, grid_y_mod + 4.0, grid_y_mod)
791 grid_y_refl_mod = tl.where(grid_y_mod <= 2.0, grid_y_mod, 4.0 - grid_y_mod)
792 grid_y_refl = grid_y_refl_mod - 1.0
794 # Denormalize to pixel space
795 if align_corners:
796 x = (grid_x_refl + 1.0) * (W_in - 1) / 2.0
797 y = (grid_y_refl + 1.0) * (H_in - 1) / 2.0
798 else:
799 x = (grid_x_refl + 1.0) * W_in / 2.0 - 0.5
800 y = (grid_y_refl + 1.0) * H_in / 2.0 - 0.5
802 # Find 4 corner indices
803 x0 = tl.floor(x)
804 y0 = tl.floor(y)
805 x1 = x0 + 1
806 y1 = y0 + 1
808 # Convert to int and clamp for safety
809 x0_int = tl.cast(x0, tl.int32)
810 y0_int = tl.cast(y0, tl.int32)
811 x1_int = tl.cast(x1, tl.int32)
812 y1_int = tl.cast(y1, tl.int32)
814 x0_int = tl.maximum(0, tl.minimum(x0_int, W_in - 1))
815 x1_int = tl.maximum(0, tl.minimum(x1_int, W_in - 1))
816 y0_int = tl.maximum(0, tl.minimum(y0_int, H_in - 1))
817 y1_int = tl.maximum(0, tl.minimum(y1_int, H_in - 1))
819 # Compute interpolation weights
820 wx = x - x0
821 wy = y - y0
823 # Load 4 corner pixels
824 input_base = n * C * H_in * W_in + c * H_in * W_in
826 offset_00 = input_base + y0_int * W_in + x0_int
827 offset_01 = input_base + y0_int * W_in + x1_int
828 offset_10 = input_base + y1_int * W_in + x0_int
829 offset_11 = input_base + y1_int * W_in + x1_int
831 p00 = tl.load(ptr_input + offset_00)
832 p01 = tl.load(ptr_input + offset_01)
833 p10 = tl.load(ptr_input + offset_10)
834 p11 = tl.load(ptr_input + offset_11)
836 # Bilinear interpolation
837 top = p00 * (1.0 - wx) + p01 * wx
838 bottom = p10 * (1.0 - wx) + p11 * wx
839 val = tl.where(grid_x_nan | grid_y_nan, 0.0, top * (1.0 - wy) + bottom * wy)
841 # Store output
842 output_offset = n * C * H_out * W_out + c * H_out * W_out + h_out * W_out + w_out
843 tl.store(ptr_output + output_offset, val)
846# ============================================================================
847# Bicubic Interpolation Kernels (4D)
848# ============================================================================
851@libentry()
852@triton.autotune(
853 configs=runtime.get_tuned_config("grid_sample_2d_bicubic"),
854 key=["N", "C", "H_out", "W_out"],
855)
856@triton.jit
857def grid_sample_2d_bicubic_zeros_kernel(
858 ptr_output,
859 ptr_input,
860 ptr_grid,
861 N,
862 C,
863 H_in,
864 W_in,
865 H_out,
866 W_out,
867 align_corners: tl.constexpr,
868 BLOCK_SIZE: tl.constexpr,
869):
870 """
871 Grid sample kernel for 2D bicubic interpolation with zeros padding.
873 Uses Keys' cubic kernel with a=-0.5. Loads 4x4 neighborhood (16 pixels).
874 """
875 pid = tl.program_id(0)
876 nc = pid // (H_out * W_out)
877 hw = pid % (H_out * W_out)
879 n = nc // C
880 c = nc % C
881 h_out = hw // W_out
882 w_out = hw % W_out
884 # Load grid coordinates
885 grid_idx = n * H_out * W_out * 2 + h_out * W_out * 2 + w_out * 2
886 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32)
887 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32)
889 # Handle NaN
890 grid_x_nan = grid_x != grid_x
891 grid_y_nan = grid_y != grid_y
892 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
893 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
895 # Denormalize to pixel space
896 if align_corners:
897 x = (grid_x + 1.0) * (W_in - 1) / 2.0
898 y = (grid_y + 1.0) * (H_in - 1) / 2.0
899 else:
900 x = (grid_x + 1.0) * W_in / 2.0 - 0.5
901 y = (grid_y + 1.0) * H_in / 2.0 - 0.5
903 # Find 4x4 neighborhood
904 x0 = tl.floor(x) - 1
905 y0 = tl.floor(y) - 1
907 # Convert to int
908 x0_int = tl.cast(x0, tl.int32)
909 y0_int = tl.cast(y0, tl.int32)
911 # Compute interpolation weights using Keys' cubic kernel (a = -0.75)
912 # W(x) = (a+2)|x|³ - (a+3)|x|² + 1, for |x| ≤ 1
913 # W(x) = a|x|³ - 5a|x|² + 8a|x| - 4a, for 1 < |x| < 2
914 # W(x) = 0, otherwise
915 a = -0.75
917 # X weights
918 dx0 = x0 - x
919 wx0 = tl.abs(dx0)
920 weight_x0 = tl.where(
921 wx0 < 1.0,
922 ((a + 2) * wx0 - (a + 3)) * wx0 * wx0 + 1,
923 tl.where(wx0 < 2.0, ((wx0 - 5) * wx0 + 8) * wx0 * a - 4 * a, 0.0),
924 )
926 dx1 = x0 + 1 - x
927 wx1 = tl.abs(dx1)
928 weight_x1 = tl.where(
929 wx1 < 1.0,
930 ((a + 2) * wx1 - (a + 3)) * wx1 * wx1 + 1,
931 tl.where(wx1 < 2.0, ((wx1 - 5) * wx1 + 8) * wx1 * a - 4 * a, 0.0),
932 )
934 dx2 = x0 + 2 - x
935 wx2 = tl.abs(dx2)
936 weight_x2 = tl.where(
937 wx2 < 1.0,
938 ((a + 2) * wx2 - (a + 3)) * wx2 * wx2 + 1,
939 tl.where(wx2 < 2.0, ((wx2 - 5) * wx2 + 8) * wx2 * a - 4 * a, 0.0),
940 )
942 dx3 = x0 + 3 - x
943 wx3 = tl.abs(dx3)
944 weight_x3 = tl.where(
945 wx3 < 1.0,
946 ((a + 2) * wx3 - (a + 3)) * wx3 * wx3 + 1,
947 tl.where(wx3 < 2.0, ((wx3 - 5) * wx3 + 8) * wx3 * a - 4 * a, 0.0),
948 )
950 # Y weights
951 dy0 = y0 - y
952 wy0 = tl.abs(dy0)
953 weight_y0 = tl.where(
954 wy0 < 1.0,
955 ((a + 2) * wy0 - (a + 3)) * wy0 * wy0 + 1,
956 tl.where(wy0 < 2.0, ((wy0 - 5) * wy0 + 8) * wy0 * a - 4 * a, 0.0),
957 )
959 dy1 = y0 + 1 - y
960 wy1 = tl.abs(dy1)
961 weight_y1 = tl.where(
962 wy1 < 1.0,
963 ((a + 2) * wy1 - (a + 3)) * wy1 * wy1 + 1,
964 tl.where(wy1 < 2.0, ((wy1 - 5) * wy1 + 8) * wy1 * a - 4 * a, 0.0),
965 )
967 dy2 = y0 + 2 - y
968 wy2 = tl.abs(dy2)
969 weight_y2 = tl.where(
970 wy2 < 1.0,
971 ((a + 2) * wy2 - (a + 3)) * wy2 * wy2 + 1,
972 tl.where(wy2 < 2.0, ((wy2 - 5) * wy2 + 8) * wy2 * a - 4 * a, 0.0),
973 )
975 dy3 = y0 + 3 - y
976 wy3 = tl.abs(dy3)
977 weight_y3 = tl.where(
978 wy3 < 1.0,
979 ((a + 2) * wy3 - (a + 3)) * wy3 * wy3 + 1,
980 tl.where(wy3 < 2.0, ((wy3 - 5) * wy3 + 8) * wy3 * a - 4 * a, 0.0),
981 )
983 # Load 4x4 neighborhood with zeros padding (unrolled loop)
984 input_base = n * C * H_in * W_in + c * H_in * W_in
986 # Initialize accumulator
987 val = 0.0
989 # Row 0
990 y_idx0 = y0_int
991 y_in_bounds0 = (y_idx0 >= 0) & (y_idx0 < H_in)
993 x_idx00 = x0_int
994 x_in_bounds00 = (x_idx00 >= 0) & (x_idx00 < W_in)
995 offset00 = input_base + y_idx0 * W_in + x_idx00
996 val00 = tl.load(
997 ptr_input + offset00,
998 mask=x_in_bounds00 & y_in_bounds0 & ~grid_x_nan & ~grid_y_nan,
999 other=0.0,
1000 ).to(tl.float32)
1001 val += val00 * weight_x0 * weight_y0
1003 x_idx01 = x0_int + 1
1004 x_in_bounds01 = (x_idx01 >= 0) & (x_idx01 < W_in)
1005 offset01 = input_base + y_idx0 * W_in + x_idx01
1006 val01 = tl.load(
1007 ptr_input + offset01,
1008 mask=x_in_bounds01 & y_in_bounds0 & ~grid_x_nan & ~grid_y_nan,
1009 other=0.0,
1010 ).to(tl.float32)
1011 val += val01 * weight_x1 * weight_y0
1013 x_idx02 = x0_int + 2
1014 x_in_bounds02 = (x_idx02 >= 0) & (x_idx02 < W_in)
1015 offset02 = input_base + y_idx0 * W_in + x_idx02
1016 val02 = tl.load(
1017 ptr_input + offset02,
1018 mask=x_in_bounds02 & y_in_bounds0 & ~grid_x_nan & ~grid_y_nan,
1019 other=0.0,
1020 ).to(tl.float32)
1021 val += val02 * weight_x2 * weight_y0
1023 x_idx03 = x0_int + 3
1024 x_in_bounds03 = (x_idx03 >= 0) & (x_idx03 < W_in)
1025 offset03 = input_base + y_idx0 * W_in + x_idx03
1026 val03 = tl.load(
1027 ptr_input + offset03,
1028 mask=x_in_bounds03 & y_in_bounds0 & ~grid_x_nan & ~grid_y_nan,
1029 other=0.0,
1030 ).to(tl.float32)
1031 val += val03 * weight_x3 * weight_y0
1033 # Row 1
1034 y_idx1 = y0_int + 1
1035 y_in_bounds1 = (y_idx1 >= 0) & (y_idx1 < H_in)
1037 x_idx10 = x0_int
1038 x_in_bounds10 = (x_idx10 >= 0) & (x_idx10 < W_in)
1039 offset10 = input_base + y_idx1 * W_in + x_idx10
1040 val10 = tl.load(
1041 ptr_input + offset10,
1042 mask=x_in_bounds10 & y_in_bounds1 & ~grid_x_nan & ~grid_y_nan,
1043 other=0.0,
1044 ).to(tl.float32)
1045 val += val10 * weight_x0 * weight_y1
1047 x_idx11 = x0_int + 1
1048 x_in_bounds11 = (x_idx11 >= 0) & (x_idx11 < W_in)
1049 offset11 = input_base + y_idx1 * W_in + x_idx11
1050 val11 = tl.load(
1051 ptr_input + offset11,
1052 mask=x_in_bounds11 & y_in_bounds1 & ~grid_x_nan & ~grid_y_nan,
1053 other=0.0,
1054 ).to(tl.float32)
1055 val += val11 * weight_x1 * weight_y1
1057 x_idx12 = x0_int + 2
1058 x_in_bounds12 = (x_idx12 >= 0) & (x_idx12 < W_in)
1059 offset12 = input_base + y_idx1 * W_in + x_idx12
1060 val12 = tl.load(
1061 ptr_input + offset12,
1062 mask=x_in_bounds12 & y_in_bounds1 & ~grid_x_nan & ~grid_y_nan,
1063 other=0.0,
1064 ).to(tl.float32)
1065 val += val12 * weight_x2 * weight_y1
1067 x_idx13 = x0_int + 3
1068 x_in_bounds13 = (x_idx13 >= 0) & (x_idx13 < W_in)
1069 offset13 = input_base + y_idx1 * W_in + x_idx13
1070 val13 = tl.load(
1071 ptr_input + offset13,
1072 mask=x_in_bounds13 & y_in_bounds1 & ~grid_x_nan & ~grid_y_nan,
1073 other=0.0,
1074 ).to(tl.float32)
1075 val += val13 * weight_x3 * weight_y1
1077 # Row 2
1078 y_idx2 = y0_int + 2
1079 y_in_bounds2 = (y_idx2 >= 0) & (y_idx2 < H_in)
1081 x_idx20 = x0_int
1082 x_in_bounds20 = (x_idx20 >= 0) & (x_idx20 < W_in)
1083 offset20 = input_base + y_idx2 * W_in + x_idx20
1084 val20 = tl.load(
1085 ptr_input + offset20,
1086 mask=x_in_bounds20 & y_in_bounds2 & ~grid_x_nan & ~grid_y_nan,
1087 other=0.0,
1088 ).to(tl.float32)
1089 val += val20 * weight_x0 * weight_y2
1091 x_idx21 = x0_int + 1
1092 x_in_bounds21 = (x_idx21 >= 0) & (x_idx21 < W_in)
1093 offset21 = input_base + y_idx2 * W_in + x_idx21
1094 val21 = tl.load(
1095 ptr_input + offset21,
1096 mask=x_in_bounds21 & y_in_bounds2 & ~grid_x_nan & ~grid_y_nan,
1097 other=0.0,
1098 ).to(tl.float32)
1099 val += val21 * weight_x1 * weight_y2
1101 x_idx22 = x0_int + 2
1102 x_in_bounds22 = (x_idx22 >= 0) & (x_idx22 < W_in)
1103 offset22 = input_base + y_idx2 * W_in + x_idx22
1104 val22 = tl.load(
1105 ptr_input + offset22,
1106 mask=x_in_bounds22 & y_in_bounds2 & ~grid_x_nan & ~grid_y_nan,
1107 other=0.0,
1108 ).to(tl.float32)
1109 val += val22 * weight_x2 * weight_y2
1111 x_idx23 = x0_int + 3
1112 x_in_bounds23 = (x_idx23 >= 0) & (x_idx23 < W_in)
1113 offset23 = input_base + y_idx2 * W_in + x_idx23
1114 val23 = tl.load(
1115 ptr_input + offset23,
1116 mask=x_in_bounds23 & y_in_bounds2 & ~grid_x_nan & ~grid_y_nan,
1117 other=0.0,
1118 ).to(tl.float32)
1119 val += val23 * weight_x3 * weight_y2
1121 # Row 3
1122 y_idx3 = y0_int + 3
1123 y_in_bounds3 = (y_idx3 >= 0) & (y_idx3 < H_in)
1125 x_idx30 = x0_int
1126 x_in_bounds30 = (x_idx30 >= 0) & (x_idx30 < W_in)
1127 offset30 = input_base + y_idx3 * W_in + x_idx30
1128 val30 = tl.load(
1129 ptr_input + offset30,
1130 mask=x_in_bounds30 & y_in_bounds3 & ~grid_x_nan & ~grid_y_nan,
1131 other=0.0,
1132 ).to(tl.float32)
1133 val += val30 * weight_x0 * weight_y3
1135 x_idx31 = x0_int + 1
1136 x_in_bounds31 = (x_idx31 >= 0) & (x_idx31 < W_in)
1137 offset31 = input_base + y_idx3 * W_in + x_idx31
1138 val31 = tl.load(
1139 ptr_input + offset31,
1140 mask=x_in_bounds31 & y_in_bounds3 & ~grid_x_nan & ~grid_y_nan,
1141 other=0.0,
1142 ).to(tl.float32)
1143 val += val31 * weight_x1 * weight_y3
1145 x_idx32 = x0_int + 2
1146 x_in_bounds32 = (x_idx32 >= 0) & (x_idx32 < W_in)
1147 offset32 = input_base + y_idx3 * W_in + x_idx32
1148 val32 = tl.load(
1149 ptr_input + offset32,
1150 mask=x_in_bounds32 & y_in_bounds3 & ~grid_x_nan & ~grid_y_nan,
1151 other=0.0,
1152 ).to(tl.float32)
1153 val += val32 * weight_x2 * weight_y3
1155 x_idx33 = x0_int + 3
1156 x_in_bounds33 = (x_idx33 >= 0) & (x_idx33 < W_in)
1157 offset33 = input_base + y_idx3 * W_in + x_idx33
1158 val33 = tl.load(
1159 ptr_input + offset33,
1160 mask=x_in_bounds33 & y_in_bounds3 & ~grid_x_nan & ~grid_y_nan,
1161 other=0.0,
1162 ).to(tl.float32)
1163 val += val33 * weight_x3 * weight_y3
1165 # Store output
1166 output_offset = n * C * H_out * W_out + c * H_out * W_out + h_out * W_out + w_out
1167 tl.store(ptr_output + output_offset, val)
1170@libentry()
1171@triton.autotune(
1172 configs=runtime.get_tuned_config("grid_sample_2d_bicubic"),
1173 key=["N", "C", "H_out", "W_out"],
1174)
1175@triton.jit
1176def grid_sample_2d_bicubic_border_kernel(
1177 ptr_output,
1178 ptr_input,
1179 ptr_grid,
1180 N,
1181 C,
1182 H_in,
1183 W_in,
1184 H_out,
1185 W_out,
1186 align_corners: tl.constexpr,
1187 BLOCK_SIZE: tl.constexpr,
1188):
1189 """
1190 Grid sample kernel for 2D bicubic interpolation with border padding.
1191 """
1192 pid = tl.program_id(0)
1193 nc = pid // (H_out * W_out)
1194 hw = pid % (H_out * W_out)
1196 n = nc // C
1197 c = nc % C
1198 h_out = hw // W_out
1199 w_out = hw % W_out
1201 # Load grid coordinates
1202 grid_idx = n * H_out * W_out * 2 + h_out * W_out * 2 + w_out * 2
1203 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32)
1204 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32)
1206 # Handle NaN
1207 grid_x_nan = grid_x != grid_x
1208 grid_y_nan = grid_y != grid_y
1209 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
1210 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
1212 # Denormalize to pixel space
1213 if align_corners:
1214 x = (grid_x + 1.0) * (W_in - 1) / 2.0
1215 y = (grid_y + 1.0) * (H_in - 1) / 2.0
1216 else:
1217 x = (grid_x + 1.0) * W_in / 2.0 - 0.5
1218 y = (grid_y + 1.0) * H_in / 2.0 - 0.5
1220 # Find 4x4 neighborhood
1221 x0 = tl.floor(x) - 1
1222 y0 = tl.floor(y) - 1
1223 x0_int = tl.cast(x0, tl.int32)
1224 y0_int = tl.cast(y0, tl.int32)
1226 # Compute Keys' cubic weights (a = -0.75)
1227 a = -0.75
1229 # X weights - compute inline for each pixel
1230 # W(x) = (a+2)|x|³ - (a+3)|x|² + 1, for |x| ≤ 1
1231 # W(x) = a|x|³ - 5a|x|² + 8a|x| - 4a, for 1 < |x| < 2
1233 # Load 4x4 neighborhood with border padding
1234 input_base = n * C * H_in * W_in + c * H_in * W_in
1235 val = 0.0
1237 # Unrolled loop for 4x4 neighborhood
1238 # Row 0
1239 y_idx = y0_int
1240 y_idx_clamped = tl.maximum(0, tl.minimum(y_idx, H_in - 1))
1241 dy0 = y0 - y
1242 wy0 = tl.abs(dy0)
1243 weight_y0 = tl.where(
1244 wy0 < 1.0,
1245 ((a + 2) * wy0 - (a + 3)) * wy0 * wy0 + 1,
1246 tl.where(wy0 < 2.0, ((wy0 - 5) * wy0 + 8) * wy0 * a - 4 * a, 0.0),
1247 )
1249 # Col 0
1250 x_idx = x0_int
1251 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1))
1252 dx0 = x0 - x
1253 wx0 = tl.abs(dx0)
1254 weight_x0 = tl.where(
1255 wx0 < 1.0,
1256 ((a + 2) * wx0 - (a + 3)) * wx0 * wx0 + 1,
1257 tl.where(wx0 < 2.0, ((wx0 - 5) * wx0 + 8) * wx0 * a - 4 * a, 0.0),
1258 )
1259 offset = input_base + y_idx_clamped * W_in + x_idx_clamped
1260 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x0 * weight_y0
1262 # Col 1
1263 x_idx = x0_int + 1
1264 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1))
1265 dx1 = x0 + 1 - x
1266 wx1 = tl.abs(dx1)
1267 weight_x1 = tl.where(
1268 wx1 < 1.0,
1269 ((a + 2) * wx1 - (a + 3)) * wx1 * wx1 + 1,
1270 tl.where(wx1 < 2.0, ((wx1 - 5) * wx1 + 8) * wx1 * a - 4 * a, 0.0),
1271 )
1272 offset = input_base + y_idx_clamped * W_in + x_idx_clamped
1273 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x1 * weight_y0
1275 # Col 2
1276 x_idx = x0_int + 2
1277 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1))
1278 dx2 = x0 + 2 - x
1279 wx2 = tl.abs(dx2)
1280 weight_x2 = tl.where(
1281 wx2 < 1.0,
1282 ((a + 2) * wx2 - (a + 3)) * wx2 * wx2 + 1,
1283 tl.where(wx2 < 2.0, ((wx2 - 5) * wx2 + 8) * wx2 * a - 4 * a, 0.0),
1284 )
1285 offset = input_base + y_idx_clamped * W_in + x_idx_clamped
1286 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x2 * weight_y0
1288 # Col 3
1289 x_idx = x0_int + 3
1290 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1))
1291 dx3 = x0 + 3 - x
1292 wx3 = tl.abs(dx3)
1293 weight_x3 = tl.where(
1294 wx3 < 1.0,
1295 ((a + 2) * wx3 - (a + 3)) * wx3 * wx3 + 1,
1296 tl.where(wx3 < 2.0, ((wx3 - 5) * wx3 + 8) * wx3 * a - 4 * a, 0.0),
1297 )
1298 offset = input_base + y_idx_clamped * W_in + x_idx_clamped
1299 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x3 * weight_y0
1301 # Row 1
1302 y_idx = y0_int + 1
1303 y_idx_clamped = tl.maximum(0, tl.minimum(y_idx, H_in - 1))
1304 dy1 = y0 + 1 - y
1305 wy1 = tl.abs(dy1)
1306 weight_y1 = tl.where(
1307 wy1 < 1.0,
1308 ((a + 2) * wy1 - (a + 3)) * wy1 * wy1 + 1,
1309 tl.where(wy1 < 2.0, ((wy1 - 5) * wy1 + 8) * wy1 * a - 4 * a, 0.0),
1310 )
1312 x_idx = x0_int
1313 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1))
1314 offset = input_base + y_idx_clamped * W_in + x_idx_clamped
1315 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x0 * weight_y1
1317 x_idx = x0_int + 1
1318 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1))
1319 offset = input_base + y_idx_clamped * W_in + x_idx_clamped
1320 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x1 * weight_y1
1322 x_idx = x0_int + 2
1323 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1))
1324 offset = input_base + y_idx_clamped * W_in + x_idx_clamped
1325 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x2 * weight_y1
1327 x_idx = x0_int + 3
1328 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1))
1329 offset = input_base + y_idx_clamped * W_in + x_idx_clamped
1330 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x3 * weight_y1
1332 # Row 2
1333 y_idx = y0_int + 2
1334 y_idx_clamped = tl.maximum(0, tl.minimum(y_idx, H_in - 1))
1335 dy2 = y0 + 2 - y
1336 wy2 = tl.abs(dy2)
1337 weight_y2 = tl.where(
1338 wy2 < 1.0,
1339 ((a + 2) * wy2 - (a + 3)) * wy2 * wy2 + 1,
1340 tl.where(wy2 < 2.0, ((wy2 - 5) * wy2 + 8) * wy2 * a - 4 * a, 0.0),
1341 )
1343 x_idx = x0_int
1344 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1))
1345 offset = input_base + y_idx_clamped * W_in + x_idx_clamped
1346 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x0 * weight_y2
1348 x_idx = x0_int + 1
1349 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1))
1350 offset = input_base + y_idx_clamped * W_in + x_idx_clamped
1351 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x1 * weight_y2
1353 x_idx = x0_int + 2
1354 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1))
1355 offset = input_base + y_idx_clamped * W_in + x_idx_clamped
1356 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x2 * weight_y2
1358 x_idx = x0_int + 3
1359 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1))
1360 offset = input_base + y_idx_clamped * W_in + x_idx_clamped
1361 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x3 * weight_y2
1363 # Row 3
1364 y_idx = y0_int + 3
1365 y_idx_clamped = tl.maximum(0, tl.minimum(y_idx, H_in - 1))
1366 dy3 = y0 + 3 - y
1367 wy3 = tl.abs(dy3)
1368 weight_y3 = tl.where(
1369 wy3 < 1.0,
1370 ((a + 2) * wy3 - (a + 3)) * wy3 * wy3 + 1,
1371 tl.where(wy3 < 2.0, ((wy3 - 5) * wy3 + 8) * wy3 * a - 4 * a, 0.0),
1372 )
1374 x_idx = x0_int
1375 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1))
1376 offset = input_base + y_idx_clamped * W_in + x_idx_clamped
1377 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x0 * weight_y3
1379 x_idx = x0_int + 1
1380 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1))
1381 offset = input_base + y_idx_clamped * W_in + x_idx_clamped
1382 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x1 * weight_y3
1384 x_idx = x0_int + 2
1385 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1))
1386 offset = input_base + y_idx_clamped * W_in + x_idx_clamped
1387 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x2 * weight_y3
1389 x_idx = x0_int + 3
1390 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1))
1391 offset = input_base + y_idx_clamped * W_in + x_idx_clamped
1392 val += tl.load(ptr_input + offset).to(tl.float32) * weight_x3 * weight_y3
1394 # Handle NaN
1395 val = tl.where(grid_x_nan | grid_y_nan, 0.0, val)
1397 # Store output
1398 output_offset = n * C * H_out * W_out + c * H_out * W_out + h_out * W_out + w_out
1399 tl.store(ptr_output + output_offset, val)
1402@libentry()
1403@triton.autotune(
1404 configs=runtime.get_tuned_config("grid_sample_2d_bicubic"),
1405 key=["N", "C", "H_out", "W_out"],
1406)
1407@triton.jit
1408def grid_sample_2d_bicubic_reflection_kernel(
1409 ptr_output,
1410 ptr_input,
1411 ptr_grid,
1412 N,
1413 C,
1414 H_in,
1415 W_in,
1416 H_out,
1417 W_out,
1418 align_corners: tl.constexpr,
1419 BLOCK_SIZE: tl.constexpr,
1420):
1421 """
1422 Grid sample kernel for 2D bicubic interpolation with reflection padding.
1423 """
1424 pid = tl.program_id(0)
1425 nc = pid // (H_out * W_out)
1426 hw = pid % (H_out * W_out)
1428 n = nc // C
1429 c = nc % C
1430 h_out = hw // W_out
1431 w_out = hw % W_out
1433 # Load grid coordinates
1434 grid_idx = n * H_out * W_out * 2 + h_out * W_out * 2 + w_out * 2
1435 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32)
1436 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32)
1438 # Handle NaN
1439 grid_x_nan = grid_x != grid_x
1440 grid_y_nan = grid_y != grid_y
1441 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
1442 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
1444 # Reflection padding in GRID space
1445 grid_x_shifted = grid_x + 1.0
1446 grid_x_mod = grid_x_shifted % 4.0
1447 grid_x_mod = tl.where(grid_x_mod < 0, grid_x_mod + 4.0, grid_x_mod)
1448 grid_x_refl_mod = tl.where(grid_x_mod <= 2.0, grid_x_mod, 4.0 - grid_x_mod)
1449 grid_x_refl = grid_x_refl_mod - 1.0
1451 grid_y_shifted = grid_y + 1.0
1452 grid_y_mod = grid_y_shifted % 4.0
1453 grid_y_mod = tl.where(grid_y_mod < 0, grid_y_mod + 4.0, grid_y_mod)
1454 grid_y_refl_mod = tl.where(grid_y_mod <= 2.0, grid_y_mod, 4.0 - grid_y_mod)
1455 grid_y_refl = grid_y_refl_mod - 1.0
1457 # Denormalize to pixel space
1458 if align_corners:
1459 x = (grid_x_refl + 1.0) * (W_in - 1) / 2.0
1460 y = (grid_y_refl + 1.0) * (H_in - 1) / 2.0
1461 else:
1462 x = (grid_x_refl + 1.0) * W_in / 2.0 - 0.5
1463 y = (grid_y_refl + 1.0) * H_in / 2.0 - 0.5
1465 # Find 4x4 neighborhood
1466 x0 = tl.floor(x) - 1
1467 y0 = tl.floor(y) - 1
1468 x0_int = tl.cast(x0, tl.int32)
1469 y0_int = tl.cast(y0, tl.int32)
1471 # Clamp for safety
1472 x0_int = tl.maximum(0, tl.minimum(x0_int, W_in - 1))
1473 y0_int = tl.maximum(0, tl.minimum(y0_int, H_in - 1))
1475 # Compute Keys' cubic weights (a = -0.75)
1476 a = -0.75
1478 # Pre-compute X weights
1479 dx0 = x0 - x
1480 wx0 = tl.abs(dx0)
1481 weight_x0 = tl.where(
1482 wx0 < 1.0,
1483 ((a + 2) * wx0 - (a + 3)) * wx0 * wx0 + 1,
1484 tl.where(wx0 < 2.0, ((wx0 - 5) * wx0 + 8) * wx0 * a - 4 * a, 0.0),
1485 )
1487 dx1 = x0 + 1 - x
1488 wx1 = tl.abs(dx1)
1489 weight_x1 = tl.where(
1490 wx1 < 1.0,
1491 ((a + 2) * wx1 - (a + 3)) * wx1 * wx1 + 1,
1492 tl.where(wx1 < 2.0, ((wx1 - 5) * wx1 + 8) * wx1 * a - 4 * a, 0.0),
1493 )
1495 dx2 = x0 + 2 - x
1496 wx2 = tl.abs(dx2)
1497 weight_x2 = tl.where(
1498 wx2 < 1.0,
1499 ((a + 2) * wx2 - (a + 3)) * wx2 * wx2 + 1,
1500 tl.where(wx2 < 2.0, ((wx2 - 5) * wx2 + 8) * wx2 * a - 4 * a, 0.0),
1501 )
1503 dx3 = x0 + 3 - x
1504 wx3 = tl.abs(dx3)
1505 weight_x3 = tl.where(
1506 wx3 < 1.0,
1507 ((a + 2) * wx3 - (a + 3)) * wx3 * wx3 + 1,
1508 tl.where(wx3 < 2.0, ((wx3 - 5) * wx3 + 8) * wx3 * a - 4 * a, 0.0),
1509 )
1511 # Pre-compute Y weights
1512 dy0 = y0 - y
1513 wy0 = tl.abs(dy0)
1514 weight_y0 = tl.where(
1515 wy0 < 1.0,
1516 ((a + 2) * wy0 - (a + 3)) * wy0 * wy0 + 1,
1517 tl.where(wy0 < 2.0, ((wy0 - 5) * wy0 + 8) * wy0 * a - 4 * a, 0.0),
1518 )
1520 dy1 = y0 + 1 - y
1521 wy1 = tl.abs(dy1)
1522 weight_y1 = tl.where(
1523 wy1 < 1.0,
1524 ((a + 2) * wy1 - (a + 3)) * wy1 * wy1 + 1,
1525 tl.where(wy1 < 2.0, ((wy1 - 5) * wy1 + 8) * wy1 * a - 4 * a, 0.0),
1526 )
1528 dy2 = y0 + 2 - y
1529 wy2 = tl.abs(dy2)
1530 weight_y2 = tl.where(
1531 wy2 < 1.0,
1532 ((a + 2) * wy2 - (a + 3)) * wy2 * wy2 + 1,
1533 tl.where(wy2 < 2.0, ((wy2 - 5) * wy2 + 8) * wy2 * a - 4 * a, 0.0),
1534 )
1536 dy3 = y0 + 3 - y
1537 wy3 = tl.abs(dy3)
1538 weight_y3 = tl.where(
1539 wy3 < 1.0,
1540 ((a + 2) * wy3 - (a + 3)) * wy3 * wy3 + 1,
1541 tl.where(wy3 < 2.0, ((wy3 - 5) * wy3 + 8) * wy3 * a - 4 * a, 0.0),
1542 )
1544 # Load 4x4 neighborhood with clamping (reflection already applied)
1545 input_base = n * C * H_in * W_in + c * H_in * W_in
1546 val = 0.0
1548 # Unrolled loops for 4x4 neighborhood
1549 for i in range(4):
1550 y_idx = y0_int + i
1551 y_idx_clamped = tl.maximum(0, tl.minimum(y_idx, H_in - 1))
1552 weight_y = tl.where(
1553 i == 0,
1554 weight_y0,
1555 tl.where(i == 1, weight_y1, tl.where(i == 2, weight_y2, weight_y3)),
1556 )
1558 for j in range(4):
1559 x_idx = x0_int + j
1560 x_idx_clamped = tl.maximum(0, tl.minimum(x_idx, W_in - 1))
1561 weight_x = tl.where(
1562 j == 0,
1563 weight_x0,
1564 tl.where(j == 1, weight_x1, tl.where(j == 2, weight_x2, weight_x3)),
1565 )
1567 offset = input_base + y_idx_clamped * W_in + x_idx_clamped
1568 pixel_val = tl.load(ptr_input + offset).to(tl.float32)
1569 val += pixel_val * weight_x * weight_y
1571 # Handle NaN
1572 val = tl.where(grid_x_nan | grid_y_nan, 0.0, val)
1574 # Store output
1575 output_offset = n * C * H_out * W_out + c * H_out * W_out + h_out * W_out + w_out
1576 tl.store(ptr_output + output_offset, val)
1579# ============================================================================
1580# 5D Support Kernels (Volumetric Data)
1581# ============================================================================
1584@libentry()
1585@triton.autotune(
1586 configs=runtime.get_tuned_config("grid_sample_3d_nearest"),
1587 key=["N", "C", "D_out", "H_out", "W_out"],
1588)
1589@triton.jit
1590def grid_sample_3d_nearest_zeros_kernel(
1591 ptr_output,
1592 ptr_input,
1593 ptr_grid,
1594 N,
1595 C,
1596 D_in,
1597 H_in,
1598 W_in,
1599 D_out,
1600 H_out,
1601 W_out,
1602 align_corners: tl.constexpr,
1603 BLOCK_SIZE: tl.constexpr,
1604):
1605 """
1606 Grid sample kernel for 3D nearest neighbor interpolation with zeros padding.
1607 Handles 5D input (N, C, D_in, H_in, W_in) and 5D grid (N, D_out, H_out, W_out, 3).
1608 """
1609 pid = tl.program_id(0)
1610 ncd = pid // (D_out * H_out * W_out)
1611 dhw = pid % (D_out * H_out * W_out)
1613 n = ncd // C
1614 c = ncd % C
1615 d_out = dhw // (H_out * W_out)
1616 hw = dhw % (H_out * W_out)
1617 h_out = hw // W_out
1618 w_out = hw % W_out
1620 # Load 3D grid coordinates
1621 grid_idx = (
1622 n * D_out * H_out * W_out * 3
1623 + d_out * H_out * W_out * 3
1624 + h_out * W_out * 3
1625 + w_out * 3
1626 )
1627 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32)
1628 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32)
1629 grid_z = tl.load(ptr_grid + grid_idx + 2).to(tl.float32)
1631 # Handle NaN
1632 grid_x_nan = grid_x != grid_x
1633 grid_y_nan = grid_y != grid_y
1634 grid_z_nan = grid_z != grid_z
1635 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
1636 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
1637 grid_z = tl.where(grid_z_nan, -2.0, grid_z)
1639 # Denormalize to pixel space
1640 if align_corners:
1641 x = (grid_x + 1.0) * (W_in - 1) / 2.0
1642 y = (grid_y + 1.0) * (H_in - 1) / 2.0
1643 z = (grid_z + 1.0) * (D_in - 1) / 2.0
1644 else:
1645 x = (grid_x + 1.0) * W_in / 2.0 - 0.5
1646 y = (grid_y + 1.0) * H_in / 2.0 - 0.5
1647 z = (grid_z + 1.0) * D_in / 2.0 - 0.5
1649 # Banker's rounding for all three coordinates
1650 x_floor = tl.floor(x)
1651 y_floor = tl.floor(y)
1652 z_floor = tl.floor(z)
1653 x_frac = x - x_floor
1654 y_frac = y - y_floor
1655 z_frac = z - z_floor
1656 x_is_half = x_frac == 0.5
1657 y_is_half = y_frac == 0.5
1658 z_is_half = z_frac == 0.5
1659 x_floor_int = tl.cast(x_floor, tl.int32)
1660 y_floor_int = tl.cast(y_floor, tl.int32)
1661 z_floor_int = tl.cast(z_floor, tl.int32)
1662 x_is_even = x_floor_int % 2 == 0
1663 y_is_even = y_floor_int % 2 == 0
1664 z_is_even = z_floor_int % 2 == 0
1665 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1)
1666 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1)
1667 z_round = tl.where(z_frac < 0.5, z_floor, z_floor + 1)
1668 x_idx = tl.cast(
1669 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round),
1670 tl.int32,
1671 )
1672 y_idx = tl.cast(
1673 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round),
1674 tl.int32,
1675 )
1676 z_idx = tl.cast(
1677 tl.where(z_is_half, tl.where(z_is_even, z_floor, z_floor + 1), z_round),
1678 tl.int32,
1679 )
1681 # Check bounds for 3D
1682 mask = (
1683 (x_idx >= 0)
1684 & (x_idx < W_in)
1685 & (y_idx >= 0)
1686 & (y_idx < H_in)
1687 & (z_idx >= 0)
1688 & (z_idx < D_in)
1689 & ~grid_x_nan
1690 & ~grid_y_nan
1691 & ~grid_z_nan
1692 )
1694 # Load input pixel (5D tensor: N, C, D, H, W)
1695 input_offset = (
1696 n * C * D_in * H_in * W_in
1697 + c * D_in * H_in * W_in
1698 + z_idx * H_in * W_in
1699 + y_idx * W_in
1700 + x_idx
1701 )
1702 val = tl.load(ptr_input + input_offset, mask=mask, other=0.0).to(tl.float32)
1704 # Store output (5D tensor: N, C, D, H, W)
1705 output_offset = (
1706 n * C * D_out * H_out * W_out
1707 + c * D_out * H_out * W_out
1708 + d_out * H_out * W_out
1709 + h_out * W_out
1710 + w_out
1711 )
1712 tl.store(ptr_output + output_offset, val)
1715@libentry()
1716@triton.autotune(
1717 configs=runtime.get_tuned_config("grid_sample_3d_nearest"),
1718 key=["N", "C", "D_out", "H_out", "W_out"],
1719)
1720@triton.jit
1721def grid_sample_3d_nearest_border_kernel(
1722 ptr_output,
1723 ptr_input,
1724 ptr_grid,
1725 N,
1726 C,
1727 D_in,
1728 H_in,
1729 W_in,
1730 D_out,
1731 H_out,
1732 W_out,
1733 align_corners: tl.constexpr,
1734 BLOCK_SIZE: tl.constexpr,
1735):
1736 """
1737 Grid sample kernel for 3D nearest neighbor interpolation with border padding.
1738 """
1739 pid = tl.program_id(0)
1740 ncd = pid // (D_out * H_out * W_out)
1741 dhw = pid % (D_out * H_out * W_out)
1743 n = ncd // C
1744 c = ncd % C
1745 d_out = dhw // (H_out * W_out)
1746 hw = dhw % (H_out * W_out)
1747 h_out = hw // W_out
1748 w_out = hw % W_out
1750 # Load 3D grid coordinates
1751 grid_idx = (
1752 n * D_out * H_out * W_out * 3
1753 + d_out * H_out * W_out * 3
1754 + h_out * W_out * 3
1755 + w_out * 3
1756 )
1757 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32)
1758 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32)
1759 grid_z = tl.load(ptr_grid + grid_idx + 2).to(tl.float32)
1761 # Handle NaN
1762 grid_x_nan = grid_x != grid_x
1763 grid_y_nan = grid_y != grid_y
1764 grid_z_nan = grid_z != grid_z
1765 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
1766 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
1767 grid_z = tl.where(grid_z_nan, -2.0, grid_z)
1769 # Denormalize to pixel space
1770 if align_corners:
1771 x = (grid_x + 1.0) * (W_in - 1) / 2.0
1772 y = (grid_y + 1.0) * (H_in - 1) / 2.0
1773 z = (grid_z + 1.0) * (D_in - 1) / 2.0
1774 else:
1775 x = (grid_x + 1.0) * W_in / 2.0 - 0.5
1776 y = (grid_y + 1.0) * H_in / 2.0 - 0.5
1777 z = (grid_z + 1.0) * D_in / 2.0 - 0.5
1779 # Banker's rounding
1780 x_floor = tl.floor(x)
1781 y_floor = tl.floor(y)
1782 z_floor = tl.floor(z)
1783 x_frac = x - x_floor
1784 y_frac = y - y_floor
1785 z_frac = z - z_floor
1786 x_is_half = x_frac == 0.5
1787 y_is_half = y_frac == 0.5
1788 z_is_half = z_frac == 0.5
1789 x_floor_int = tl.cast(x_floor, tl.int32)
1790 y_floor_int = tl.cast(y_floor, tl.int32)
1791 z_floor_int = tl.cast(z_floor, tl.int32)
1792 x_is_even = x_floor_int % 2 == 0
1793 y_is_even = y_floor_int % 2 == 0
1794 z_is_even = z_floor_int % 2 == 0
1795 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1)
1796 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1)
1797 z_round = tl.where(z_frac < 0.5, z_floor, z_floor + 1)
1798 x_idx = tl.cast(
1799 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round),
1800 tl.int32,
1801 )
1802 y_idx = tl.cast(
1803 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round),
1804 tl.int32,
1805 )
1806 z_idx = tl.cast(
1807 tl.where(z_is_half, tl.where(z_is_even, z_floor, z_floor + 1), z_round),
1808 tl.int32,
1809 )
1811 # Clamp to valid bounds (border padding)
1812 x_idx = tl.maximum(0, tl.minimum(x_idx, W_in - 1))
1813 y_idx = tl.maximum(0, tl.minimum(y_idx, H_in - 1))
1814 z_idx = tl.maximum(0, tl.minimum(z_idx, D_in - 1))
1816 # Load input pixel
1817 val = tl.where(
1818 grid_x_nan | grid_y_nan | grid_z_nan,
1819 0.0,
1820 tl.load(
1821 ptr_input
1822 + n * C * D_in * H_in * W_in
1823 + c * D_in * H_in * W_in
1824 + z_idx * H_in * W_in
1825 + y_idx * W_in
1826 + x_idx
1827 ).to(tl.float32),
1828 )
1830 # Store output
1831 output_offset = (
1832 n * C * D_out * H_out * W_out
1833 + c * D_out * H_out * W_out
1834 + d_out * H_out * W_out
1835 + h_out * W_out
1836 + w_out
1837 )
1838 tl.store(ptr_output + output_offset, val)
1841@libentry()
1842@triton.autotune(
1843 configs=runtime.get_tuned_config("grid_sample_3d_nearest"),
1844 key=["N", "C", "D_out", "H_out", "W_out"],
1845)
1846@triton.jit
1847def grid_sample_3d_nearest_reflection_kernel(
1848 ptr_output,
1849 ptr_input,
1850 ptr_grid,
1851 N,
1852 C,
1853 D_in,
1854 H_in,
1855 W_in,
1856 D_out,
1857 H_out,
1858 W_out,
1859 align_corners: tl.constexpr,
1860 BLOCK_SIZE: tl.constexpr,
1861):
1862 """
1863 Grid sample kernel for 3D nearest neighbor interpolation with reflection padding.
1864 """
1865 pid = tl.program_id(0)
1866 ncd = pid // (D_out * H_out * W_out)
1867 dhw = pid % (D_out * H_out * W_out)
1869 n = ncd // C
1870 c = ncd % C
1871 d_out = dhw // (H_out * W_out)
1872 hw = dhw % (H_out * W_out)
1873 h_out = hw // W_out
1874 w_out = hw % W_out
1876 # Load 3D grid coordinates
1877 grid_idx = (
1878 n * D_out * H_out * W_out * 3
1879 + d_out * H_out * W_out * 3
1880 + h_out * W_out * 3
1881 + w_out * 3
1882 )
1883 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32)
1884 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32)
1885 grid_z = tl.load(ptr_grid + grid_idx + 2).to(tl.float32)
1887 # Handle NaN
1888 grid_x_nan = grid_x != grid_x
1889 grid_y_nan = grid_y != grid_y
1890 grid_z_nan = grid_z != grid_z
1891 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
1892 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
1893 grid_z = tl.where(grid_z_nan, -2.0, grid_z)
1895 # Reflection padding in GRID space (triangle wave with period 4)
1896 grid_x_shifted = grid_x + 1.0
1897 grid_x_mod = grid_x_shifted % 4.0
1898 grid_x_mod = tl.where(grid_x_mod < 0, grid_x_mod + 4.0, grid_x_mod)
1899 grid_x_refl_mod = tl.where(grid_x_mod <= 2.0, grid_x_mod, 4.0 - grid_x_mod)
1900 grid_x_refl = grid_x_refl_mod - 1.0
1902 grid_y_shifted = grid_y + 1.0
1903 grid_y_mod = grid_y_shifted % 4.0
1904 grid_y_mod = tl.where(grid_y_mod < 0, grid_y_mod + 4.0, grid_y_mod)
1905 grid_y_refl_mod = tl.where(grid_y_mod <= 2.0, grid_y_mod, 4.0 - grid_y_mod)
1906 grid_y_refl = grid_y_refl_mod - 1.0
1908 grid_z_shifted = grid_z + 1.0
1909 grid_z_mod = grid_z_shifted % 4.0
1910 grid_z_mod = tl.where(grid_z_mod < 0, grid_z_mod + 4.0, grid_z_mod)
1911 grid_z_refl_mod = tl.where(grid_z_mod <= 2.0, grid_z_mod, 4.0 - grid_z_mod)
1912 grid_z_refl = grid_z_refl_mod - 1.0
1914 # Denormalize to pixel space
1915 if align_corners:
1916 x = (grid_x_refl + 1.0) * (W_in - 1) / 2.0
1917 y = (grid_y_refl + 1.0) * (H_in - 1) / 2.0
1918 z = (grid_z_refl + 1.0) * (D_in - 1) / 2.0
1919 else:
1920 x = (grid_x_refl + 1.0) * W_in / 2.0 - 0.5
1921 y = (grid_y_refl + 1.0) * H_in / 2.0 - 0.5
1922 z = (grid_z_refl + 1.0) * D_in / 2.0 - 0.5
1924 # Banker's rounding
1925 x_floor = tl.floor(x)
1926 y_floor = tl.floor(y)
1927 z_floor = tl.floor(z)
1928 x_frac = x - x_floor
1929 y_frac = y - y_floor
1930 z_frac = z - z_floor
1931 x_is_half = x_frac == 0.5
1932 y_is_half = y_frac == 0.5
1933 z_is_half = z_frac == 0.5
1934 x_floor_int = tl.cast(x_floor, tl.int32)
1935 y_floor_int = tl.cast(y_floor, tl.int32)
1936 z_floor_int = tl.cast(z_floor, tl.int32)
1937 x_is_even = x_floor_int % 2 == 0
1938 y_is_even = y_floor_int % 2 == 0
1939 z_is_even = z_floor_int % 2 == 0
1940 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1)
1941 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1)
1942 z_round = tl.where(z_frac < 0.5, z_floor, z_floor + 1)
1943 x_idx = tl.cast(
1944 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round),
1945 tl.int32,
1946 )
1947 y_idx = tl.cast(
1948 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round),
1949 tl.int32,
1950 )
1951 z_idx = tl.cast(
1952 tl.where(z_is_half, tl.where(z_is_even, z_floor, z_floor + 1), z_round),
1953 tl.int32,
1954 )
1956 # Clamp for safety
1957 x_idx = tl.maximum(0, tl.minimum(x_idx, W_in - 1))
1958 y_idx = tl.maximum(0, tl.minimum(y_idx, H_in - 1))
1959 z_idx = tl.maximum(0, tl.minimum(z_idx, D_in - 1))
1961 # Load input pixel
1962 val = tl.where(
1963 grid_x_nan | grid_y_nan | grid_z_nan,
1964 0.0,
1965 tl.load(
1966 ptr_input
1967 + n * C * D_in * H_in * W_in
1968 + c * D_in * H_in * W_in
1969 + z_idx * H_in * W_in
1970 + y_idx * W_in
1971 + x_idx
1972 ).to(tl.float32),
1973 )
1975 # Store output
1976 output_offset = (
1977 n * C * D_out * H_out * W_out
1978 + c * D_out * H_out * W_out
1979 + d_out * H_out * W_out
1980 + h_out * W_out
1981 + w_out
1982 )
1983 tl.store(ptr_output + output_offset, val)
1986@libentry()
1987@triton.autotune(
1988 configs=runtime.get_tuned_config("grid_sample_3d_trilinear"),
1989 key=["N", "C", "D_out", "H_out", "W_out"],
1990)
1991@triton.jit
1992def grid_sample_3d_trilinear_zeros_kernel(
1993 ptr_output,
1994 ptr_input,
1995 ptr_grid,
1996 N,
1997 C,
1998 D_in,
1999 H_in,
2000 W_in,
2001 D_out,
2002 H_out,
2003 W_out,
2004 align_corners: tl.constexpr,
2005 BLOCK_SIZE: tl.constexpr,
2006):
2007 """
2008 Grid sample kernel for 3D trilinear interpolation with zeros padding.
2009 Loads 8 corner pixels and performs trilinear interpolation.
2010 """
2011 pid = tl.program_id(0)
2012 ncd = pid // (D_out * H_out * W_out)
2013 dhw = pid % (D_out * H_out * W_out)
2015 n = ncd // C
2016 c = ncd % C
2017 d_out = dhw // (H_out * W_out)
2018 hw = dhw % (H_out * W_out)
2019 h_out = hw // W_out
2020 w_out = hw % W_out
2022 # Load 3D grid coordinates
2023 grid_idx = (
2024 n * D_out * H_out * W_out * 3
2025 + d_out * H_out * W_out * 3
2026 + h_out * W_out * 3
2027 + w_out * 3
2028 )
2029 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32)
2030 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32)
2031 grid_z = tl.load(ptr_grid + grid_idx + 2).to(tl.float32)
2033 # Handle NaN
2034 grid_x_nan = grid_x != grid_x
2035 grid_y_nan = grid_y != grid_y
2036 grid_z_nan = grid_z != grid_z
2037 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
2038 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
2039 grid_z = tl.where(grid_z_nan, -2.0, grid_z)
2041 # Denormalize to pixel space
2042 if align_corners:
2043 x = (grid_x + 1.0) * (W_in - 1) / 2.0
2044 y = (grid_y + 1.0) * (H_in - 1) / 2.0
2045 z = (grid_z + 1.0) * (D_in - 1) / 2.0
2046 else:
2047 x = (grid_x + 1.0) * W_in / 2.0 - 0.5
2048 y = (grid_y + 1.0) * H_in / 2.0 - 0.5
2049 z = (grid_z + 1.0) * D_in / 2.0 - 0.5
2051 # Find 8 corner indices (2x2x2)
2052 x0 = tl.floor(x)
2053 y0 = tl.floor(y)
2054 z0 = tl.floor(z)
2055 x1 = x0 + 1
2056 y1 = y0 + 1
2057 z1 = z0 + 1
2059 # Compute interpolation weights
2060 wx = x - x0
2061 wy = y - y0
2062 wz = z - z0
2064 # Convert to int
2065 x0_int = tl.cast(x0, tl.int32)
2066 y0_int = tl.cast(y0, tl.int32)
2067 z0_int = tl.cast(z0, tl.int32)
2068 x1_int = tl.cast(x1, tl.int32)
2069 y1_int = tl.cast(y1, tl.int32)
2070 z1_int = tl.cast(z1, tl.int32)
2072 # Check bounds for each corner (zeros padding)
2073 x0_in = (x0_int >= 0) & (x0_int < W_in)
2074 x1_in = (x1_int >= 0) & (x1_int < W_in)
2075 y0_in = (y0_int >= 0) & (y0_int < H_in)
2076 y1_in = (y1_int >= 0) & (y1_int < H_in)
2077 z0_in = (z0_int >= 0) & (z0_int < D_in)
2078 z1_in = (z1_int >= 0) & (z1_int < D_in)
2080 # Load 8 corner pixels with zeros padding
2081 input_base = n * C * D_in * H_in * W_in + c * D_in * H_in * W_in
2083 # z=y=x=0,0,0
2084 offset = input_base + z0_int * H_in * W_in + y0_int * W_in + x0_int
2085 p000 = tl.load(
2086 ptr_input + offset,
2087 mask=x0_in & y0_in & z0_in & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan,
2088 other=0.0,
2089 ).to(tl.float32)
2091 # z=y=0, x=1
2092 offset = input_base + z0_int * H_in * W_in + y0_int * W_in + x1_int
2093 p001 = tl.load(
2094 ptr_input + offset,
2095 mask=x1_in & y0_in & z0_in & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan,
2096 other=0.0,
2097 ).to(tl.float32)
2099 # z=0, y=1, x=0
2100 offset = input_base + z0_int * H_in * W_in + y1_int * W_in + x0_int
2101 p010 = tl.load(
2102 ptr_input + offset,
2103 mask=x0_in & y1_in & z0_in & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan,
2104 other=0.0,
2105 ).to(tl.float32)
2107 # z=0, y=1, x=1
2108 offset = input_base + z0_int * H_in * W_in + y1_int * W_in + x1_int
2109 p011 = tl.load(
2110 ptr_input + offset,
2111 mask=x1_in & y1_in & z0_in & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan,
2112 other=0.0,
2113 ).to(tl.float32)
2115 # z=1, y=x=0,0
2116 offset = input_base + z1_int * H_in * W_in + y0_int * W_in + x0_int
2117 p100 = tl.load(
2118 ptr_input + offset,
2119 mask=x0_in & y0_in & z1_in & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan,
2120 other=0.0,
2121 ).to(tl.float32)
2123 # z=1, y=0, x=1
2124 offset = input_base + z1_int * H_in * W_in + y0_int * W_in + x1_int
2125 p101 = tl.load(
2126 ptr_input + offset,
2127 mask=x1_in & y0_in & z1_in & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan,
2128 other=0.0,
2129 ).to(tl.float32)
2131 # z=1, y=1, x=0
2132 offset = input_base + z1_int * H_in * W_in + y1_int * W_in + x0_int
2133 p110 = tl.load(
2134 ptr_input + offset,
2135 mask=x0_in & y1_in & z1_in & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan,
2136 other=0.0,
2137 ).to(tl.float32)
2139 # z=1, y=1, x=1
2140 offset = input_base + z1_int * H_in * W_in + y1_int * W_in + x1_int
2141 p111 = tl.load(
2142 ptr_input + offset,
2143 mask=x1_in & y1_in & z1_in & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan,
2144 other=0.0,
2145 ).to(tl.float32)
2147 # Trilinear interpolation
2148 # Interpolate along x first, then y, then z
2149 # Front face (z=0)
2150 c000 = p000 * (1.0 - wx) + p001 * wx
2151 c001 = p010 * (1.0 - wx) + p011 * wx
2152 front = c000 * (1.0 - wy) + c001 * wy
2154 # Back face (z=1)
2155 c100 = p100 * (1.0 - wx) + p101 * wx
2156 c101 = p110 * (1.0 - wx) + p111 * wx
2157 back = c100 * (1.0 - wy) + c101 * wy
2159 # Interpolate along z
2160 val = front * (1.0 - wz) + back * wz
2162 # Store output
2163 output_offset = (
2164 n * C * D_out * H_out * W_out
2165 + c * D_out * H_out * W_out
2166 + d_out * H_out * W_out
2167 + h_out * W_out
2168 + w_out
2169 )
2170 tl.store(ptr_output + output_offset, val)
2173@libentry()
2174@triton.autotune(
2175 configs=runtime.get_tuned_config("grid_sample_3d_trilinear"),
2176 key=["N", "C", "D_out", "H_out", "W_out"],
2177)
2178@triton.jit
2179def grid_sample_3d_trilinear_border_kernel(
2180 ptr_output,
2181 ptr_input,
2182 ptr_grid,
2183 N,
2184 C,
2185 D_in,
2186 H_in,
2187 W_in,
2188 D_out,
2189 H_out,
2190 W_out,
2191 align_corners: tl.constexpr,
2192 BLOCK_SIZE: tl.constexpr,
2193):
2194 """
2195 Grid sample kernel for 3D trilinear interpolation with border padding.
2196 """
2197 pid = tl.program_id(0)
2198 ncd = pid // (D_out * H_out * W_out)
2199 dhw = pid % (D_out * H_out * W_out)
2201 n = ncd // C
2202 c = ncd % C
2203 d_out = dhw // (H_out * W_out)
2204 hw = dhw % (H_out * W_out)
2205 h_out = hw // W_out
2206 w_out = hw % W_out
2208 # Load 3D grid coordinates
2209 grid_idx = (
2210 n * D_out * H_out * W_out * 3
2211 + d_out * H_out * W_out * 3
2212 + h_out * W_out * 3
2213 + w_out * 3
2214 )
2215 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32)
2216 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32)
2217 grid_z = tl.load(ptr_grid + grid_idx + 2).to(tl.float32)
2219 # Handle NaN
2220 grid_x_nan = grid_x != grid_x
2221 grid_y_nan = grid_y != grid_y
2222 grid_z_nan = grid_z != grid_z
2223 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
2224 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
2225 grid_z = tl.where(grid_z_nan, -2.0, grid_z)
2227 # Denormalize to pixel space
2228 if align_corners:
2229 x = (grid_x + 1.0) * (W_in - 1) / 2.0
2230 y = (grid_y + 1.0) * (H_in - 1) / 2.0
2231 z = (grid_z + 1.0) * (D_in - 1) / 2.0
2232 else:
2233 x = (grid_x + 1.0) * W_in / 2.0 - 0.5
2234 y = (grid_y + 1.0) * H_in / 2.0 - 0.5
2235 z = (grid_z + 1.0) * D_in / 2.0 - 0.5
2237 # Find 8 corner indices
2238 x0 = tl.floor(x)
2239 y0 = tl.floor(y)
2240 z0 = tl.floor(z)
2241 x1 = x0 + 1
2242 y1 = y0 + 1
2243 z1 = z0 + 1
2245 # Compute weights
2246 wx = x - x0
2247 wy = y - y0
2248 wz = z - z0
2250 # Convert to int and clamp
2251 x0_int = tl.maximum(0, tl.minimum(tl.cast(x0, tl.int32), W_in - 1))
2252 x1_int = tl.maximum(0, tl.minimum(tl.cast(x1, tl.int32), W_in - 1))
2253 y0_int = tl.maximum(0, tl.minimum(tl.cast(y0, tl.int32), H_in - 1))
2254 y1_int = tl.maximum(0, tl.minimum(tl.cast(y1, tl.int32), H_in - 1))
2255 z0_int = tl.maximum(0, tl.minimum(tl.cast(z0, tl.int32), D_in - 1))
2256 z1_int = tl.maximum(0, tl.minimum(tl.cast(z1, tl.int32), D_in - 1))
2258 # Load 8 corner pixels (no mask needed due to clamping)
2259 input_base = n * C * D_in * H_in * W_in + c * D_in * H_in * W_in
2261 p000 = tl.load(
2262 ptr_input + input_base + z0_int * H_in * W_in + y0_int * W_in + x0_int
2263 ).to(tl.float32)
2264 p001 = tl.load(
2265 ptr_input + input_base + z0_int * H_in * W_in + y0_int * W_in + x1_int
2266 ).to(tl.float32)
2267 p010 = tl.load(
2268 ptr_input + input_base + z0_int * H_in * W_in + y1_int * W_in + x0_int
2269 ).to(tl.float32)
2270 p011 = tl.load(
2271 ptr_input + input_base + z0_int * H_in * W_in + y1_int * W_in + x1_int
2272 ).to(tl.float32)
2273 p100 = tl.load(
2274 ptr_input + input_base + z1_int * H_in * W_in + y0_int * W_in + x0_int
2275 ).to(tl.float32)
2276 p101 = tl.load(
2277 ptr_input + input_base + z1_int * H_in * W_in + y0_int * W_in + x1_int
2278 ).to(tl.float32)
2279 p110 = tl.load(
2280 ptr_input + input_base + z1_int * H_in * W_in + y1_int * W_in + x0_int
2281 ).to(tl.float32)
2282 p111 = tl.load(
2283 ptr_input + input_base + z1_int * H_in * W_in + y1_int * W_in + x1_int
2284 ).to(tl.float32)
2286 # Trilinear interpolation
2287 c000 = p000 * (1.0 - wx) + p001 * wx
2288 c001 = p010 * (1.0 - wx) + p011 * wx
2289 front = c000 * (1.0 - wy) + c001 * wy
2291 c100 = p100 * (1.0 - wx) + p101 * wx
2292 c101 = p110 * (1.0 - wx) + p111 * wx
2293 back = c100 * (1.0 - wy) + c101 * wy
2295 val = tl.where(
2296 grid_x_nan | grid_y_nan | grid_z_nan, 0.0, front * (1.0 - wz) + back * wz
2297 )
2299 # Store output
2300 output_offset = (
2301 n * C * D_out * H_out * W_out
2302 + c * D_out * H_out * W_out
2303 + d_out * H_out * W_out
2304 + h_out * W_out
2305 + w_out
2306 )
2307 tl.store(ptr_output + output_offset, val)
2310@libentry()
2311@triton.autotune(
2312 configs=runtime.get_tuned_config("grid_sample_3d_trilinear"),
2313 key=["N", "C", "D_out", "H_out", "W_out"],
2314)
2315@triton.jit
2316def grid_sample_3d_trilinear_reflection_kernel(
2317 ptr_output,
2318 ptr_input,
2319 ptr_grid,
2320 N,
2321 C,
2322 D_in,
2323 H_in,
2324 W_in,
2325 D_out,
2326 H_out,
2327 W_out,
2328 align_corners: tl.constexpr,
2329 BLOCK_SIZE: tl.constexpr,
2330):
2331 """
2332 Grid sample kernel for 3D trilinear interpolation with reflection padding.
2333 """
2334 pid = tl.program_id(0)
2335 ncd = pid // (D_out * H_out * W_out)
2336 dhw = pid % (D_out * H_out * W_out)
2338 n = ncd // C
2339 c = ncd % C
2340 d_out = dhw // (H_out * W_out)
2341 hw = dhw % (H_out * W_out)
2342 h_out = hw // W_out
2343 w_out = hw % W_out
2345 # Load 3D grid coordinates
2346 grid_idx = (
2347 n * D_out * H_out * W_out * 3
2348 + d_out * H_out * W_out * 3
2349 + h_out * W_out * 3
2350 + w_out * 3
2351 )
2352 grid_x = tl.load(ptr_grid + grid_idx).to(tl.float32)
2353 grid_y = tl.load(ptr_grid + grid_idx + 1).to(tl.float32)
2354 grid_z = tl.load(ptr_grid + grid_idx + 2).to(tl.float32)
2356 # Handle NaN
2357 grid_x_nan = grid_x != grid_x
2358 grid_y_nan = grid_y != grid_y
2359 grid_z_nan = grid_z != grid_z
2360 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
2361 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
2362 grid_z = tl.where(grid_z_nan, -2.0, grid_z)
2364 # Reflection padding in GRID space (triangle wave)
2365 grid_x_shifted = grid_x + 1.0
2366 grid_x_mod = grid_x_shifted % 4.0
2367 grid_x_mod = tl.where(grid_x_mod < 0, grid_x_mod + 4.0, grid_x_mod)
2368 grid_x_refl_mod = tl.where(grid_x_mod <= 2.0, grid_x_mod, 4.0 - grid_x_mod)
2369 grid_x_refl = grid_x_refl_mod - 1.0
2371 grid_y_shifted = grid_y + 1.0
2372 grid_y_mod = grid_y_shifted % 4.0
2373 grid_y_mod = tl.where(grid_y_mod < 0, grid_y_mod + 4.0, grid_y_mod)
2374 grid_y_refl_mod = tl.where(grid_y_mod <= 2.0, grid_y_mod, 4.0 - grid_y_mod)
2375 grid_y_refl = grid_y_refl_mod - 1.0
2377 grid_z_shifted = grid_z + 1.0
2378 grid_z_mod = grid_z_shifted % 4.0
2379 grid_z_mod = tl.where(grid_z_mod < 0, grid_z_mod + 4.0, grid_z_mod)
2380 grid_z_refl_mod = tl.where(grid_z_mod <= 2.0, grid_z_mod, 4.0 - grid_z_mod)
2381 grid_z_refl = grid_z_refl_mod - 1.0
2383 # Denormalize to pixel space
2384 if align_corners:
2385 x = (grid_x_refl + 1.0) * (W_in - 1) / 2.0
2386 y = (grid_y_refl + 1.0) * (H_in - 1) / 2.0
2387 z = (grid_z_refl + 1.0) * (D_in - 1) / 2.0
2388 else:
2389 x = (grid_x_refl + 1.0) * W_in / 2.0 - 0.5
2390 y = (grid_y_refl + 1.0) * H_in / 2.0 - 0.5
2391 z = (grid_z_refl + 1.0) * D_in / 2.0 - 0.5
2393 # Find 8 corner indices
2394 x0 = tl.floor(x)
2395 y0 = tl.floor(y)
2396 z0 = tl.floor(z)
2397 x1 = x0 + 1
2398 y1 = y0 + 1
2399 z1 = z0 + 1
2401 # Compute weights
2402 wx = x - x0
2403 wy = y - y0
2404 wz = z - z0
2406 # Convert to int and clamp
2407 x0_int = tl.maximum(0, tl.minimum(tl.cast(x0, tl.int32), W_in - 1))
2408 x1_int = tl.maximum(0, tl.minimum(tl.cast(x1, tl.int32), W_in - 1))
2409 y0_int = tl.maximum(0, tl.minimum(tl.cast(y0, tl.int32), H_in - 1))
2410 y1_int = tl.maximum(0, tl.minimum(tl.cast(y1, tl.int32), H_in - 1))
2411 z0_int = tl.maximum(0, tl.minimum(tl.cast(z0, tl.int32), D_in - 1))
2412 z1_int = tl.maximum(0, tl.minimum(tl.cast(z1, tl.int32), D_in - 1))
2414 # Load 8 corner pixels
2415 input_base = n * C * D_in * H_in * W_in + c * D_in * H_in * W_in
2417 p000 = tl.load(
2418 ptr_input + input_base + z0_int * H_in * W_in + y0_int * W_in + x0_int
2419 ).to(tl.float32)
2420 p001 = tl.load(
2421 ptr_input + input_base + z0_int * H_in * W_in + y0_int * W_in + x1_int
2422 ).to(tl.float32)
2423 p010 = tl.load(
2424 ptr_input + input_base + z0_int * H_in * W_in + y1_int * W_in + x0_int
2425 ).to(tl.float32)
2426 p011 = tl.load(
2427 ptr_input + input_base + z0_int * H_in * W_in + y1_int * W_in + x1_int
2428 ).to(tl.float32)
2429 p100 = tl.load(
2430 ptr_input + input_base + z1_int * H_in * W_in + y0_int * W_in + x0_int
2431 ).to(tl.float32)
2432 p101 = tl.load(
2433 ptr_input + input_base + z1_int * H_in * W_in + y0_int * W_in + x1_int
2434 ).to(tl.float32)
2435 p110 = tl.load(
2436 ptr_input + input_base + z1_int * H_in * W_in + y1_int * W_in + x0_int
2437 ).to(tl.float32)
2438 p111 = tl.load(
2439 ptr_input + input_base + z1_int * H_in * W_in + y1_int * W_in + x1_int
2440 ).to(tl.float32)
2442 # Trilinear interpolation
2443 c000 = p000 * (1.0 - wx) + p001 * wx
2444 c001 = p010 * (1.0 - wx) + p011 * wx
2445 front = c000 * (1.0 - wy) + c001 * wy
2447 c100 = p100 * (1.0 - wx) + p101 * wx
2448 c101 = p110 * (1.0 - wx) + p111 * wx
2449 back = c100 * (1.0 - wy) + c101 * wy
2451 val = tl.where(
2452 grid_x_nan | grid_y_nan | grid_z_nan, 0.0, front * (1.0 - wz) + back * wz
2453 )
2455 # Store output
2456 output_offset = (
2457 n * C * D_out * H_out * W_out
2458 + c * D_out * H_out * W_out
2459 + d_out * H_out * W_out
2460 + h_out * W_out
2461 + w_out
2462 )
2463 tl.store(ptr_output + output_offset, val)
2466# ============================================================================
2467# 3D Tiled Kernels for Medium-to-Large 5D Inputs (3D Blocking: D×H×W)
2468# ============================================================================
2471@libentry()
2472@triton.autotune(
2473 configs=runtime.get_tuned_config("grid_sample_3d_nearest_tiled"),
2474 key=["N", "C", "D_out", "H_out", "W_out"],
2475)
2476@triton.jit
2477def grid_sample_3d_nearest_zeros_tiled_kernel(
2478 ptr_output,
2479 ptr_input,
2480 ptr_grid,
2481 N,
2482 C,
2483 D_in,
2484 H_in,
2485 W_in,
2486 D_out,
2487 H_out,
2488 W_out,
2489 align_corners: tl.constexpr,
2490 BLOCK_D: tl.constexpr,
2491 BLOCK_H: tl.constexpr,
2492 BLOCK_W: tl.constexpr,
2493):
2494 """
2495 Grid sample kernel for 3D nearest neighbor interpolation with zeros padding (tiled version).
2497 This kernel processes a BLOCK_D × BLOCK_H × BLOCK_W tile of output voxels at once,
2498 enabling better memory coalescing and data reuse for medium-to-large 5D inputs.
2500 Args:
2501 ptr_output: Pointer to output tensor (N, C, D_out, H_out, W_out)
2502 ptr_input: Pointer to input tensor (N, C, D_in, H_in, W_in)
2503 ptr_grid: Pointer to grid tensor (N, D_out, H_out, W_out, 3)
2504 N: Batch size
2505 C: Number of channels
2506 D_in: Input depth
2507 H_in: Input height
2508 W_in: Input width
2509 D_out: Output depth
2510 H_out: Output height
2511 W_out: Output width
2512 align_corners: Whether to align corners
2513 BLOCK_D: Block depth for tiling
2514 BLOCK_H: Block height for tiling
2515 BLOCK_W: Block width for tiling
2516 """
2517 # 2D program IDs: pid_nc for (batch, channel), pid_dhw for spatial tile
2518 pid_nc = tl.program_id(0)
2519 pid_dhw = tl.program_id(1)
2521 # Compute batch and channel
2522 n = pid_nc // C
2523 c = pid_nc % C
2525 # Decompose flattened 3D tile index to (d, h, w) block indices
2526 num_w_blocks = tl.cdiv(W_out, BLOCK_W)
2527 num_h_blocks = tl.cdiv(H_out, BLOCK_H)
2528 num_hw_blocks = num_h_blocks * num_w_blocks
2530 d_block_idx = pid_dhw // num_hw_blocks
2531 hw_block_idx = pid_dhw % num_hw_blocks
2532 h_block_idx = hw_block_idx // num_w_blocks
2533 w_block_idx = hw_block_idx % num_w_blocks
2535 # Compute voxel offsets within tile (3D broadcasting)
2536 d_offsets = d_block_idx * BLOCK_D + tl.arange(0, BLOCK_D)
2537 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
2538 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
2540 # Mask for boundary tiles
2541 d_mask = d_offsets < D_out
2542 h_mask = h_offsets < H_out
2543 w_mask = w_offsets < W_out
2544 tile_mask = d_mask[:, None, None] & h_mask[None, :, None] & w_mask[None, None, :]
2546 # Reshape for 3D broadcasting: (BLOCK_D, BLOCK_H, BLOCK_W)
2547 d_out_3d = d_offsets[:, None, None]
2548 h_out_3d = h_offsets[None, :, None]
2549 w_out_3d = w_offsets[None, None, :]
2551 # Load 3D grid coordinates for entire tile (vectorized)
2552 # Grid shape: (N, D_out, H_out, W_out, 3)
2553 grid_base = n * D_out * H_out * W_out * 3
2555 # Load x, y, z coordinates: (BLOCK_D, BLOCK_H, BLOCK_W)
2556 grid_x_offsets = (
2557 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3
2558 )
2559 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to(
2560 tl.float32
2561 )
2563 grid_y_offsets = (
2564 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 1
2565 )
2566 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to(
2567 tl.float32
2568 )
2570 grid_z_offsets = (
2571 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 2
2572 )
2573 grid_z = tl.load(ptr_grid + grid_z_offsets, mask=tile_mask, other=0.0).to(
2574 tl.float32
2575 )
2577 # Handle NaN - use sentinel value -2.0 (outside valid grid range [-1, 1])
2578 grid_x_nan = grid_x != grid_x
2579 grid_y_nan = grid_y != grid_y
2580 grid_z_nan = grid_z != grid_z
2581 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
2582 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
2583 grid_z = tl.where(grid_z_nan, -2.0, grid_z)
2585 # Denormalize to pixel space
2586 if align_corners:
2587 # Pixel centers at -1 and 1
2588 x = (grid_x + 1.0) * (W_in - 1) / 2.0
2589 y = (grid_y + 1.0) * (H_in - 1) / 2.0
2590 z = (grid_z + 1.0) * (D_in - 1) / 2.0
2591 else:
2592 # Pixel corners at -1 and 1
2593 x = (grid_x + 1.0) * W_in / 2.0 - 0.5
2594 y = (grid_y + 1.0) * H_in / 2.0 - 0.5
2595 z = (grid_z + 1.0) * D_in / 2.0 - 0.5
2597 # Apply banker's rounding (vectorized across tile)
2598 x_floor = tl.floor(x)
2599 y_floor = tl.floor(y)
2600 z_floor = tl.floor(z)
2601 x_frac = x - x_floor
2602 y_frac = y - y_floor
2603 z_frac = z - z_floor
2605 x_is_half = x_frac == 0.5
2606 y_is_half = y_frac == 0.5
2607 z_is_half = z_frac == 0.5
2608 x_floor_int = tl.cast(x_floor, tl.int32)
2609 y_floor_int = tl.cast(y_floor, tl.int32)
2610 z_floor_int = tl.cast(z_floor, tl.int32)
2612 x_is_even = x_floor_int % 2 == 0
2613 y_is_even = y_floor_int % 2 == 0
2614 z_is_even = z_floor_int % 2 == 0
2616 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1)
2617 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1)
2618 z_round = tl.where(z_frac < 0.5, z_floor, z_floor + 1)
2620 x_idx = tl.cast(
2621 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round),
2622 tl.int32,
2623 )
2624 y_idx = tl.cast(
2625 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round),
2626 tl.int32,
2627 )
2628 z_idx = tl.cast(
2629 tl.where(z_is_half, tl.where(z_is_even, z_floor, z_floor + 1), z_round),
2630 tl.int32,
2631 )
2633 # Check bounds (vectorized)
2634 x_in_bounds = (x_idx >= 0) & (x_idx < W_in)
2635 y_in_bounds = (y_idx >= 0) & (y_idx < H_in)
2636 z_in_bounds = (z_idx >= 0) & (z_idx < D_in)
2637 valid_mask = (
2638 tile_mask
2639 & x_in_bounds
2640 & y_in_bounds
2641 & z_in_bounds
2642 & ~grid_x_nan
2643 & ~grid_y_nan
2644 & ~grid_z_nan
2645 )
2647 # Load input voxels for entire tile
2648 # Input shape: (N, C, D_in, H_in, W_in)
2649 input_base = n * C * D_in * H_in * W_in + c * D_in * H_in * W_in
2650 input_offsets = input_base + z_idx * H_in * W_in + y_idx * W_in + x_idx
2652 # Vectorized load: (BLOCK_D, BLOCK_H, BLOCK_W)
2653 vals = tl.load(ptr_input + input_offsets, mask=valid_mask, other=0.0)
2655 # Store to output
2656 # Output shape: (N, C, D_out, H_out, W_out)
2657 output_base = n * C * D_out * H_out * W_out + c * D_out * H_out * W_out
2658 output_offsets = output_base + (
2659 d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d
2660 )
2662 tl.store(ptr_output + output_offsets, vals, mask=tile_mask)
2665@libentry()
2666@triton.autotune(
2667 configs=runtime.get_tuned_config("grid_sample_3d_nearest_tiled"),
2668 key=["N", "C", "D_out", "H_out", "W_out"],
2669)
2670@triton.jit
2671def grid_sample_3d_nearest_border_tiled_kernel(
2672 ptr_output,
2673 ptr_input,
2674 ptr_grid,
2675 N,
2676 C,
2677 D_in,
2678 H_in,
2679 W_in,
2680 D_out,
2681 H_out,
2682 W_out,
2683 align_corners: tl.constexpr,
2684 BLOCK_D: tl.constexpr,
2685 BLOCK_H: tl.constexpr,
2686 BLOCK_W: tl.constexpr,
2687):
2688 """
2689 Grid sample kernel for 3D nearest neighbor interpolation with border padding (tiled version).
2691 Border padding: coordinates outside the input range are clamped to the boundary.
2692 """
2693 # 2D program IDs: pid_nc for (batch, channel), pid_dhw for spatial tile
2694 pid_nc = tl.program_id(0)
2695 pid_dhw = tl.program_id(1)
2697 # Compute batch and channel
2698 n = pid_nc // C
2699 c = pid_nc % C
2701 # Decompose flattened 3D tile index to (d, h, w) block indices
2702 num_w_blocks = tl.cdiv(W_out, BLOCK_W)
2703 num_h_blocks = tl.cdiv(H_out, BLOCK_H)
2704 num_hw_blocks = num_h_blocks * num_w_blocks
2706 d_block_idx = pid_dhw // num_hw_blocks
2707 hw_block_idx = pid_dhw % num_hw_blocks
2708 h_block_idx = hw_block_idx // num_w_blocks
2709 w_block_idx = hw_block_idx % num_w_blocks
2711 # Compute voxel offsets within tile (3D broadcasting)
2712 d_offsets = d_block_idx * BLOCK_D + tl.arange(0, BLOCK_D)
2713 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
2714 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
2716 # Mask for boundary tiles
2717 d_mask = d_offsets < D_out
2718 h_mask = h_offsets < H_out
2719 w_mask = w_offsets < W_out
2720 tile_mask = d_mask[:, None, None] & h_mask[None, :, None] & w_mask[None, None, :]
2722 # Reshape for 3D broadcasting: (BLOCK_D, BLOCK_H, BLOCK_W)
2723 d_out_3d = d_offsets[:, None, None]
2724 h_out_3d = h_offsets[None, :, None]
2725 w_out_3d = w_offsets[None, None, :]
2727 # Load 3D grid coordinates for entire tile (vectorized)
2728 grid_base = n * D_out * H_out * W_out * 3
2730 grid_x_offsets = (
2731 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3
2732 )
2733 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to(
2734 tl.float32
2735 )
2737 grid_y_offsets = (
2738 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 1
2739 )
2740 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to(
2741 tl.float32
2742 )
2744 grid_z_offsets = (
2745 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 2
2746 )
2747 grid_z = tl.load(ptr_grid + grid_z_offsets, mask=tile_mask, other=0.0).to(
2748 tl.float32
2749 )
2751 # Handle NaN
2752 grid_x_nan = grid_x != grid_x
2753 grid_y_nan = grid_y != grid_y
2754 grid_z_nan = grid_z != grid_z
2755 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
2756 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
2757 grid_z = tl.where(grid_z_nan, -2.0, grid_z)
2759 # Denormalize to pixel space
2760 if align_corners:
2761 x = (grid_x + 1.0) * (W_in - 1) / 2.0
2762 y = (grid_y + 1.0) * (H_in - 1) / 2.0
2763 z = (grid_z + 1.0) * (D_in - 1) / 2.0
2764 else:
2765 x = (grid_x + 1.0) * W_in / 2.0 - 0.5
2766 y = (grid_y + 1.0) * H_in / 2.0 - 0.5
2767 z = (grid_z + 1.0) * D_in / 2.0 - 0.5
2769 # Apply banker's rounding (vectorized across tile)
2770 x_floor = tl.floor(x)
2771 y_floor = tl.floor(y)
2772 z_floor = tl.floor(z)
2773 x_frac = x - x_floor
2774 y_frac = y - y_floor
2775 z_frac = z - z_floor
2777 x_is_half = x_frac == 0.5
2778 y_is_half = y_frac == 0.5
2779 z_is_half = z_frac == 0.5
2780 x_floor_int = tl.cast(x_floor, tl.int32)
2781 y_floor_int = tl.cast(y_floor, tl.int32)
2782 z_floor_int = tl.cast(z_floor, tl.int32)
2784 x_is_even = x_floor_int % 2 == 0
2785 y_is_even = y_floor_int % 2 == 0
2786 z_is_even = z_floor_int % 2 == 0
2788 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1)
2789 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1)
2790 z_round = tl.where(z_frac < 0.5, z_floor, z_floor + 1)
2792 x_idx = tl.cast(
2793 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round),
2794 tl.int32,
2795 )
2796 y_idx = tl.cast(
2797 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round),
2798 tl.int32,
2799 )
2800 z_idx = tl.cast(
2801 tl.where(z_is_half, tl.where(z_is_even, z_floor, z_floor + 1), z_round),
2802 tl.int32,
2803 )
2805 # Border padding: clamp coordinates to valid range
2806 x_idx = tl.maximum(0, tl.minimum(x_idx, W_in - 1))
2807 y_idx = tl.maximum(0, tl.minimum(y_idx, H_in - 1))
2808 z_idx = tl.maximum(0, tl.minimum(z_idx, D_in - 1))
2810 # Valid mask: only tile boundary and NaN (no bounds check needed for border)
2811 valid_mask = tile_mask & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan
2813 # Load input voxels for entire tile
2814 input_base = n * C * D_in * H_in * W_in + c * D_in * H_in * W_in
2815 input_offsets = input_base + z_idx * H_in * W_in + y_idx * W_in + x_idx
2817 # Load and handle NaN separately (border padding doesn't help with NaN)
2818 vals_raw = tl.load(ptr_input + input_offsets, mask=valid_mask, other=0.0)
2819 vals = tl.where(grid_x_nan | grid_y_nan | grid_z_nan, 0.0, vals_raw)
2821 # Store to output
2822 output_base = n * C * D_out * H_out * W_out + c * D_out * H_out * W_out
2823 output_offsets = output_base + (
2824 d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d
2825 )
2827 tl.store(ptr_output + output_offsets, vals, mask=tile_mask)
2830@libentry()
2831@triton.autotune(
2832 configs=runtime.get_tuned_config("grid_sample_3d_nearest_tiled"),
2833 key=["N", "C", "D_out", "H_out", "W_out"],
2834)
2835@triton.jit
2836def grid_sample_3d_nearest_reflection_tiled_kernel(
2837 ptr_output,
2838 ptr_input,
2839 ptr_grid,
2840 N,
2841 C,
2842 D_in,
2843 H_in,
2844 W_in,
2845 D_out,
2846 H_out,
2847 W_out,
2848 align_corners: tl.constexpr,
2849 BLOCK_D: tl.constexpr,
2850 BLOCK_H: tl.constexpr,
2851 BLOCK_W: tl.constexpr,
2852):
2853 """
2854 Grid sample kernel for 3D nearest neighbor interpolation with reflection padding (tiled version).
2856 Reflection padding: coordinates outside the input range are reflected back into the valid range
2857 using a triangle wave pattern with period 4.
2858 """
2859 # 2D program IDs: pid_nc for (batch, channel), pid_dhw for spatial tile
2860 pid_nc = tl.program_id(0)
2861 pid_dhw = tl.program_id(1)
2863 # Compute batch and channel
2864 n = pid_nc // C
2865 c = pid_nc % C
2867 # Decompose flattened 3D tile index to (d, h, w) block indices
2868 num_w_blocks = tl.cdiv(W_out, BLOCK_W)
2869 num_h_blocks = tl.cdiv(H_out, BLOCK_H)
2870 num_hw_blocks = num_h_blocks * num_w_blocks
2872 d_block_idx = pid_dhw // num_hw_blocks
2873 hw_block_idx = pid_dhw % num_hw_blocks
2874 h_block_idx = hw_block_idx // num_w_blocks
2875 w_block_idx = hw_block_idx % num_w_blocks
2877 # Compute voxel offsets within tile (3D broadcasting)
2878 d_offsets = d_block_idx * BLOCK_D + tl.arange(0, BLOCK_D)
2879 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
2880 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
2882 # Mask for boundary tiles
2883 d_mask = d_offsets < D_out
2884 h_mask = h_offsets < H_out
2885 w_mask = w_offsets < W_out
2886 tile_mask = d_mask[:, None, None] & h_mask[None, :, None] & w_mask[None, None, :]
2888 # Reshape for 3D broadcasting: (BLOCK_D, BLOCK_H, BLOCK_W)
2889 d_out_3d = d_offsets[:, None, None]
2890 h_out_3d = h_offsets[None, :, None]
2891 w_out_3d = w_offsets[None, None, :]
2893 # Load 3D grid coordinates for entire tile (vectorized)
2894 grid_base = n * D_out * H_out * W_out * 3
2896 grid_x_offsets = (
2897 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3
2898 )
2899 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to(
2900 tl.float32
2901 )
2903 grid_y_offsets = (
2904 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 1
2905 )
2906 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to(
2907 tl.float32
2908 )
2910 grid_z_offsets = (
2911 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 2
2912 )
2913 grid_z = tl.load(ptr_grid + grid_z_offsets, mask=tile_mask, other=0.0).to(
2914 tl.float32
2915 )
2917 # Handle NaN
2918 grid_x_nan = grid_x != grid_x
2919 grid_y_nan = grid_y != grid_y
2920 grid_z_nan = grid_z != grid_z
2921 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
2922 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
2923 grid_z = tl.where(grid_z_nan, -2.0, grid_z)
2925 # Apply triangle wave reflection with period 4 (before denormalization)
2926 # This maps coordinates outside [-1, 1] back into this range by reflection
2927 # Process grid_x
2928 grid_x_shifted = grid_x + 1.0
2929 grid_x_mod = grid_x_shifted % 4.0
2930 grid_x_mod = tl.where(grid_x_mod < 0, grid_x_mod + 4.0, grid_x_mod)
2931 grid_x_refl_mod = tl.where(grid_x_mod <= 2.0, grid_x_mod, 4.0 - grid_x_mod)
2932 grid_x = grid_x_refl_mod - 1.0
2934 # Process grid_y
2935 grid_y_shifted = grid_y + 1.0
2936 grid_y_mod = grid_y_shifted % 4.0
2937 grid_y_mod = tl.where(grid_y_mod < 0, grid_y_mod + 4.0, grid_y_mod)
2938 grid_y_refl_mod = tl.where(grid_y_mod <= 2.0, grid_y_mod, 4.0 - grid_y_mod)
2939 grid_y = grid_y_refl_mod - 1.0
2941 # Process grid_z
2942 grid_z_shifted = grid_z + 1.0
2943 grid_z_mod = grid_z_shifted % 4.0
2944 grid_z_mod = tl.where(grid_z_mod < 0, grid_z_mod + 4.0, grid_z_mod)
2945 grid_z_refl_mod = tl.where(grid_z_mod <= 2.0, grid_z_mod, 4.0 - grid_z_mod)
2946 grid_z = grid_z_refl_mod - 1.0
2948 # Denormalize reflected coordinates to pixel space
2949 if align_corners:
2950 x = (grid_x + 1.0) * (W_in - 1) / 2.0
2951 y = (grid_y + 1.0) * (H_in - 1) / 2.0
2952 z = (grid_z + 1.0) * (D_in - 1) / 2.0
2953 else:
2954 x = (grid_x + 1.0) * W_in / 2.0 - 0.5
2955 y = (grid_y + 1.0) * H_in / 2.0 - 0.5
2956 z = (grid_z + 1.0) * D_in / 2.0 - 0.5
2958 # Apply banker's rounding (vectorized across tile)
2959 x_floor = tl.floor(x)
2960 y_floor = tl.floor(y)
2961 z_floor = tl.floor(z)
2962 x_frac = x - x_floor
2963 y_frac = y - y_floor
2964 z_frac = z - z_floor
2966 x_is_half = x_frac == 0.5
2967 y_is_half = y_frac == 0.5
2968 z_is_half = z_frac == 0.5
2969 x_floor_int = tl.cast(x_floor, tl.int32)
2970 y_floor_int = tl.cast(y_floor, tl.int32)
2971 z_floor_int = tl.cast(z_floor, tl.int32)
2973 x_is_even = x_floor_int % 2 == 0
2974 y_is_even = y_floor_int % 2 == 0
2975 z_is_even = z_floor_int % 2 == 0
2977 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1)
2978 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1)
2979 z_round = tl.where(z_frac < 0.5, z_floor, z_floor + 1)
2981 x_idx = tl.cast(
2982 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round),
2983 tl.int32,
2984 )
2985 y_idx = tl.cast(
2986 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round),
2987 tl.int32,
2988 )
2989 z_idx = tl.cast(
2990 tl.where(z_is_half, tl.where(z_is_even, z_floor, z_floor + 1), z_round),
2991 tl.int32,
2992 )
2994 # Check bounds (reflection ensures coordinates are valid, but still check)
2995 x_in_bounds = (x_idx >= 0) & (x_idx < W_in)
2996 y_in_bounds = (y_idx >= 0) & (y_idx < H_in)
2997 z_in_bounds = (z_idx >= 0) & (z_idx < D_in)
2998 valid_mask = (
2999 tile_mask
3000 & x_in_bounds
3001 & y_in_bounds
3002 & z_in_bounds
3003 & ~grid_x_nan
3004 & ~grid_y_nan
3005 & ~grid_z_nan
3006 )
3008 # Load input voxels for entire tile
3009 input_base = n * C * D_in * H_in * W_in + c * D_in * H_in * W_in
3010 input_offsets = input_base + z_idx * H_in * W_in + y_idx * W_in + x_idx
3012 vals = tl.load(ptr_input + input_offsets, mask=valid_mask, other=0.0)
3014 # Store to output
3015 output_base = n * C * D_out * H_out * W_out + c * D_out * H_out * W_out
3016 output_offsets = output_base + (
3017 d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d
3018 )
3020 tl.store(ptr_output + output_offsets, vals, mask=tile_mask)
3023@libentry()
3024@triton.autotune(
3025 configs=runtime.get_tuned_config("grid_sample_3d_trilinear_tiled"),
3026 key=["N", "C", "D_out", "H_out", "W_out"],
3027)
3028@triton.jit
3029def grid_sample_3d_trilinear_zeros_tiled_kernel(
3030 ptr_output,
3031 ptr_input,
3032 ptr_grid,
3033 N,
3034 C,
3035 D_in,
3036 H_in,
3037 W_in,
3038 D_out,
3039 H_out,
3040 W_out,
3041 align_corners: tl.constexpr,
3042 BLOCK_D: tl.constexpr,
3043 BLOCK_H: tl.constexpr,
3044 BLOCK_W: tl.constexpr,
3045):
3046 """
3047 Grid sample kernel for 3D trilinear interpolation with zeros padding (tiled version).
3049 Trilinear interpolation uses 8 corner points (2×2×2 cube) for each output voxel.
3050 """
3051 # 2D program IDs: pid_nc for (batch, channel), pid_dhw for spatial tile
3052 pid_nc = tl.program_id(0)
3053 pid_dhw = tl.program_id(1)
3055 # Compute batch and channel
3056 n = pid_nc // C
3057 c = pid_nc % C
3059 # Decompose flattened 3D tile index to (d, h, w) block indices
3060 num_w_blocks = tl.cdiv(W_out, BLOCK_W)
3061 num_h_blocks = tl.cdiv(H_out, BLOCK_H)
3062 num_hw_blocks = num_h_blocks * num_w_blocks
3064 d_block_idx = pid_dhw // num_hw_blocks
3065 hw_block_idx = pid_dhw % num_hw_blocks
3066 h_block_idx = hw_block_idx // num_w_blocks
3067 w_block_idx = hw_block_idx % num_w_blocks
3069 # Compute voxel offsets within tile (3D broadcasting)
3070 d_offsets = d_block_idx * BLOCK_D + tl.arange(0, BLOCK_D)
3071 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
3072 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
3074 # Mask for boundary tiles
3075 d_mask = d_offsets < D_out
3076 h_mask = h_offsets < H_out
3077 w_mask = w_offsets < W_out
3078 tile_mask = d_mask[:, None, None] & h_mask[None, :, None] & w_mask[None, None, :]
3080 # Reshape for 3D broadcasting: (BLOCK_D, BLOCK_H, BLOCK_W)
3081 d_out_3d = d_offsets[:, None, None]
3082 h_out_3d = h_offsets[None, :, None]
3083 w_out_3d = w_offsets[None, None, :]
3085 # Load 3D grid coordinates for entire tile (vectorized)
3086 grid_base = n * D_out * H_out * W_out * 3
3088 grid_x_offsets = (
3089 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3
3090 )
3091 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to(
3092 tl.float32
3093 )
3095 grid_y_offsets = (
3096 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 1
3097 )
3098 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to(
3099 tl.float32
3100 )
3102 grid_z_offsets = (
3103 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 2
3104 )
3105 grid_z = tl.load(ptr_grid + grid_z_offsets, mask=tile_mask, other=0.0).to(
3106 tl.float32
3107 )
3109 # Handle NaN
3110 grid_x_nan = grid_x != grid_x
3111 grid_y_nan = grid_y != grid_y
3112 grid_z_nan = grid_z != grid_z
3113 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
3114 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
3115 grid_z = tl.where(grid_z_nan, -2.0, grid_z)
3117 # Denormalize to pixel space
3118 if align_corners:
3119 x = (grid_x + 1.0) * (W_in - 1) / 2.0
3120 y = (grid_y + 1.0) * (H_in - 1) / 2.0
3121 z = (grid_z + 1.0) * (D_in - 1) / 2.0
3122 else:
3123 x = (grid_x + 1.0) * W_in / 2.0 - 0.5
3124 y = (grid_y + 1.0) * H_in / 2.0 - 0.5
3125 z = (grid_z + 1.0) * D_in / 2.0 - 0.5
3127 # Compute 8 corner indices for entire tile
3128 x0 = tl.floor(x)
3129 y0 = tl.floor(y)
3130 z0 = tl.floor(z)
3131 x1 = x0 + 1
3132 y1 = y0 + 1
3133 z1 = z0 + 1
3135 # Interpolation weights
3136 wx = x - x0
3137 wy = y - y0
3138 wz = z - z0
3140 # Convert to integers
3141 x0_int = tl.cast(x0, tl.int32)
3142 y0_int = tl.cast(y0, tl.int32)
3143 z0_int = tl.cast(z0, tl.int32)
3144 x1_int = tl.cast(x1, tl.int32)
3145 y1_int = tl.cast(y1, tl.int32)
3146 z1_int = tl.cast(z1, tl.int32)
3148 # Boundary checks
3149 x0_in = (x0_int >= 0) & (x0_int < W_in)
3150 x1_in = (x1_int >= 0) & (x1_int < W_in)
3151 y0_in = (y0_int >= 0) & (y0_int < H_in)
3152 y1_in = (y1_int >= 0) & (y1_int < H_in)
3153 z0_in = (z0_int >= 0) & (z0_int < D_in)
3154 z1_in = (z1_int >= 0) & (z1_int < D_in)
3156 # Load 8 corners (vectorized)
3157 input_base = n * C * D_in * H_in * W_in + c * D_in * H_in * W_in
3159 # p000: (x=0, y=0, z=0)
3160 offset = input_base + z0_int * H_in * W_in + y0_int * W_in + x0_int
3161 p000 = tl.load(
3162 ptr_input + offset,
3163 mask=tile_mask
3164 & x0_in
3165 & y0_in
3166 & z0_in
3167 & ~grid_x_nan
3168 & ~grid_y_nan
3169 & ~grid_z_nan,
3170 other=0.0,
3171 ).to(tl.float32)
3173 # p001: (x=1, y=0, z=0)
3174 offset = input_base + z0_int * H_in * W_in + y0_int * W_in + x1_int
3175 p001 = tl.load(
3176 ptr_input + offset,
3177 mask=tile_mask
3178 & x1_in
3179 & y0_in
3180 & z0_in
3181 & ~grid_x_nan
3182 & ~grid_y_nan
3183 & ~grid_z_nan,
3184 other=0.0,
3185 ).to(tl.float32)
3187 # p010: (x=0, y=1, z=0)
3188 offset = input_base + z0_int * H_in * W_in + y1_int * W_in + x0_int
3189 p010 = tl.load(
3190 ptr_input + offset,
3191 mask=tile_mask
3192 & x0_in
3193 & y1_in
3194 & z0_in
3195 & ~grid_x_nan
3196 & ~grid_y_nan
3197 & ~grid_z_nan,
3198 other=0.0,
3199 ).to(tl.float32)
3201 # p011: (x=1, y=1, z=0)
3202 offset = input_base + z0_int * H_in * W_in + y1_int * W_in + x1_int
3203 p011 = tl.load(
3204 ptr_input + offset,
3205 mask=tile_mask
3206 & x1_in
3207 & y1_in
3208 & z0_in
3209 & ~grid_x_nan
3210 & ~grid_y_nan
3211 & ~grid_z_nan,
3212 other=0.0,
3213 ).to(tl.float32)
3215 # p100: (x=0, y=0, z=1)
3216 offset = input_base + z1_int * H_in * W_in + y0_int * W_in + x0_int
3217 p100 = tl.load(
3218 ptr_input + offset,
3219 mask=tile_mask
3220 & x0_in
3221 & y0_in
3222 & z1_in
3223 & ~grid_x_nan
3224 & ~grid_y_nan
3225 & ~grid_z_nan,
3226 other=0.0,
3227 ).to(tl.float32)
3229 # p101: (x=1, y=0, z=1)
3230 offset = input_base + z1_int * H_in * W_in + y0_int * W_in + x1_int
3231 p101 = tl.load(
3232 ptr_input + offset,
3233 mask=tile_mask
3234 & x1_in
3235 & y0_in
3236 & z1_in
3237 & ~grid_x_nan
3238 & ~grid_y_nan
3239 & ~grid_z_nan,
3240 other=0.0,
3241 ).to(tl.float32)
3243 # p110: (x=0, y=1, z=1)
3244 offset = input_base + z1_int * H_in * W_in + y1_int * W_in + x0_int
3245 p110 = tl.load(
3246 ptr_input + offset,
3247 mask=tile_mask
3248 & x0_in
3249 & y1_in
3250 & z1_in
3251 & ~grid_x_nan
3252 & ~grid_y_nan
3253 & ~grid_z_nan,
3254 other=0.0,
3255 ).to(tl.float32)
3257 # p111: (x=1, y=1, z=1)
3258 offset = input_base + z1_int * H_in * W_in + y1_int * W_in + x1_int
3259 p111 = tl.load(
3260 ptr_input + offset,
3261 mask=tile_mask
3262 & x1_in
3263 & y1_in
3264 & z1_in
3265 & ~grid_x_nan
3266 & ~grid_y_nan
3267 & ~grid_z_nan,
3268 other=0.0,
3269 ).to(tl.float32)
3271 # 3-stage trilinear interpolation
3272 # Stage 1: Interpolate along x
3273 c000 = p000 * (1.0 - wx) + p001 * wx # z=0, y=0
3274 c001 = p010 * (1.0 - wx) + p011 * wx # z=0, y=1
3275 c010 = p100 * (1.0 - wx) + p101 * wx # z=1, y=0
3276 c011 = p110 * (1.0 - wx) + p111 * wx # z=1, y=1
3278 # Stage 2: Interpolate along y
3279 c00 = c000 * (1.0 - wy) + c001 * wy # z=0
3280 c01 = c010 * (1.0 - wy) + c011 * wy # z=1
3282 # Stage 3: Interpolate along z (final)
3283 vals = c00 * (1.0 - wz) + c01 * wz
3285 # Store to output
3286 output_base = n * C * D_out * H_out * W_out + c * D_out * H_out * W_out
3287 output_offsets = output_base + (
3288 d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d
3289 )
3291 tl.store(ptr_output + output_offsets, vals, mask=tile_mask)
3294@libentry()
3295@triton.autotune(
3296 configs=runtime.get_tuned_config("grid_sample_3d_trilinear_tiled"),
3297 key=["N", "C", "D_out", "H_out", "W_out"],
3298)
3299@triton.jit
3300def grid_sample_3d_trilinear_border_tiled_kernel(
3301 ptr_output,
3302 ptr_input,
3303 ptr_grid,
3304 N,
3305 C,
3306 D_in,
3307 H_in,
3308 W_in,
3309 D_out,
3310 H_out,
3311 W_out,
3312 align_corners: tl.constexpr,
3313 BLOCK_D: tl.constexpr,
3314 BLOCK_H: tl.constexpr,
3315 BLOCK_W: tl.constexpr,
3316):
3317 """
3318 Grid sample kernel for 3D trilinear interpolation with border padding (tiled version).
3319 """
3320 # 2D program IDs: pid_nc for (batch, channel), pid_dhw for spatial tile
3321 pid_nc = tl.program_id(0)
3322 pid_dhw = tl.program_id(1)
3324 # Compute batch and channel
3325 n = pid_nc // C
3326 c = pid_nc % C
3328 # Decompose flattened 3D tile index to (d, h, w) block indices
3329 num_w_blocks = tl.cdiv(W_out, BLOCK_W)
3330 num_h_blocks = tl.cdiv(H_out, BLOCK_H)
3331 num_hw_blocks = num_h_blocks * num_w_blocks
3333 d_block_idx = pid_dhw // num_hw_blocks
3334 hw_block_idx = pid_dhw % num_hw_blocks
3335 h_block_idx = hw_block_idx // num_w_blocks
3336 w_block_idx = hw_block_idx % num_w_blocks
3338 # Compute voxel offsets within tile (3D broadcasting)
3339 d_offsets = d_block_idx * BLOCK_D + tl.arange(0, BLOCK_D)
3340 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
3341 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
3343 # Mask for boundary tiles
3344 d_mask = d_offsets < D_out
3345 h_mask = h_offsets < H_out
3346 w_mask = w_offsets < W_out
3347 tile_mask = d_mask[:, None, None] & h_mask[None, :, None] & w_mask[None, None, :]
3349 # Reshape for 3D broadcasting: (BLOCK_D, BLOCK_H, BLOCK_W)
3350 d_out_3d = d_offsets[:, None, None]
3351 h_out_3d = h_offsets[None, :, None]
3352 w_out_3d = w_offsets[None, None, :]
3354 # Load 3D grid coordinates for entire tile (vectorized)
3355 grid_base = n * D_out * H_out * W_out * 3
3357 grid_x_offsets = (
3358 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3
3359 )
3360 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to(
3361 tl.float32
3362 )
3364 grid_y_offsets = (
3365 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 1
3366 )
3367 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to(
3368 tl.float32
3369 )
3371 grid_z_offsets = (
3372 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 2
3373 )
3374 grid_z = tl.load(ptr_grid + grid_z_offsets, mask=tile_mask, other=0.0).to(
3375 tl.float32
3376 )
3378 # Handle NaN
3379 grid_x_nan = grid_x != grid_x
3380 grid_y_nan = grid_y != grid_y
3381 grid_z_nan = grid_z != grid_z
3382 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
3383 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
3384 grid_z = tl.where(grid_z_nan, -2.0, grid_z)
3386 # Denormalize to pixel space
3387 if align_corners:
3388 x = (grid_x + 1.0) * (W_in - 1) / 2.0
3389 y = (grid_y + 1.0) * (H_in - 1) / 2.0
3390 z = (grid_z + 1.0) * (D_in - 1) / 2.0
3391 else:
3392 x = (grid_x + 1.0) * W_in / 2.0 - 0.5
3393 y = (grid_y + 1.0) * H_in / 2.0 - 0.5
3394 z = (grid_z + 1.0) * D_in / 2.0 - 0.5
3396 # Compute 8 corner indices for entire tile
3397 x0 = tl.floor(x)
3398 y0 = tl.floor(y)
3399 z0 = tl.floor(z)
3400 x1 = x0 + 1
3401 y1 = y0 + 1
3402 z1 = z0 + 1
3404 # Interpolation weights
3405 wx = x - x0
3406 wy = y - y0
3407 wz = z - z0
3409 # Convert to integers and clamp for border padding
3410 x0_int = tl.maximum(0, tl.minimum(tl.cast(x0, tl.int32), W_in - 1))
3411 x1_int = tl.maximum(0, tl.minimum(tl.cast(x1, tl.int32), W_in - 1))
3412 y0_int = tl.maximum(0, tl.minimum(tl.cast(y0, tl.int32), H_in - 1))
3413 y1_int = tl.maximum(0, tl.minimum(tl.cast(y1, tl.int32), H_in - 1))
3414 z0_int = tl.maximum(0, tl.minimum(tl.cast(z0, tl.int32), D_in - 1))
3415 z1_int = tl.maximum(0, tl.minimum(tl.cast(z1, tl.int32), D_in - 1))
3417 # Load 8 corners (vectorized, no bounds mask needed for border)
3418 input_base = n * C * D_in * H_in * W_in + c * D_in * H_in * W_in
3420 offset = input_base + z0_int * H_in * W_in + y0_int * W_in + x0_int
3421 p000 = tl.load(
3422 ptr_input + offset,
3423 mask=tile_mask & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan,
3424 other=0.0,
3425 ).to(tl.float32)
3427 offset = input_base + z0_int * H_in * W_in + y0_int * W_in + x1_int
3428 p001 = tl.load(
3429 ptr_input + offset,
3430 mask=tile_mask & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan,
3431 other=0.0,
3432 ).to(tl.float32)
3434 offset = input_base + z0_int * H_in * W_in + y1_int * W_in + x0_int
3435 p010 = tl.load(
3436 ptr_input + offset,
3437 mask=tile_mask & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan,
3438 other=0.0,
3439 ).to(tl.float32)
3441 offset = input_base + z0_int * H_in * W_in + y1_int * W_in + x1_int
3442 p011 = tl.load(
3443 ptr_input + offset,
3444 mask=tile_mask & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan,
3445 other=0.0,
3446 ).to(tl.float32)
3448 offset = input_base + z1_int * H_in * W_in + y0_int * W_in + x0_int
3449 p100 = tl.load(
3450 ptr_input + offset,
3451 mask=tile_mask & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan,
3452 other=0.0,
3453 ).to(tl.float32)
3455 offset = input_base + z1_int * H_in * W_in + y0_int * W_in + x1_int
3456 p101 = tl.load(
3457 ptr_input + offset,
3458 mask=tile_mask & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan,
3459 other=0.0,
3460 ).to(tl.float32)
3462 offset = input_base + z1_int * H_in * W_in + y1_int * W_in + x0_int
3463 p110 = tl.load(
3464 ptr_input + offset,
3465 mask=tile_mask & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan,
3466 other=0.0,
3467 ).to(tl.float32)
3469 offset = input_base + z1_int * H_in * W_in + y1_int * W_in + x1_int
3470 p111 = tl.load(
3471 ptr_input + offset,
3472 mask=tile_mask & ~grid_x_nan & ~grid_y_nan & ~grid_z_nan,
3473 other=0.0,
3474 ).to(tl.float32)
3476 # 3-stage trilinear interpolation
3477 c000 = p000 * (1.0 - wx) + p001 * wx
3478 c001 = p010 * (1.0 - wx) + p011 * wx
3479 c010 = p100 * (1.0 - wx) + p101 * wx
3480 c011 = p110 * (1.0 - wx) + p111 * wx
3482 c00 = c000 * (1.0 - wy) + c001 * wy
3483 c01 = c010 * (1.0 - wy) + c011 * wy
3485 vals = c00 * (1.0 - wz) + c01 * wz
3487 # Handle NaN
3488 vals = tl.where(grid_x_nan | grid_y_nan | grid_z_nan, 0.0, vals)
3490 # Store to output
3491 output_base = n * C * D_out * H_out * W_out + c * D_out * H_out * W_out
3492 output_offsets = output_base + (
3493 d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d
3494 )
3496 tl.store(ptr_output + output_offsets, vals, mask=tile_mask)
3499@libentry()
3500@triton.autotune(
3501 configs=runtime.get_tuned_config("grid_sample_3d_trilinear_tiled"),
3502 key=["N", "C", "D_out", "H_out", "W_out"],
3503)
3504@triton.jit
3505def grid_sample_3d_trilinear_reflection_tiled_kernel(
3506 ptr_output,
3507 ptr_input,
3508 ptr_grid,
3509 N,
3510 C,
3511 D_in,
3512 H_in,
3513 W_in,
3514 D_out,
3515 H_out,
3516 W_out,
3517 align_corners: tl.constexpr,
3518 BLOCK_D: tl.constexpr,
3519 BLOCK_H: tl.constexpr,
3520 BLOCK_W: tl.constexpr,
3521):
3522 """
3523 Grid sample kernel for 3D trilinear interpolation with reflection padding (tiled version).
3524 """
3525 # 2D program IDs: pid_nc for (batch, channel), pid_dhw for spatial tile
3526 pid_nc = tl.program_id(0)
3527 pid_dhw = tl.program_id(1)
3529 # Compute batch and channel
3530 n = pid_nc // C
3531 c = pid_nc % C
3533 # Decompose flattened 3D tile index to (d, h, w) block indices
3534 num_w_blocks = tl.cdiv(W_out, BLOCK_W)
3535 num_h_blocks = tl.cdiv(H_out, BLOCK_H)
3536 num_hw_blocks = num_h_blocks * num_w_blocks
3538 d_block_idx = pid_dhw // num_hw_blocks
3539 hw_block_idx = pid_dhw % num_hw_blocks
3540 h_block_idx = hw_block_idx // num_w_blocks
3541 w_block_idx = hw_block_idx % num_w_blocks
3543 # Compute voxel offsets within tile (3D broadcasting)
3544 d_offsets = d_block_idx * BLOCK_D + tl.arange(0, BLOCK_D)
3545 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
3546 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
3548 # Mask for boundary tiles
3549 d_mask = d_offsets < D_out
3550 h_mask = h_offsets < H_out
3551 w_mask = w_offsets < W_out
3552 tile_mask = d_mask[:, None, None] & h_mask[None, :, None] & w_mask[None, None, :]
3554 # Reshape for 3D broadcasting: (BLOCK_D, BLOCK_H, BLOCK_W)
3555 d_out_3d = d_offsets[:, None, None]
3556 h_out_3d = h_offsets[None, :, None]
3557 w_out_3d = w_offsets[None, None, :]
3559 # Load 3D grid coordinates for entire tile (vectorized)
3560 grid_base = n * D_out * H_out * W_out * 3
3562 grid_x_offsets = (
3563 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3
3564 )
3565 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to(
3566 tl.float32
3567 )
3569 grid_y_offsets = (
3570 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 1
3571 )
3572 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to(
3573 tl.float32
3574 )
3576 grid_z_offsets = (
3577 grid_base + (d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d) * 3 + 2
3578 )
3579 grid_z = tl.load(ptr_grid + grid_z_offsets, mask=tile_mask, other=0.0).to(
3580 tl.float32
3581 )
3583 # Handle NaN
3584 grid_x_nan = grid_x != grid_x
3585 grid_y_nan = grid_y != grid_y
3586 grid_z_nan = grid_z != grid_z
3587 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
3588 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
3589 grid_z = tl.where(grid_z_nan, -2.0, grid_z)
3591 # Apply triangle wave reflection with period 4
3592 grid_x_shifted = grid_x + 1.0
3593 grid_x_mod = grid_x_shifted % 4.0
3594 grid_x_mod = tl.where(grid_x_mod < 0, grid_x_mod + 4.0, grid_x_mod)
3595 grid_x_refl_mod = tl.where(grid_x_mod <= 2.0, grid_x_mod, 4.0 - grid_x_mod)
3596 grid_x = grid_x_refl_mod - 1.0
3598 grid_y_shifted = grid_y + 1.0
3599 grid_y_mod = grid_y_shifted % 4.0
3600 grid_y_mod = tl.where(grid_y_mod < 0, grid_y_mod + 4.0, grid_y_mod)
3601 grid_y_refl_mod = tl.where(grid_y_mod <= 2.0, grid_y_mod, 4.0 - grid_y_mod)
3602 grid_y = grid_y_refl_mod - 1.0
3604 grid_z_shifted = grid_z + 1.0
3605 grid_z_mod = grid_z_shifted % 4.0
3606 grid_z_mod = tl.where(grid_z_mod < 0, grid_z_mod + 4.0, grid_z_mod)
3607 grid_z_refl_mod = tl.where(grid_z_mod <= 2.0, grid_z_mod, 4.0 - grid_z_mod)
3608 grid_z = grid_z_refl_mod - 1.0
3610 # Denormalize reflected coordinates to pixel space
3611 if align_corners:
3612 x = (grid_x + 1.0) * (W_in - 1) / 2.0
3613 y = (grid_y + 1.0) * (H_in - 1) / 2.0
3614 z = (grid_z + 1.0) * (D_in - 1) / 2.0
3615 else:
3616 x = (grid_x + 1.0) * W_in / 2.0 - 0.5
3617 y = (grid_y + 1.0) * H_in / 2.0 - 0.5
3618 z = (grid_z + 1.0) * D_in / 2.0 - 0.5
3620 # Compute 8 corner indices for entire tile
3621 x0 = tl.floor(x)
3622 y0 = tl.floor(y)
3623 z0 = tl.floor(z)
3624 x1 = x0 + 1
3625 y1 = y0 + 1
3626 z1 = z0 + 1
3628 # Interpolation weights
3629 wx = x - x0
3630 wy = y - y0
3631 wz = z - z0
3633 # Convert to integers
3634 x0_int = tl.cast(x0, tl.int32)
3635 y0_int = tl.cast(y0, tl.int32)
3636 z0_int = tl.cast(z0, tl.int32)
3637 x1_int = tl.cast(x1, tl.int32)
3638 y1_int = tl.cast(y1, tl.int32)
3639 z1_int = tl.cast(z1, tl.int32)
3641 # Boundary checks (reflection ensures coordinates are mostly valid, but still check)
3642 x0_in = (x0_int >= 0) & (x0_int < W_in)
3643 x1_in = (x1_int >= 0) & (x1_int < W_in)
3644 y0_in = (y0_int >= 0) & (y0_int < H_in)
3645 y1_in = (y1_int >= 0) & (y1_int < H_in)
3646 z0_in = (z0_int >= 0) & (z0_int < D_in)
3647 z1_in = (z1_int >= 0) & (z1_int < D_in)
3649 # Load 8 corners (vectorized)
3650 input_base = n * C * D_in * H_in * W_in + c * D_in * H_in * W_in
3652 offset = input_base + z0_int * H_in * W_in + y0_int * W_in + x0_int
3653 p000 = tl.load(
3654 ptr_input + offset,
3655 mask=tile_mask
3656 & x0_in
3657 & y0_in
3658 & z0_in
3659 & ~grid_x_nan
3660 & ~grid_y_nan
3661 & ~grid_z_nan,
3662 other=0.0,
3663 ).to(tl.float32)
3665 offset = input_base + z0_int * H_in * W_in + y0_int * W_in + x1_int
3666 p001 = tl.load(
3667 ptr_input + offset,
3668 mask=tile_mask
3669 & x1_in
3670 & y0_in
3671 & z0_in
3672 & ~grid_x_nan
3673 & ~grid_y_nan
3674 & ~grid_z_nan,
3675 other=0.0,
3676 ).to(tl.float32)
3678 offset = input_base + z0_int * H_in * W_in + y1_int * W_in + x0_int
3679 p010 = tl.load(
3680 ptr_input + offset,
3681 mask=tile_mask
3682 & x0_in
3683 & y1_in
3684 & z0_in
3685 & ~grid_x_nan
3686 & ~grid_y_nan
3687 & ~grid_z_nan,
3688 other=0.0,
3689 ).to(tl.float32)
3691 offset = input_base + z0_int * H_in * W_in + y1_int * W_in + x1_int
3692 p011 = tl.load(
3693 ptr_input + offset,
3694 mask=tile_mask
3695 & x1_in
3696 & y1_in
3697 & z0_in
3698 & ~grid_x_nan
3699 & ~grid_y_nan
3700 & ~grid_z_nan,
3701 other=0.0,
3702 ).to(tl.float32)
3704 offset = input_base + z1_int * H_in * W_in + y0_int * W_in + x0_int
3705 p100 = tl.load(
3706 ptr_input + offset,
3707 mask=tile_mask
3708 & x0_in
3709 & y0_in
3710 & z1_in
3711 & ~grid_x_nan
3712 & ~grid_y_nan
3713 & ~grid_z_nan,
3714 other=0.0,
3715 ).to(tl.float32)
3717 offset = input_base + z1_int * H_in * W_in + y0_int * W_in + x1_int
3718 p101 = tl.load(
3719 ptr_input + offset,
3720 mask=tile_mask
3721 & x1_in
3722 & y0_in
3723 & z1_in
3724 & ~grid_x_nan
3725 & ~grid_y_nan
3726 & ~grid_z_nan,
3727 other=0.0,
3728 ).to(tl.float32)
3730 offset = input_base + z1_int * H_in * W_in + y1_int * W_in + x0_int
3731 p110 = tl.load(
3732 ptr_input + offset,
3733 mask=tile_mask
3734 & x0_in
3735 & y1_in
3736 & z1_in
3737 & ~grid_x_nan
3738 & ~grid_y_nan
3739 & ~grid_z_nan,
3740 other=0.0,
3741 ).to(tl.float32)
3743 offset = input_base + z1_int * H_in * W_in + y1_int * W_in + x1_int
3744 p111 = tl.load(
3745 ptr_input + offset,
3746 mask=tile_mask
3747 & x1_in
3748 & y1_in
3749 & z1_in
3750 & ~grid_x_nan
3751 & ~grid_y_nan
3752 & ~grid_z_nan,
3753 other=0.0,
3754 ).to(tl.float32)
3756 # 3-stage trilinear interpolation
3757 c000 = p000 * (1.0 - wx) + p001 * wx
3758 c001 = p010 * (1.0 - wx) + p011 * wx
3759 c010 = p100 * (1.0 - wx) + p101 * wx
3760 c011 = p110 * (1.0 - wx) + p111 * wx
3762 c00 = c000 * (1.0 - wy) + c001 * wy
3763 c01 = c010 * (1.0 - wy) + c011 * wy
3765 vals = c00 * (1.0 - wz) + c01 * wz
3767 # Store to output
3768 output_base = n * C * D_out * H_out * W_out + c * D_out * H_out * W_out
3769 output_offsets = output_base + (
3770 d_out_3d * H_out * W_out + h_out_3d * W_out + w_out_3d
3771 )
3773 tl.store(ptr_output + output_offsets, vals, mask=tile_mask)
3776# ============================================================================
3777# Tiled Kernels for Medium-to-Large Inputs (Multi-dimensional Blocking)
3778# ============================================================================
3781@libentry()
3782@triton.autotune(
3783 configs=runtime.get_tuned_config("grid_sample_2d_nearest_tiled"),
3784 key=["N", "C", "H_out", "W_out"],
3785)
3786@triton.jit
3787def grid_sample_2d_nearest_zeros_tiled_kernel(
3788 ptr_output,
3789 ptr_input,
3790 ptr_grid,
3791 N,
3792 C,
3793 H_in,
3794 W_in,
3795 H_out,
3796 W_out,
3797 align_corners: tl.constexpr,
3798 BLOCK_H: tl.constexpr,
3799 BLOCK_W: tl.constexpr,
3800):
3801 """
3802 Grid sample kernel for 2D nearest neighbor interpolation with zeros padding (tiled version).
3804 This kernel processes a BLOCK_H × BLOCK_W tile of output pixels at once,
3805 enabling better memory coalescing and data reuse for medium-to-large inputs.
3807 Args:
3808 ptr_output: Pointer to output tensor
3809 ptr_input: Pointer to input tensor
3810 ptr_grid: Pointer to grid tensor
3811 N: Batch size
3812 C: Number of channels
3813 H_in: Input height
3814 W_in: Input width
3815 H_out: Output height
3816 W_out: Output width
3817 align_corners: Whether to align corners
3818 BLOCK_H: Block height for tiling
3819 BLOCK_W: Block width for tiling
3820 """
3821 # 2D program IDs: pid_nc for (batch, channel), pid_hw for spatial tile
3822 pid_nc = tl.program_id(0)
3823 pid_hw = tl.program_id(1)
3825 # Compute batch and channel
3826 n = pid_nc // C
3827 c = pid_nc % C
3829 # Compute tile position in output grid
3830 num_w_blocks = tl.cdiv(W_out, BLOCK_W)
3831 h_block_idx = pid_hw // num_w_blocks
3832 w_block_idx = pid_hw % num_w_blocks
3834 # Compute pixel offsets within tile
3835 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
3836 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
3838 # Mask for boundary tiles
3839 h_mask = h_offsets < H_out
3840 w_mask = w_offsets < W_out
3841 tile_mask = h_mask[:, None] & w_mask[None, :]
3843 # Reshape for broadcasting: (BLOCK_H, BLOCK_W)
3844 h_out_flat = h_offsets[:, None]
3845 w_out_flat = w_offsets[None, :]
3847 # Load grid coordinates for entire tile (vectorized)
3848 # Grid shape: (N, H_out, W_out, 2)
3849 grid_base = n * H_out * W_out * 2
3851 # Load x coordinates: (BLOCK_H, BLOCK_W)
3852 grid_x_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2
3853 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to(
3854 tl.float32
3855 )
3857 # Load y coordinates: (BLOCK_H, BLOCK_W)
3858 grid_y_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2 + 1
3859 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to(
3860 tl.float32
3861 )
3863 # Handle NaN - use sentinel value -2.0 (outside valid grid range [-1, 1])
3864 grid_x_nan = grid_x != grid_x # True if NaN
3865 grid_y_nan = grid_y != grid_y # True if NaN
3866 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
3867 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
3869 # Denormalize to pixel space
3870 if align_corners:
3871 # Pixel centers at -1 and 1
3872 x = (grid_x + 1.0) * (W_in - 1) / 2.0
3873 y = (grid_y + 1.0) * (H_in - 1) / 2.0
3874 else:
3875 # Pixel corners at -1 and 1
3876 x = (grid_x + 1.0) * W_in / 2.0 - 0.5
3877 y = (grid_y + 1.0) * H_in / 2.0 - 0.5
3879 # Apply banker's rounding (vectorized across tile)
3880 x_floor = tl.floor(x)
3881 y_floor = tl.floor(y)
3882 x_frac = x - x_floor
3883 y_frac = y - y_floor
3885 x_is_half = x_frac == 0.5
3886 y_is_half = y_frac == 0.5
3887 x_floor_int = tl.cast(x_floor, tl.int32)
3888 y_floor_int = tl.cast(y_floor, tl.int32)
3890 x_is_even = x_floor_int % 2 == 0
3891 y_is_even = y_floor_int % 2 == 0
3893 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1)
3894 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1)
3896 x_idx = tl.cast(
3897 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round),
3898 tl.int32,
3899 )
3900 y_idx = tl.cast(
3901 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round),
3902 tl.int32,
3903 )
3905 # Check bounds (vectorized)
3906 x_in_bounds = (x_idx >= 0) & (x_idx < W_in)
3907 y_in_bounds = (y_idx >= 0) & (y_idx < H_in)
3908 valid_mask = tile_mask & x_in_bounds & y_in_bounds & ~grid_x_nan & ~grid_y_nan
3910 # Load input pixels for entire tile
3911 # Input shape: (N, C, H_in, W_in)
3912 input_base = n * C * H_in * W_in + c * H_in * W_in
3913 input_offsets = input_base + y_idx * W_in + x_idx
3915 # Vectorized load: (BLOCK_H, BLOCK_W)
3916 vals = tl.load(ptr_input + input_offsets, mask=valid_mask, other=0.0)
3918 # Store to output
3919 # Output shape: (N, C, H_out, W_out)
3920 output_base = n * C * H_out * W_out + c * H_out * W_out
3921 output_offsets = output_base + (h_out_flat * W_out + w_out_flat)
3923 tl.store(ptr_output + output_offsets, vals, mask=tile_mask)
3926@libentry()
3927@triton.autotune(
3928 configs=runtime.get_tuned_config("grid_sample_2d_bilinear_tiled"),
3929 key=["N", "C", "H_out", "W_out"],
3930)
3931@triton.jit
3932def grid_sample_2d_bilinear_zeros_tiled_kernel(
3933 ptr_output,
3934 ptr_input,
3935 ptr_grid,
3936 N,
3937 C,
3938 H_in,
3939 W_in,
3940 H_out,
3941 W_out,
3942 align_corners: tl.constexpr,
3943 BLOCK_H: tl.constexpr,
3944 BLOCK_W: tl.constexpr,
3945):
3946 """
3947 Grid sample kernel for 2D bilinear interpolation with zeros padding (tiled version).
3949 This kernel processes a BLOCK_H × BLOCK_W tile of output pixels at once,
3950 enabling better memory coalescing and data reuse for medium-to-large inputs.
3952 Args:
3953 ptr_output: Pointer to output tensor
3954 ptr_input: Pointer to input tensor
3955 ptr_grid: Pointer to grid tensor
3956 N: Batch size
3957 C: Number of channels
3958 H_in: Input height
3959 W_in: Input width
3960 H_out: Output height
3961 W_out: Output width
3962 align_corners: Whether to align corners
3963 BLOCK_H: Block height for tiling
3964 BLOCK_W: Block width for tiling
3965 """
3966 # 2D program IDs: pid_nc for (batch, channel), pid_hw for spatial tile
3967 pid_nc = tl.program_id(0)
3968 pid_hw = tl.program_id(1)
3970 # Compute batch and channel
3971 n = pid_nc // C
3972 c = pid_nc % C
3974 # Compute tile position in output grid
3975 num_w_blocks = tl.cdiv(W_out, BLOCK_W)
3976 h_block_idx = pid_hw // num_w_blocks
3977 w_block_idx = pid_hw % num_w_blocks
3979 # Compute pixel offsets within tile
3980 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
3981 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
3983 # Mask for boundary tiles
3984 h_mask = h_offsets < H_out
3985 w_mask = w_offsets < W_out
3986 tile_mask = h_mask[:, None] & w_mask[None, :]
3988 # Reshape for broadcasting: (BLOCK_H, BLOCK_W)
3989 h_out_flat = h_offsets[:, None]
3990 w_out_flat = w_offsets[None, :]
3992 # Load grid coordinates for entire tile (vectorized)
3993 # Grid shape: (N, H_out, W_out, 2)
3994 grid_base = n * H_out * W_out * 2
3996 # Load x coordinates: (BLOCK_H, BLOCK_W)
3997 grid_x_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2
3998 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to(
3999 tl.float32
4000 )
4002 # Load y coordinates: (BLOCK_H, BLOCK_W)
4003 grid_y_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2 + 1
4004 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to(
4005 tl.float32
4006 )
4008 # Handle NaN - use sentinel value -2.0
4009 grid_x_nan = grid_x != grid_x
4010 grid_y_nan = grid_y != grid_y
4011 grid_x = tl.where(grid_x_nan, -2.0, grid_x)
4012 grid_y = tl.where(grid_y_nan, -2.0, grid_y)
4014 # Denormalize to pixel space
4015 if align_corners:
4016 # Pixel centers at -1 and 1
4017 x = (grid_x + 1.0) * (W_in - 1) / 2.0
4018 y = (grid_y + 1.0) * (H_in - 1) / 2.0
4019 else:
4020 # Pixel corners at -1 and 1
4021 x = (grid_x + 1.0) * W_in / 2.0 - 0.5
4022 y = (grid_y + 1.0) * H_in / 2.0 - 0.5
4024 # Compute corner indices for entire tile (vectorized)
4025 x0 = tl.floor(x)
4026 x1 = x0 + 1
4027 y0 = tl.floor(y)
4028 y1 = y0 + 1
4030 # Cast to int for indexing
4031 x0_int = tl.cast(x0, tl.int32)
4032 x1_int = tl.cast(x1, tl.int32)
4033 y0_int = tl.cast(y0, tl.int32)
4034 y1_int = tl.cast(y1, tl.int32)
4036 # Check bounds for all 4 corners
4037 x0_in = (x0_int >= 0) & (x0_int < W_in)
4038 x1_in = (x1_int >= 0) & (x1_int < W_in)
4039 y0_in = (y0_int >= 0) & (y0_int < H_in)
4040 y1_in = (y1_int >= 0) & (y1_int < H_in)
4042 # Compute interpolation weights
4043 wx = x - tl.cast(x0, tl.float32)
4044 wy = y - tl.cast(y0, tl.float32)
4046 # Load 4 corner pixels (vectorized)
4047 # Input shape: (N, C, H_in, W_in)
4048 input_base = n * C * H_in * W_in + c * H_in * W_in
4050 p00_offsets = input_base + y0_int * W_in + x0_int
4051 p00 = tl.load(
4052 ptr_input + p00_offsets,
4053 mask=tile_mask & x0_in & y0_in & ~grid_x_nan & ~grid_y_nan,
4054 other=0.0,
4055 )
4057 p01_offsets = input_base + y0_int * W_in + x1_int
4058 p01 = tl.load(
4059 ptr_input + p01_offsets,
4060 mask=tile_mask & x1_in & y0_in & ~grid_x_nan & ~grid_y_nan,
4061 other=0.0,
4062 )
4064 p10_offsets = input_base + y1_int * W_in + x0_int
4065 p10 = tl.load(
4066 ptr_input + p10_offsets,
4067 mask=tile_mask & x0_in & y1_in & ~grid_x_nan & ~grid_y_nan,
4068 other=0.0,
4069 )
4071 p11_offsets = input_base + y1_int * W_in + x1_int
4072 p11 = tl.load(
4073 ptr_input + p11_offsets,
4074 mask=tile_mask & x1_in & y1_in & ~grid_x_nan & ~grid_y_nan,
4075 other=0.0,
4076 )
4078 # Bilinear interpolation (vectorized)
4079 # Interpolate along x, then y
4080 top = p00 * (1.0 - wx) + p01 * wx
4081 bottom = p10 * (1.0 - wx) + p11 * wx
4082 vals = top * (1.0 - wy) + bottom * wy
4084 # Store to output
4085 # Output shape: (N, C, H_out, W_out)
4086 output_base = n * C * H_out * W_out + c * H_out * W_out
4087 output_offsets = output_base + (h_out_flat * W_out + w_out_flat)
4089 tl.store(ptr_output + output_offsets, vals, mask=tile_mask)
4092@libentry()
4093@triton.autotune(
4094 configs=runtime.get_tuned_config("grid_sample_2d_nearest_tiled"),
4095 key=["N", "C", "H_out", "W_out"],
4096)
4097@triton.jit
4098def grid_sample_2d_nearest_border_tiled_kernel(
4099 ptr_output,
4100 ptr_input,
4101 ptr_grid,
4102 N,
4103 C,
4104 H_in,
4105 W_in,
4106 H_out,
4107 W_out,
4108 align_corners: tl.constexpr,
4109 BLOCK_H: tl.constexpr,
4110 BLOCK_W: tl.constexpr,
4111):
4112 """
4113 Grid sample kernel for 2D nearest neighbor interpolation with border padding (tiled version).
4115 Border padding: coordinates are clamped to valid range [0, W_in) x [0, H_in).
4116 """
4117 # 2D program IDs: pid_nc for (batch, channel), pid_hw for spatial tile
4118 pid_nc = tl.program_id(0)
4119 pid_hw = tl.program_id(1)
4121 # Compute batch and channel
4122 n = pid_nc // C
4123 c = pid_nc % C
4125 # Compute tile position in output grid
4126 num_w_blocks = tl.cdiv(W_out, BLOCK_W)
4127 h_block_idx = pid_hw // num_w_blocks
4128 w_block_idx = pid_hw % num_w_blocks
4130 # Compute pixel offsets within tile
4131 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
4132 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
4134 # Mask for boundary tiles
4135 h_mask = h_offsets < H_out
4136 w_mask = w_offsets < W_out
4137 tile_mask = h_mask[:, None] & w_mask[None, :]
4139 # Reshape for broadcasting: (BLOCK_H, BLOCK_W)
4140 h_out_flat = h_offsets[:, None]
4141 w_out_flat = w_offsets[None, :]
4143 # Load grid coordinates for entire tile (vectorized)
4144 grid_base = n * H_out * W_out * 2
4146 # Load x coordinates: (BLOCK_H, BLOCK_W)
4147 grid_x_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2
4148 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to(
4149 tl.float32
4150 )
4152 # Load y coordinates: (BLOCK_H, BLOCK_W)
4153 grid_y_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2 + 1
4154 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to(
4155 tl.float32
4156 )
4158 # Handle NaN - use sentinel -1.0 like original kernel
4159 grid_x = tl.where(grid_x != grid_x, -1.0, grid_x)
4160 grid_y = tl.where(grid_y != grid_y, -1.0, grid_y)
4162 # Denormalize to pixel space
4163 if align_corners:
4164 x = (grid_x + 1.0) * (W_in - 1) / 2.0
4165 y = (grid_y + 1.0) * (H_in - 1) / 2.0
4166 else:
4167 x = (grid_x + 1.0) * W_in / 2.0 - 0.5
4168 y = (grid_y + 1.0) * H_in / 2.0 - 0.5
4170 # Apply banker's rounding (vectorized across tile)
4171 x_floor = tl.floor(x)
4172 y_floor = tl.floor(y)
4173 x_frac = x - x_floor
4174 y_frac = y - y_floor
4176 x_is_half = x_frac == 0.5
4177 y_is_half = y_frac == 0.5
4178 x_floor_int = tl.cast(x_floor, tl.int32)
4179 y_floor_int = tl.cast(y_floor, tl.int32)
4181 x_is_even = x_floor_int % 2 == 0
4182 y_is_even = y_floor_int % 2 == 0
4184 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1)
4185 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1)
4187 x_idx_unclamped = tl.cast(
4188 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round),
4189 tl.int32,
4190 )
4191 y_idx_unclamped = tl.cast(
4192 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round),
4193 tl.int32,
4194 )
4196 # Clamp to valid range (border padding)
4197 x_idx = tl.maximum(0, tl.minimum(x_idx_unclamped, W_in - 1))
4198 y_idx = tl.maximum(0, tl.minimum(y_idx_unclamped, H_in - 1))
4200 # Load input pixels for entire tile (no mask needed - clamping ensures validity)
4201 input_base = n * C * H_in * W_in + c * H_in * W_in
4202 input_offsets = input_base + y_idx * W_in + x_idx
4204 vals = tl.load(ptr_input + input_offsets, mask=tile_mask, other=0.0)
4206 # Store to output
4207 output_base = n * C * H_out * W_out + c * H_out * W_out
4208 output_offsets = output_base + (h_out_flat * W_out + w_out_flat)
4210 tl.store(ptr_output + output_offsets, vals, mask=tile_mask)
4213@libentry()
4214@triton.autotune(
4215 configs=runtime.get_tuned_config("grid_sample_2d_bilinear_tiled"),
4216 key=["N", "C", "H_out", "W_out"],
4217)
4218@triton.jit
4219def grid_sample_2d_bilinear_border_tiled_kernel(
4220 ptr_output,
4221 ptr_input,
4222 ptr_grid,
4223 N,
4224 C,
4225 H_in,
4226 W_in,
4227 H_out,
4228 W_out,
4229 align_corners: tl.constexpr,
4230 BLOCK_H: tl.constexpr,
4231 BLOCK_W: tl.constexpr,
4232):
4233 """
4234 Grid sample kernel for 2D bilinear interpolation with border padding (tiled version).
4236 Border padding: coordinates are clamped to valid range [0, W_in) x [0, H_in).
4237 """
4238 # 2D program IDs: pid_nc for (batch, channel), pid_hw for spatial tile
4239 pid_nc = tl.program_id(0)
4240 pid_hw = tl.program_id(1)
4242 # Compute batch and channel
4243 n = pid_nc // C
4244 c = pid_nc % C
4246 # Compute tile position in output grid
4247 num_w_blocks = tl.cdiv(W_out, BLOCK_W)
4248 h_block_idx = pid_hw // num_w_blocks
4249 w_block_idx = pid_hw % num_w_blocks
4251 # Compute pixel offsets within tile
4252 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
4253 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
4255 # Mask for boundary tiles
4256 h_mask = h_offsets < H_out
4257 w_mask = w_offsets < W_out
4258 tile_mask = h_mask[:, None] & w_mask[None, :]
4260 # Reshape for broadcasting: (BLOCK_H, BLOCK_W)
4261 h_out_flat = h_offsets[:, None]
4262 w_out_flat = w_offsets[None, :]
4264 # Load grid coordinates for entire tile (vectorized)
4265 grid_base = n * H_out * W_out * 2
4267 # Load x coordinates: (BLOCK_H, BLOCK_W)
4268 grid_x_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2
4269 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to(
4270 tl.float32
4271 )
4273 # Load y coordinates: (BLOCK_H, BLOCK_W)
4274 grid_y_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2 + 1
4275 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to(
4276 tl.float32
4277 )
4279 # Handle NaN - use sentinel -1.0 like original kernel
4280 grid_x = tl.where(grid_x != grid_x, -1.0, grid_x)
4281 grid_y = tl.where(grid_y != grid_y, -1.0, grid_y)
4283 # Denormalize to pixel space
4284 if align_corners:
4285 x = (grid_x + 1.0) * (W_in - 1) / 2.0
4286 y = (grid_y + 1.0) * (H_in - 1) / 2.0
4287 else:
4288 x = (grid_x + 1.0) * W_in / 2.0 - 0.5
4289 y = (grid_y + 1.0) * H_in / 2.0 - 0.5
4291 # Compute corner indices for entire tile (vectorized)
4292 x0 = tl.floor(x)
4293 x1 = x0 + 1
4294 y0 = tl.floor(y)
4295 y1 = y0 + 1
4297 # Cast to int for indexing
4298 x0_int = tl.cast(x0, tl.int32)
4299 x1_int = tl.cast(x1, tl.int32)
4300 y0_int = tl.cast(y0, tl.int32)
4301 y1_int = tl.cast(y1, tl.int32)
4303 # Clamp to valid range (border padding)
4304 x0_int = tl.maximum(0, tl.minimum(x0_int, W_in - 1))
4305 x1_int = tl.maximum(0, tl.minimum(x1_int, W_in - 1))
4306 y0_int = tl.maximum(0, tl.minimum(y0_int, H_in - 1))
4307 y1_int = tl.maximum(0, tl.minimum(y1_int, H_in - 1))
4309 # Compute interpolation weights
4310 wx = x - tl.cast(x0, tl.float32)
4311 wy = y - tl.cast(y0, tl.float32)
4313 # Load 4 corner pixels (vectorized, no mask needed - clamping ensures validity)
4314 input_base = n * C * H_in * W_in + c * H_in * W_in
4316 p00_offsets = input_base + y0_int * W_in + x0_int
4317 p00 = tl.load(ptr_input + p00_offsets, mask=tile_mask, other=0.0)
4319 p01_offsets = input_base + y0_int * W_in + x1_int
4320 p01 = tl.load(ptr_input + p01_offsets, mask=tile_mask, other=0.0)
4322 p10_offsets = input_base + y1_int * W_in + x0_int
4323 p10 = tl.load(ptr_input + p10_offsets, mask=tile_mask, other=0.0)
4325 p11_offsets = input_base + y1_int * W_in + x1_int
4326 p11 = tl.load(ptr_input + p11_offsets, mask=tile_mask, other=0.0)
4328 # Bilinear interpolation (vectorized)
4329 top = p00 * (1.0 - wx) + p01 * wx
4330 bottom = p10 * (1.0 - wx) + p11 * wx
4331 vals = top * (1.0 - wy) + bottom * wy
4333 # Store to output
4334 output_base = n * C * H_out * W_out + c * H_out * W_out
4335 output_offsets = output_base + (h_out_flat * W_out + w_out_flat)
4337 tl.store(ptr_output + output_offsets, vals, mask=tile_mask)
4340@libentry()
4341@triton.autotune(
4342 configs=runtime.get_tuned_config("grid_sample_2d_nearest_tiled"),
4343 key=["N", "C", "H_out", "W_out"],
4344)
4345@triton.jit
4346def grid_sample_2d_nearest_reflection_tiled_kernel(
4347 ptr_output,
4348 ptr_input,
4349 ptr_grid,
4350 N,
4351 C,
4352 H_in,
4353 W_in,
4354 H_out,
4355 W_out,
4356 align_corners: tl.constexpr,
4357 BLOCK_H: tl.constexpr,
4358 BLOCK_W: tl.constexpr,
4359):
4360 """
4361 Grid sample kernel for 2D nearest neighbor interpolation with reflection padding (tiled version).
4363 Reflection padding: applies triangle wave reflection in grid space.
4364 """
4365 # 2D program IDs: pid_nc for (batch, channel), pid_hw for spatial tile
4366 pid_nc = tl.program_id(0)
4367 pid_hw = tl.program_id(1)
4369 # Compute batch and channel
4370 n = pid_nc // C
4371 c = pid_nc % C
4373 # Compute tile position in output grid
4374 num_w_blocks = tl.cdiv(W_out, BLOCK_W)
4375 h_block_idx = pid_hw // num_w_blocks
4376 w_block_idx = pid_hw % num_w_blocks
4378 # Compute pixel offsets within tile
4379 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
4380 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
4382 # Mask for boundary tiles
4383 h_mask = h_offsets < H_out
4384 w_mask = w_offsets < W_out
4385 tile_mask = h_mask[:, None] & w_mask[None, :]
4387 # Reshape for broadcasting: (BLOCK_H, BLOCK_W)
4388 h_out_flat = h_offsets[:, None]
4389 w_out_flat = w_offsets[None, :]
4391 # Load grid coordinates for entire tile (vectorized)
4392 grid_base = n * H_out * W_out * 2
4394 # Load x coordinates: (BLOCK_H, BLOCK_W)
4395 grid_x_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2
4396 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to(
4397 tl.float32
4398 )
4400 # Load y coordinates: (BLOCK_H, BLOCK_W)
4401 grid_y_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2 + 1
4402 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to(
4403 tl.float32
4404 )
4406 # Apply triangle wave reflection in grid space (vectorized)
4407 # Triangle wave pattern with period 4
4408 grid_x_shifted = grid_x + 1.0
4409 grid_x_mod = grid_x_shifted % 4.0
4410 grid_x_mod = tl.where(grid_x_mod < 0, grid_x_mod + 4.0, grid_x_mod)
4411 grid_x_refl_mod = tl.where(grid_x_mod <= 2.0, grid_x_mod, 4.0 - grid_x_mod)
4412 x = grid_x_refl_mod - 1.0
4414 grid_y_shifted = grid_y + 1.0
4415 grid_y_mod = grid_y_shifted % 4.0
4416 grid_y_mod = tl.where(grid_y_mod < 0, grid_y_mod + 4.0, grid_y_mod)
4417 grid_y_refl_mod = tl.where(grid_y_mod <= 2.0, grid_y_mod, 4.0 - grid_y_mod)
4418 y = grid_y_refl_mod - 1.0
4420 # Denormalize to pixel space
4421 if align_corners:
4422 x = (x + 1.0) * (W_in - 1) / 2.0
4423 y = (y + 1.0) * (H_in - 1) / 2.0
4424 else:
4425 x = (x + 1.0) * W_in / 2.0 - 0.5
4426 y = (y + 1.0) * H_in / 2.0 - 0.5
4428 # Apply banker's rounding (vectorized across tile)
4429 x_floor = tl.floor(x)
4430 y_floor = tl.floor(y)
4431 x_frac = x - x_floor
4432 y_frac = y - y_floor
4434 x_is_half = x_frac == 0.5
4435 y_is_half = y_frac == 0.5
4436 x_floor_int = tl.cast(x_floor, tl.int32)
4437 y_floor_int = tl.cast(y_floor, tl.int32)
4439 x_is_even = x_floor_int % 2 == 0
4440 y_is_even = y_floor_int % 2 == 0
4442 x_round = tl.where(x_frac < 0.5, x_floor, x_floor + 1)
4443 y_round = tl.where(y_frac < 0.5, y_floor, y_floor + 1)
4445 x_idx_unclamped = tl.cast(
4446 tl.where(x_is_half, tl.where(x_is_even, x_floor, x_floor + 1), x_round),
4447 tl.int32,
4448 )
4449 y_idx_unclamped = tl.cast(
4450 tl.where(y_is_half, tl.where(y_is_even, y_floor, y_floor + 1), y_round),
4451 tl.int32,
4452 )
4454 # Clamp to valid bounds (should already be in bounds due to reflection, but clamp for safety)
4455 x_idx = tl.maximum(0, tl.minimum(x_idx_unclamped, W_in - 1))
4456 y_idx = tl.maximum(0, tl.minimum(y_idx_unclamped, H_in - 1))
4458 # Load input pixels for entire tile
4459 input_base = n * C * H_in * W_in + c * H_in * W_in
4460 input_offsets = input_base + y_idx * W_in + x_idx
4462 vals = tl.load(ptr_input + input_offsets, mask=tile_mask, other=0.0)
4464 # Store to output
4465 output_base = n * C * H_out * W_out + c * H_out * W_out
4466 output_offsets = output_base + (h_out_flat * W_out + w_out_flat)
4468 tl.store(ptr_output + output_offsets, vals, mask=tile_mask)
4471@libentry()
4472@triton.autotune(
4473 configs=runtime.get_tuned_config("grid_sample_2d_bilinear_tiled"),
4474 key=["N", "C", "H_out", "W_out"],
4475)
4476@triton.jit
4477def grid_sample_2d_bilinear_reflection_tiled_kernel(
4478 ptr_output,
4479 ptr_input,
4480 ptr_grid,
4481 N,
4482 C,
4483 H_in,
4484 W_in,
4485 H_out,
4486 W_out,
4487 align_corners: tl.constexpr,
4488 BLOCK_H: tl.constexpr,
4489 BLOCK_W: tl.constexpr,
4490):
4491 """
4492 Grid sample kernel for 2D bilinear interpolation with reflection padding (tiled version).
4494 Reflection padding: applies triangle wave reflection in grid space.
4495 """
4496 # 2D program IDs: pid_nc for (batch, channel), pid_hw for spatial tile
4497 pid_nc = tl.program_id(0)
4498 pid_hw = tl.program_id(1)
4500 # Compute batch and channel
4501 n = pid_nc // C
4502 c = pid_nc % C
4504 # Compute tile position in output grid
4505 num_w_blocks = tl.cdiv(W_out, BLOCK_W)
4506 h_block_idx = pid_hw // num_w_blocks
4507 w_block_idx = pid_hw % num_w_blocks
4509 # Compute pixel offsets within tile
4510 h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
4511 w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
4513 # Mask for boundary tiles
4514 h_mask = h_offsets < H_out
4515 w_mask = w_offsets < W_out
4516 tile_mask = h_mask[:, None] & w_mask[None, :]
4518 # Reshape for broadcasting: (BLOCK_H, BLOCK_W)
4519 h_out_flat = h_offsets[:, None]
4520 w_out_flat = w_offsets[None, :]
4522 # Load grid coordinates for entire tile (vectorized)
4523 grid_base = n * H_out * W_out * 2
4525 # Load x coordinates: (BLOCK_H, BLOCK_W)
4526 grid_x_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2
4527 grid_x = tl.load(ptr_grid + grid_x_offsets, mask=tile_mask, other=0.0).to(
4528 tl.float32
4529 )
4531 # Load y coordinates: (BLOCK_H, BLOCK_W)
4532 grid_y_offsets = grid_base + (h_out_flat * W_out + w_out_flat) * 2 + 1
4533 grid_y = tl.load(ptr_grid + grid_y_offsets, mask=tile_mask, other=0.0).to(
4534 tl.float32
4535 )
4537 # Apply triangle wave reflection in grid space (vectorized)
4538 grid_x_shifted = grid_x + 1.0
4539 grid_x_mod = grid_x_shifted % 4.0
4540 grid_x_mod = tl.where(grid_x_mod < 0, grid_x_mod + 4.0, grid_x_mod)
4541 grid_x_refl_mod = tl.where(grid_x_mod <= 2.0, grid_x_mod, 4.0 - grid_x_mod)
4542 x = grid_x_refl_mod - 1.0
4544 grid_y_shifted = grid_y + 1.0
4545 grid_y_mod = grid_y_shifted % 4.0
4546 grid_y_mod = tl.where(grid_y_mod < 0, grid_y_mod + 4.0, grid_y_mod)
4547 grid_y_refl_mod = tl.where(grid_y_mod <= 2.0, grid_y_mod, 4.0 - grid_y_mod)
4548 y = grid_y_refl_mod - 1.0
4550 # Denormalize to pixel space
4551 if align_corners:
4552 x = (x + 1.0) * (W_in - 1) / 2.0
4553 y = (y + 1.0) * (H_in - 1) / 2.0
4554 else:
4555 x = (x + 1.0) * W_in / 2.0 - 0.5
4556 y = (y + 1.0) * H_in / 2.0 - 0.5
4558 # Compute corner indices for entire tile (vectorized)
4559 x0 = tl.floor(x)
4560 x1 = x0 + 1
4561 y0 = tl.floor(y)
4562 y1 = y0 + 1
4564 # Cast to int for indexing
4565 x0_int = tl.cast(x0, tl.int32)
4566 x1_int = tl.cast(x1, tl.int32)
4567 y0_int = tl.cast(y0, tl.int32)
4568 y1_int = tl.cast(y1, tl.int32)
4570 # Clamp to valid bounds (should already be in bounds due to reflection)
4571 x0_int = tl.maximum(0, tl.minimum(x0_int, W_in - 1))
4572 x1_int = tl.maximum(0, tl.minimum(x1_int, W_in - 1))
4573 y0_int = tl.maximum(0, tl.minimum(y0_int, H_in - 1))
4574 y1_int = tl.maximum(0, tl.minimum(y1_int, H_in - 1))
4576 # Compute interpolation weights
4577 wx = x - tl.cast(x0, tl.float32)
4578 wy = y - tl.cast(y0, tl.float32)
4580 # Load 4 corner pixels (vectorized)
4581 input_base = n * C * H_in * W_in + c * H_in * W_in
4583 p00_offsets = input_base + y0_int * W_in + x0_int
4584 p00 = tl.load(ptr_input + p00_offsets, mask=tile_mask, other=0.0)
4586 p01_offsets = input_base + y0_int * W_in + x1_int
4587 p01 = tl.load(ptr_input + p01_offsets, mask=tile_mask, other=0.0)
4589 p10_offsets = input_base + y1_int * W_in + x0_int
4590 p10 = tl.load(ptr_input + p10_offsets, mask=tile_mask, other=0.0)
4592 p11_offsets = input_base + y1_int * W_in + x1_int
4593 p11 = tl.load(ptr_input + p11_offsets, mask=tile_mask, other=0.0)
4595 # Bilinear interpolation (vectorized)
4596 top = p00 * (1.0 - wx) + p01 * wx
4597 bottom = p10 * (1.0 - wx) + p11 * wx
4598 vals = top * (1.0 - wy) + bottom * wy
4600 # Store to output
4601 output_base = n * C * H_out * W_out + c * H_out * W_out
4602 output_offsets = output_base + (h_out_flat * W_out + w_out_flat)
4604 tl.store(ptr_output + output_offsets, vals, mask=tile_mask)
4607# ============================================================================
4608# Main Dispatch Function
4609# ============================================================================
4612def grid_sample(
4613 input: torch.Tensor,
4614 grid: torch.Tensor,
4615 mode: str = "bilinear",
4616 padding_mode: str = "zeros",
4617 align_corners: bool = False,
4618) -> torch.Tensor:
4619 """
4620 Grid sample operation with spatial interpolation.
4622 Computes the output using input values and pixel locations from grid.
4623 Grid specifies sampling pixel locations normalized by input spatial dimensions.
4625 Args:
4626 input: Input tensor of shape (N, C, H_in, W_in) or (N, C, D_in, H_in, W_in)
4627 grid: Grid tensor of shape (N, H_out, W_out, 2) or (N, D_out, H_out, W_out, 3)
4628 Values should be in range [-1, 1], normalized by input spatial dimensions
4629 mode: Interpolation mode - 'bilinear', 'nearest', or 'bicubic' (4D only)
4630 padding_mode: Padding mode for out-of-bound grid locations
4631 - 'zeros': use 0 for out-of-bound locations
4632 - 'border': use border values
4633 - 'reflection': reflect by border
4634 align_corners: If True, extrema (-1, 1) refer to center points of corner pixels
4635 If False, extrema refer to corner points of corner pixels
4637 Returns:
4638 Output tensor of shape (N, C, H_out, W_out) or (N, C, D_out, H_out, W_out)
4640 Examples:
4641 >>> input = torch.randn(1, 3, 32, 32).cuda()
4642 >>> grid = torch.randn(1, 64, 64, 2).cuda()
4643 >>> output = grid_sample(input, grid, mode='bilinear')
4644 >>> print(output.shape)
4645 torch.Size([1, 3, 64, 64])
4646 """
4647 # Validate inputs
4648 _validate_grid_sample_input(input, grid, mode, padding_mode)
4650 # Get tensor properties
4651 dtype = input.dtype
4652 device = input.device
4654 is_3d = input.dim() == 5
4656 # Handle 4D inputs (N, C, H_in, W_in)
4657 if not is_3d:
4658 N, C, H_in, W_in = input.shape
4659 _, H_out, W_out, _ = grid.shape
4661 # Allocate output tensor
4662 output = torch.empty((N, C, H_out, W_out), dtype=dtype, device=device)
4664 # Adaptive kernel selection based on output size
4665 # Use tiled kernels for medium-to-large outputs (>= 32x32 = 1024 pixels)
4666 # Use original per-pixel kernels for small outputs (< 32x32)
4667 output_pixels = H_out * W_out
4668 USE_TILED_THRESHOLD = 1024
4670 use_tiled = output_pixels >= USE_TILED_THRESHOLD
4672 # Select kernel based on mode, padding mode, and output size
4673 if mode == "nearest":
4674 if use_tiled:
4675 # Use tiled kernels for medium-to-large outputs
4676 if padding_mode == "zeros":
4677 kernel = grid_sample_2d_nearest_zeros_tiled_kernel
4678 elif padding_mode == "border":
4679 kernel = grid_sample_2d_nearest_border_tiled_kernel
4680 else: # reflection
4681 kernel = grid_sample_2d_nearest_reflection_tiled_kernel
4682 else:
4683 # Use original kernels for small outputs
4684 if padding_mode == "zeros":
4685 kernel = grid_sample_2d_nearest_zeros_kernel
4686 elif padding_mode == "border":
4687 kernel = grid_sample_2d_nearest_border_kernel
4688 else: # reflection
4689 kernel = grid_sample_2d_nearest_reflection_kernel
4690 elif mode == "bilinear":
4691 if use_tiled:
4692 # Use tiled kernels for medium-to-large outputs
4693 if padding_mode == "zeros":
4694 kernel = grid_sample_2d_bilinear_zeros_tiled_kernel
4695 elif padding_mode == "border":
4696 kernel = grid_sample_2d_bilinear_border_tiled_kernel
4697 else: # reflection
4698 kernel = grid_sample_2d_bilinear_reflection_tiled_kernel
4699 else:
4700 # Use original kernels for small outputs
4701 if padding_mode == "zeros":
4702 kernel = grid_sample_2d_bilinear_zeros_kernel
4703 elif padding_mode == "border":
4704 kernel = grid_sample_2d_bilinear_border_kernel
4705 else: # reflection
4706 kernel = grid_sample_2d_bilinear_reflection_kernel
4707 elif mode == "bicubic":
4708 # Bicubic is already competitive, use original kernels for all sizes
4709 if padding_mode == "zeros":
4710 kernel = grid_sample_2d_bicubic_zeros_kernel
4711 elif padding_mode == "border":
4712 kernel = grid_sample_2d_bicubic_border_kernel
4713 else: # reflection
4714 kernel = grid_sample_2d_bicubic_reflection_kernel
4715 else: # unsupported mode
4716 logger.info(f"grid_sample mode '{mode}' not supported")
4717 raise NotImplementedError
4719 # Launch kernel with appropriate grid size
4720 # For very large outputs (> 512x512), fall back to original kernels to avoid grid size issues
4721 output_pixels = H_out * W_out
4722 MAX_TILED_PIXELS = 512 * 512
4724 if (
4725 use_tiled
4726 and mode in ["nearest", "bilinear"]
4727 and output_pixels <= MAX_TILED_PIXELS
4728 ):
4729 # Tiled kernels use 2D grid with adaptive tile size selection
4730 # Goal: Create ~100-500 blocks total for good GPU utilization
4731 target_total_blocks = (
4732 300 # Target: aim for ~300 blocks across all batches/channels
4733 )
4734 min_blocks_per_nc = 50 # Minimum: ensure enough parallelism
4735 max_blocks_per_nc = 1000 # Maximum: avoid too many blocks
4737 # Estimate blocks per (batch, channel) pair
4738 nc_pairs = N * C
4740 # Target blocks per (N, C) pair
4741 target_blocks_per_nc = max(
4742 min_blocks_per_nc,
4743 min(max_blocks_per_nc, target_total_blocks // max(1, nc_pairs)),
4744 )
4746 # Calculate tile dimensions to achieve target block count
4747 # Start with square tiles
4748 target_tile_pixels = output_pixels // target_blocks_per_nc
4749 target_tile_side = int(max(4, min(128, int(target_tile_pixels**0.5))))
4751 # Snap to power-of-2 for better alignment
4752 if target_tile_side >= 64:
4753 block_h = block_w = 64 if target_tile_side < 96 else 128
4754 elif target_tile_side >= 16:
4755 block_h = block_w = 32
4756 elif target_tile_side >= 8:
4757 block_h = block_w = 16
4758 else:
4759 block_h = block_w = 8
4761 # For bilinear, use smaller tiles due to higher memory footprint
4762 if mode == "bilinear":
4763 block_h = max(4, block_h // 2)
4764 block_w = max(4, block_w // 2)
4766 # Calculate actual grid size
4767 num_h_blocks = (H_out + block_h - 1) // block_h
4768 num_w_blocks = (W_out + block_w - 1) // block_w
4769 grid_size = (N * C, num_h_blocks * num_w_blocks)
4770 else:
4771 # Original kernels use 1D grid (for small outputs or very large outputs)
4772 grid_size = (N * C * H_out * W_out,)
4774 kernel[grid_size](
4775 output,
4776 input,
4777 grid,
4778 N,
4779 C,
4780 H_in,
4781 W_in,
4782 H_out,
4783 W_out,
4784 align_corners,
4785 )
4787 return output
4789 # Handle 5D inputs (N, C, D_in, H_in, W_in)
4790 else: # is_3d == True
4791 N, C, D_in, H_in, W_in = input.shape
4792 _, D_out, H_out, W_out, _ = grid.shape
4794 # Allocate output tensor
4795 output = torch.empty((N, C, D_out, H_out, W_out), dtype=dtype, device=device)
4797 # Adaptive kernel selection based on output size
4798 # Use tiled kernels for medium-to-large outputs (>= 16x16x16 = 4096 voxels)
4799 # Increased from 512 to avoid tiled kernel overhead on small outputs
4800 output_voxels = D_out * H_out * W_out
4801 USE_TILED_THRESHOLD_3D = 4096 # 16x16x16
4803 use_tiled = output_voxels >= USE_TILED_THRESHOLD_3D
4805 # Select kernel based on mode, padding mode, and output size
4806 if mode == "nearest":
4807 if use_tiled:
4808 # Use tiled kernels for medium-to-large outputs
4809 if padding_mode == "zeros":
4810 kernel = grid_sample_3d_nearest_zeros_tiled_kernel
4811 elif padding_mode == "border":
4812 kernel = grid_sample_3d_nearest_border_tiled_kernel
4813 else: # reflection
4814 kernel = grid_sample_3d_nearest_reflection_tiled_kernel
4815 else:
4816 # Use original kernels for small outputs
4817 if padding_mode == "zeros":
4818 kernel = grid_sample_3d_nearest_zeros_kernel
4819 elif padding_mode == "border":
4820 kernel = grid_sample_3d_nearest_border_kernel
4821 else: # reflection
4822 kernel = grid_sample_3d_nearest_reflection_kernel
4823 elif mode == "bilinear": # For 5D, bilinear means trilinear
4824 if use_tiled:
4825 # Use tiled kernels for medium-to-large outputs
4826 if padding_mode == "zeros":
4827 kernel = grid_sample_3d_trilinear_zeros_tiled_kernel
4828 elif padding_mode == "border":
4829 kernel = grid_sample_3d_trilinear_border_tiled_kernel
4830 else: # reflection
4831 kernel = grid_sample_3d_trilinear_reflection_tiled_kernel
4832 else:
4833 # Use original kernels for small outputs
4834 if padding_mode == "zeros":
4835 kernel = grid_sample_3d_trilinear_zeros_kernel
4836 elif padding_mode == "border":
4837 kernel = grid_sample_3d_trilinear_border_kernel
4838 else: # reflection
4839 kernel = grid_sample_3d_trilinear_reflection_kernel
4840 else: # unsupported mode for 5D
4841 logger.info(f"grid_sample mode '{mode}' not supported for 5D input")
4842 raise NotImplementedError("Unsupported mode for 5D input")
4844 # Launch kernel with appropriate grid size
4845 # For very large outputs (> 128x128x128), fall back to original kernels
4846 if (
4847 use_tiled
4848 and mode in ["nearest", "bilinear"]
4849 and output_voxels <= MAX_TILED_VOXELS
4850 ):
4851 # Tiled kernels use 2D grid with adaptive tile size selection
4852 # Goal: Create optimal blocks for good GPU utilization (more granular for medium outputs)
4853 nc_pairs = N * C
4855 # More granular targeting to fix 16³ and 32³ performance
4856 # Key: Need MORE blocks for 16³ and 32³, not fewer
4857 if output_voxels < VOXEL_THRESHOLD_SMALL: # 16³ - 20³
4858 target_total_blocks = TARGET_BLOCKS_SMALL
4859 min_blocks_per_nc = MIN_BLOCKS_NC_SMALL
4860 max_blocks_per_nc = MAX_BLOCKS_NC_SMALL
4861 elif output_voxels < VOXEL_THRESHOLD_MEDIUM: # 20³ - 32³
4862 target_total_blocks = TARGET_BLOCKS_MEDIUM
4863 min_blocks_per_nc = MIN_BLOCKS_NC_MEDIUM
4864 max_blocks_per_nc = MAX_BLOCKS_NC_MEDIUM
4865 elif output_voxels < VOXEL_THRESHOLD_LARGE: # 32³ - 50³
4866 target_total_blocks = TARGET_BLOCKS_LARGE
4867 min_blocks_per_nc = MIN_BLOCKS_NC_LARGE
4868 max_blocks_per_nc = MAX_BLOCKS_NC_LARGE
4869 elif output_voxels < VOXEL_THRESHOLD_VERY_LARGE: # 50³ - 64³
4870 target_total_blocks = TARGET_BLOCKS_VERY_LARGE
4871 min_blocks_per_nc = MIN_BLOCKS_NC_VERY_LARGE
4872 max_blocks_per_nc = MAX_BLOCKS_NC_VERY_LARGE
4873 else: # Large outputs (>= 64³)
4874 target_total_blocks = TARGET_BLOCKS_EXTRA_LARGE
4875 min_blocks_per_nc = MIN_BLOCKS_NC_EXTRA_LARGE
4876 max_blocks_per_nc = MAX_BLOCKS_NC_EXTRA_LARGE
4878 # Channel-aware tiling: reduce targets for high channel counts to avoid too many blocks
4879 # When C is large, we create too many blocks with the current formula
4880 # Solution: Reduce target_total_blocks proportionally
4881 if C > CHANNEL_COUNT_THRESHOLD:
4882 # Scale down targets more aggressively to avoid excessive blocks when C > threshold
4883 # Use sqrt scaling for better balance
4884 channel_scale = (
4885 CHANNEL_COUNT_THRESHOLD / C
4886 ) ** CHANNEL_SCALING_EXPONENT
4887 target_total_blocks = max(
4888 MIN_TARGET_TOTAL_BLOCKS, int(target_total_blocks * channel_scale)
4889 )
4890 min_blocks_per_nc = max(
4891 MIN_BLOCKS_PER_NC, int(min_blocks_per_nc * channel_scale)
4892 )
4893 # Keep max_blocks_per_nc unchanged to prevent excessive blocks
4895 # Target blocks per (N, C) pair
4896 target_blocks_per_nc = max(
4897 min_blocks_per_nc,
4898 min(max_blocks_per_nc, target_total_blocks // max(1, nc_pairs)),
4899 )
4901 # Calculate tile dimensions to achieve target block count
4902 # For 3D, start with cubic tiles
4903 total_voxels = D_out * H_out * W_out
4904 target_tile_voxels = total_voxels // target_blocks_per_nc
4905 target_tile_side = int(
4906 max(
4907 MIN_TILE_SIDE,
4908 min(MAX_TILE_SIDE, int(target_tile_voxels ** (1.0 / 3.0))),
4909 )
4910 )
4912 # Snap to power-of-2 for better alignment
4913 # Minimum tile size is 4x4x4 for small outputs, 8x8x8 for large
4914 if target_tile_side >= LARGE_TILE_THRESHOLD:
4915 block_d = block_h = block_w = (
4916 LARGE_TILE_THRESHOLD
4917 if target_tile_side < VERY_LARGE_TILE_THRESHOLD
4918 else MAX_TILE_SIDE
4919 )
4920 elif target_tile_side >= MEDIUM_TILE_THRESHOLD:
4921 block_d = block_h = block_w = MEDIUM_TILE_THRESHOLD
4922 elif target_tile_side >= SMALL_TILE_THRESHOLD:
4923 block_d = block_h = block_w = SMALL_TILE_THRESHOLD
4924 else:
4925 block_d = block_h = block_w = MIN_TILE_SIDE
4927 # For trilinear, use smaller tiles due to higher memory footprint (8x loads)
4928 if mode == "bilinear": # actually trilinear in 5D
4929 block_d = max(MIN_BLOCK_DIMENSION, block_d // 2)
4930 block_h = max(MIN_BLOCK_DIMENSION, block_h // 2)
4931 block_w = max(MIN_BLOCK_DIMENSION, block_w // 2)
4933 # Calculate actual grid size
4934 num_d_blocks = (D_out + block_d - 1) // block_d
4935 num_h_blocks = (H_out + block_h - 1) // block_h
4936 num_w_blocks = (W_out + block_w - 1) // block_w
4937 grid_size = (N * C, num_d_blocks * num_h_blocks * num_w_blocks)
4938 else:
4939 # Original kernels use 1D grid (for small outputs or very large outputs)
4940 grid_size = (N * C * D_out * H_out * W_out,)
4942 # Kernel launch
4943 kernel[grid_size](
4944 output,
4945 input,
4946 grid,
4947 N,
4948 C,
4949 D_in,
4950 H_in,
4951 W_in,
4952 D_out,
4953 H_out,
4954 W_out,
4955 align_corners,
4956 )
4958 return output