Coverage for src/flag_gems/ops/max_pool3d_with_indices.py: 12%
164 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import libentry
8from flag_gems.utils.limits import get_dtype_min
10logger = logging.getLogger(__name__)
13def pool3d_output_size(
14 in_size: int,
15 kernel_size: int,
16 stride: int,
17 padding: int,
18 dilation: int,
19 ceil_mode: bool = False,
20) -> int:
21 """Compute one spatial dimension of the 3-D max-pool output."""
22 effective_kernel_size = (kernel_size - 1) * dilation + 1
23 numerator = in_size + 2 * padding - effective_kernel_size
24 if ceil_mode:
25 output_size = (numerator + stride - 1) // stride + 1
26 # PyTorch-compatible adjustment for ceil_mode
27 if (output_size - 1) * stride >= in_size + padding:
28 output_size -= 1
29 else:
30 output_size = numerator // stride + 1
31 return output_size
34@libentry()
35@triton.autotune(
36 configs=[
37 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4),
38 triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4),
39 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4),
40 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8),
41 triton.Config({"BLOCK_H": 8, "BLOCK_W": 8}, num_stages=5, num_warps=2),
42 triton.Config({"BLOCK_H": 16, "BLOCK_W": 8}, num_stages=5, num_warps=2),
43 triton.Config({"BLOCK_H": 8, "BLOCK_W": 16}, num_stages=5, num_warps=2),
44 triton.Config({"BLOCK_H": 64, "BLOCK_W": 16}, num_stages=2, num_warps=8),
45 triton.Config({"BLOCK_H": 16, "BLOCK_W": 64}, num_stages=2, num_warps=8),
46 triton.Config({"BLOCK_H": 32, "BLOCK_W": 64}, num_stages=3, num_warps=8),
47 triton.Config({"BLOCK_H": 64, "BLOCK_W": 32}, num_stages=3, num_warps=8),
48 triton.Config({"BLOCK_H": 64, "BLOCK_W": 64}, num_stages=2, num_warps=8),
49 ],
50 key=[
51 "out_d",
52 "out_h",
53 "out_w",
54 "kernel_d",
55 "kernel_h",
56 "kernel_w",
57 "stride_d",
58 "stride_h",
59 "stride_w",
60 ],
61)
62@triton.jit
63def max_pool3d_forward_kernel(
64 input_ptr,
65 output_ptr,
66 indices_ptr,
67 # Input tensor strides
68 in_stride_n,
69 in_stride_c,
70 in_stride_d,
71 in_stride_h,
72 in_stride_w,
73 # Input/Output shapes
74 in_c,
75 in_d,
76 in_h,
77 in_w,
78 out_d,
79 out_h,
80 out_w,
81 # Pooling parameters
82 kernel_d: tl.constexpr,
83 kernel_h: tl.constexpr,
84 kernel_w: tl.constexpr,
85 stride_d: tl.constexpr,
86 stride_h: tl.constexpr,
87 stride_w: tl.constexpr,
88 padding_d: tl.constexpr,
89 padding_h: tl.constexpr,
90 padding_w: tl.constexpr,
91 dilation_d: tl.constexpr,
92 dilation_h: tl.constexpr,
93 dilation_w: tl.constexpr,
94 # Meta-parameters for tiling
95 BLOCK_H: tl.constexpr,
96 BLOCK_W: tl.constexpr,
97):
98 """Forward kernel for 3-D max pooling.
100 Grid: (N * C, num_d_blocks * num_h_blocks * num_w_blocks)
101 where num_h_blocks = cdiv(out_h, BLOCK_H),
102 num_w_blocks = cdiv(out_w, BLOCK_W).
103 The depth dimension is iterated inside the kernel via a loop
104 over num_d_blocks output depth positions.
105 """
106 pid_nc = tl.program_id(0)
107 pid_dhw = tl.program_id(1)
109 num_h_blocks = tl.cdiv(out_h, BLOCK_H)
110 num_w_blocks = tl.cdiv(out_w, BLOCK_W)
112 d_block_idx = pid_dhw // (num_h_blocks * num_w_blocks)
113 hw_remainder = pid_dhw % (num_h_blocks * num_w_blocks)
114 h_block_idx = hw_remainder // num_w_blocks
115 w_block_idx = hw_remainder % num_w_blocks
117 n_idx = pid_nc // in_c
118 c_idx = pid_nc % in_c
120 d_out = d_block_idx
122 h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
123 w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
125 dtype = input_ptr.type.element_ty
126 min_val = get_dtype_min(dtype)
127 max_val_acc = tl.full((BLOCK_H, BLOCK_W), min_val, dtype=dtype)
128 max_idx_acc = tl.full((BLOCK_H, BLOCK_W), -1, dtype=tl.int64)
130 input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c
132 for kd in tl.static_range(0, kernel_d):
133 d_in = d_out * stride_d - padding_d + kd * dilation_d
134 d_valid = (d_in >= 0) & (d_in < in_d)
135 for kh in tl.static_range(0, kernel_h):
136 for kw in tl.static_range(0, kernel_w):
137 h_in = h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h
138 w_in = w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w
139 in_mask = (
140 d_valid & (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w)
141 )
142 input_offset = (
143 d_in * in_stride_d + h_in * in_stride_h + w_in * in_stride_w
144 )
145 current_val = tl.load(
146 input_base_ptr + input_offset, mask=in_mask, other=min_val
147 )
148 # Flat index in (D, H, W) space
149 current_idx = d_in * in_h * in_w + h_in * in_w + w_in
151 is_new_max = current_val > max_val_acc
152 max_val_acc = tl.where(is_new_max, current_val, max_val_acc)
153 max_idx_acc = tl.where(is_new_max & in_mask, current_idx, max_idx_acc)
155 out_spatial = out_h * out_w
156 out_base_offset = pid_nc * out_d * out_spatial + d_out * out_spatial
157 out_base_ptr = output_ptr + out_base_offset
158 indices_base_ptr = indices_ptr + out_base_offset
159 out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
160 out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
161 output_block_ptr = (
162 out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :]
163 )
164 indices_block_ptr = (
165 indices_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :]
166 )
168 out_mask = (out_h_offsets[:, None] < out_h) & (out_w_offsets[None, :] < out_w)
169 tl.store(output_block_ptr, max_val_acc, mask=out_mask)
170 tl.store(indices_block_ptr, max_idx_acc, mask=out_mask)
173@libentry()
174@triton.autotune(
175 configs=[
176 triton.Config({"BLOCK_IN_H": 16, "BLOCK_IN_W": 16}, num_warps=4),
177 triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 8}, num_warps=4),
178 triton.Config({"BLOCK_IN_H": 8, "BLOCK_IN_W": 32}, num_warps=4),
179 triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 32}, num_warps=8),
180 triton.Config({"BLOCK_IN_H": 16, "BLOCK_IN_W": 64}, num_warps=8),
181 triton.Config({"BLOCK_IN_H": 64, "BLOCK_IN_W": 16}, num_warps=8),
182 ],
183 key=[
184 "in_d",
185 "in_h",
186 "in_w",
187 "kernel_d",
188 "kernel_h",
189 "kernel_w",
190 "stride_d",
191 "stride_h",
192 "stride_w",
193 ],
194)
195@triton.jit
196def max_pool3d_backward_kernel(
197 grad_output_ptr,
198 indices_ptr,
199 grad_input_ptr,
200 # Shape info
201 in_d,
202 in_h,
203 in_w,
204 out_d,
205 out_h,
206 out_w,
207 # Strides for grad_output/indices (contiguous layout: NC, D, H, W)
208 out_stride_nc,
209 out_stride_d,
210 out_stride_h,
211 out_stride_w,
212 # Pooling parameters
213 kernel_d: tl.constexpr,
214 kernel_h: tl.constexpr,
215 kernel_w: tl.constexpr,
216 stride_d: tl.constexpr,
217 stride_h: tl.constexpr,
218 stride_w: tl.constexpr,
219 padding_d: tl.constexpr,
220 padding_h: tl.constexpr,
221 padding_w: tl.constexpr,
222 dilation_d: tl.constexpr,
223 dilation_h: tl.constexpr,
224 dilation_w: tl.constexpr,
225 # Tiling parameters
226 BLOCK_IN_H: tl.constexpr,
227 BLOCK_IN_W: tl.constexpr,
228):
229 """Backward kernel for 3-D max pooling.
231 Grid: (N * C, num_d_in * num_h_blocks * num_w_blocks)
232 For each input (d, h, w) position, iterate over all kernel
233 offsets to find which output positions could have selected it,
234 then accumulate the gradient.
235 """
236 nc_idx = tl.program_id(0)
237 pid_dhw = tl.program_id(1)
239 num_h_blocks = tl.cdiv(in_h, BLOCK_IN_H)
240 num_w_blocks = tl.cdiv(in_w, BLOCK_IN_W)
242 d_in_idx = pid_dhw // (num_h_blocks * num_w_blocks)
243 hw_remainder = pid_dhw % (num_h_blocks * num_w_blocks)
244 h_block_idx = hw_remainder // num_w_blocks
245 w_block_idx = hw_remainder % num_w_blocks
247 h_in_offsets = h_block_idx * BLOCK_IN_H + tl.arange(0, BLOCK_IN_H)
248 w_in_offsets = w_block_idx * BLOCK_IN_W + tl.arange(0, BLOCK_IN_W)
250 # Flat index of current input position in (D, H, W) space
251 current_input_flat_idx = (
252 d_in_idx * in_h * in_w + h_in_offsets[:, None] * in_w + w_in_offsets[None, :]
253 )
254 grad_acc = tl.zeros((BLOCK_IN_H, BLOCK_IN_W), dtype=tl.float32)
256 indices_base_ptr = indices_ptr + nc_idx * out_stride_nc
257 grad_output_base_ptr = grad_output_ptr + nc_idx * out_stride_nc
259 for kd in tl.static_range(0, kernel_d):
260 numerator_d = d_in_idx + padding_d - kd * dilation_d
261 valid_d = numerator_d % stride_d == 0
262 d_out = numerator_d // stride_d
263 d_bounds = (d_out >= 0) & (d_out < out_d)
264 d_valid = valid_d & d_bounds
266 for kh in tl.static_range(0, kernel_h):
267 for kw in tl.static_range(0, kernel_w):
268 numerator_h = h_in_offsets[:, None] + padding_h - kh * dilation_h
269 numerator_w = w_in_offsets[None, :] + padding_w - kw * dilation_w
271 valid_map_mask = (
272 d_valid
273 & (numerator_h % stride_h == 0)
274 & (numerator_w % stride_w == 0)
275 )
276 h_out = numerator_h // stride_h
277 w_out = numerator_w // stride_w
278 out_bounds_mask = (
279 (h_out >= 0) & (h_out < out_h) & (w_out >= 0) & (w_out < out_w)
280 )
281 load_mask = valid_map_mask & out_bounds_mask
283 safe_h_out = tl.where(load_mask, h_out, 0)
284 safe_w_out = tl.where(load_mask, w_out, 0)
285 safe_d_out = tl.where(load_mask, d_out, 0)
286 out_offsets = (
287 safe_d_out * out_stride_d + safe_h_out * out_stride_h + safe_w_out
288 )
290 indices_block = tl.load(
291 indices_base_ptr + out_offsets, mask=load_mask, other=-1
292 )
293 match_mask = indices_block == current_input_flat_idx
295 grad_block = tl.load(
296 grad_output_base_ptr + out_offsets,
297 mask=match_mask,
298 other=0.0,
299 )
300 grad_acc += grad_block
302 in_spatial = in_h * in_w
303 grad_input_base_ptr = grad_input_ptr + nc_idx * in_d * in_spatial
304 grad_input_offsets = (
305 d_in_idx * in_spatial + h_in_offsets[:, None] * in_w + w_in_offsets[None, :]
306 )
307 store_mask = (h_in_offsets[:, None] < in_h) & (w_in_offsets[None, :] < in_w)
308 tl.store(grad_input_base_ptr + grad_input_offsets, grad_acc, mask=store_mask)
311def _parse_pool3d_params(kernel_size, stride, padding, dilation):
312 """Parse and validate 3-D pooling parameters.
314 Each parameter can be an int (applied to all 3 spatial dims) or a
315 3-element tuple/list (D, H, W).
316 """
318 def _parse_param(param, name, default=None):
319 if param is None:
320 return default
321 if isinstance(param, int):
322 return param, param, param
323 if isinstance(param, (list, tuple)) and len(param) == 3:
324 return tuple(param)
325 raise ValueError(f"Invalid {name}: {param}")
327 kd, kh, kw = _parse_param(kernel_size, "kernel_size")
328 sd, sh, sw = _parse_param(stride, "stride", default=(kd, kh, kw))
329 pd, ph, pw = _parse_param(padding, "padding", default=(0, 0, 0))
330 dd, dh, dw = _parse_param(dilation, "dilation", default=(1, 1, 1))
332 if sd <= 0 or sh <= 0 or sw <= 0:
333 raise ValueError(f"stride must be positive, but got stride=({sd}, {sh}, {sw})")
334 if pd < 0 or ph < 0 or pw < 0:
335 raise ValueError(
336 f"padding must be non-negative, but got padding=({pd}, {ph}, {pw})"
337 )
338 if dd <= 0 or dh <= 0 or dw <= 0:
339 raise ValueError(
340 f"dilation must be positive, but got dilation=({dd}, {dh}, {dw})"
341 )
343 return kd, kh, kw, sd, sh, sw, pd, ph, pw, dd, dh, dw
346def max_pool3d_with_indices(
347 input: torch.Tensor,
348 kernel_size,
349 stride=None,
350 padding=0,
351 dilation=1,
352 ceil_mode=False,
353):
354 """Compute 3-D max pooling, returning (output, indices).
356 Indices are flat offsets into the (D, H, W) spatial volume of the input.
357 """
358 logger.debug("GEMS MAX_POOL3D_WITH_INDICES")
359 input = input.contiguous()
361 params = _parse_pool3d_params(kernel_size, stride, padding, dilation)
362 kd, kh, kw, sd, sh, sw, pd, ph, pw, dd, dh, dw = params
364 in_n, in_c, in_d, in_h, in_w = input.shape
365 out_d = pool3d_output_size(in_d, kd, sd, pd, dd, ceil_mode)
366 out_h = pool3d_output_size(in_h, kh, sh, ph, dh, ceil_mode)
367 out_w = pool3d_output_size(in_w, kw, sw, pw, dw, ceil_mode)
369 output = torch.empty(
370 (in_n, in_c, out_d, out_h, out_w), device=input.device, dtype=input.dtype
371 )
372 indices = torch.empty(
373 (in_n, in_c, out_d, out_h, out_w), device=input.device, dtype=torch.int64
374 )
376 if output.numel() == 0:
377 return output, indices
379 grid = lambda meta: (
380 in_n * in_c,
381 out_d
382 * triton.cdiv(out_h, meta["BLOCK_H"])
383 * triton.cdiv(out_w, meta["BLOCK_W"]),
384 )
386 max_pool3d_forward_kernel[grid](
387 input,
388 output,
389 indices,
390 input.stride(0),
391 input.stride(1),
392 input.stride(2),
393 input.stride(3),
394 input.stride(4),
395 in_c,
396 in_d,
397 in_h,
398 in_w,
399 out_d,
400 out_h,
401 out_w,
402 kd,
403 kh,
404 kw,
405 sd,
406 sh,
407 sw,
408 pd,
409 ph,
410 pw,
411 dd,
412 dh,
413 dw,
414 )
416 return output, indices
419def max_pool3d_backward(
420 grad_output: torch.Tensor,
421 input: torch.Tensor,
422 indices: torch.Tensor,
423 kernel_size,
424 stride,
425 padding,
426 dilation,
427 ceil_mode,
428):
429 """Backward pass for 3-D max pooling."""
430 logger.debug("GEMS MAX_POOL3D BACKWARD")
431 grad_output = grad_output.contiguous()
432 indices = indices.contiguous()
434 params = _parse_pool3d_params(kernel_size, stride, padding, dilation)
435 kd, kh, kw, sd, sh, sw, pd, ph, pw, dd, dh, dw = params
437 in_n, in_c, in_d, in_h, in_w = input.shape
438 out_d, out_h, out_w = (
439 grad_output.shape[2],
440 grad_output.shape[3],
441 grad_output.shape[4],
442 )
444 grad_input = torch.zeros_like(input, dtype=torch.float32)
446 if grad_input.numel() == 0:
447 return grad_input.to(grad_output.dtype)
449 out_spatial = out_h * out_w
450 out_stride_nc = out_d * out_spatial
451 out_stride_d = out_spatial
452 out_stride_h = out_w
453 out_stride_w = 1
455 grid = lambda meta: (
456 in_n * in_c,
457 in_d
458 * triton.cdiv(in_h, meta["BLOCK_IN_H"])
459 * triton.cdiv(in_w, meta["BLOCK_IN_W"]),
460 )
462 max_pool3d_backward_kernel[grid](
463 grad_output,
464 indices,
465 grad_input,
466 in_d,
467 in_h,
468 in_w,
469 out_d,
470 out_h,
471 out_w,
472 out_stride_nc,
473 out_stride_d,
474 out_stride_h,
475 out_stride_w,
476 kd,
477 kh,
478 kw,
479 sd,
480 sh,
481 sw,
482 pd,
483 ph,
484 pw,
485 dd,
486 dh,
487 dw,
488 )
490 return grad_input.to(grad_output.dtype)