Coverage for src/flag_gems/ops/avg_pool3d.py: 34%
183 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import libentry
9logger = logging.getLogger(__name__)
12def pool3d_output_size(
13 in_size: int,
14 kernel_size: int,
15 stride: int,
16 padding: int,
17 dilation: int,
18 ceil_mode: bool = False,
19) -> int:
20 """Compute the output size for one spatial dimension of a 3D pooling operation."""
21 effective_kernel_size = (kernel_size - 1) * dilation + 1
22 numerator = in_size + 2 * padding - effective_kernel_size
23 if ceil_mode:
24 output_size = (numerator + stride - 1) // stride + 1
25 if (output_size - 1) * stride >= in_size + padding:
26 output_size -= 1
27 else:
28 output_size = numerator // stride + 1
30 return output_size
33@libentry()
34@triton.autotune(
35 configs=[
36 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4),
37 triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4),
38 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4),
39 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8),
40 triton.Config({"BLOCK_H": 8, "BLOCK_W": 8}, num_stages=5, num_warps=2),
41 triton.Config({"BLOCK_H": 8, "BLOCK_W": 16}, num_stages=5, num_warps=2),
42 triton.Config({"BLOCK_H": 16, "BLOCK_W": 8}, num_stages=5, num_warps=2),
43 triton.Config({"BLOCK_H": 64, "BLOCK_W": 16}, num_stages=2, num_warps=8),
44 triton.Config({"BLOCK_H": 16, "BLOCK_W": 64}, num_stages=2, num_warps=8),
45 ],
46 key=["out_d", "out_h", "out_w", "kernel_d", "kernel_h", "kernel_w"],
47)
48@triton.jit
49def avg_pool3d_forward_kernel(
50 input_ptr,
51 output_ptr,
52 # Input tensor strides
53 in_stride_n,
54 in_stride_c,
55 in_stride_d,
56 in_stride_h,
57 in_stride_w,
58 # Input/Output shapes
59 in_c,
60 in_d,
61 in_h,
62 in_w,
63 out_d,
64 out_h,
65 out_w,
66 # Pooling parameters
67 kernel_d: tl.constexpr,
68 kernel_h: tl.constexpr,
69 kernel_w: tl.constexpr,
70 stride_d: tl.constexpr,
71 stride_h: tl.constexpr,
72 stride_w: tl.constexpr,
73 padding_d: tl.constexpr,
74 padding_h: tl.constexpr,
75 padding_w: tl.constexpr,
76 dilation_d: tl.constexpr,
77 dilation_h: tl.constexpr,
78 dilation_w: tl.constexpr,
79 # AvgPool specific parameters
80 COUNT_INCLUDE_PAD: tl.constexpr,
81 divisor_override,
82 # Tiling meta-parameters
83 BLOCK_H: tl.constexpr,
84 BLOCK_W: tl.constexpr,
85):
86 # Grid: (N*C, out_d * cdiv(out_h, BLOCK_H) * cdiv(out_w, BLOCK_W))
87 pid_nc = tl.program_id(0)
88 pid_dhw = tl.program_id(1)
90 num_w_blocks = tl.cdiv(out_w, BLOCK_W)
91 num_h_blocks = tl.cdiv(out_h, BLOCK_H)
92 num_hw_blocks = num_h_blocks * num_w_blocks
94 # Decompose pid_dhw into d_idx, h_block_idx, w_block_idx
95 d_idx = pid_dhw // num_hw_blocks
96 hw_remainder = pid_dhw % num_hw_blocks
97 h_block_idx = hw_remainder // num_w_blocks
98 w_block_idx = hw_remainder % num_w_blocks
100 n_idx = pid_nc // in_c
101 c_idx = pid_nc % in_c
103 h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
104 w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
106 sum_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32)
107 count_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32)
109 input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c
111 for kd in range(0, kernel_d):
112 d_in = d_idx * stride_d - padding_d + kd * dilation_d
113 d_valid = (d_in >= 0) & (d_in < in_d)
114 for kh in range(0, kernel_h):
115 for kw in range(0, kernel_w):
116 h_in = h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h
117 w_in = w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w
118 hw_mask = (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w)
119 in_mask = hw_mask & d_valid
121 input_offset = (
122 d_in * in_stride_d + h_in * in_stride_h + w_in * in_stride_w
123 )
124 current_val = tl.load(
125 input_base_ptr + input_offset, mask=in_mask, other=0.0
126 )
128 sum_acc += tl.where(in_mask, current_val, 0.0)
129 count_acc += in_mask.to(tl.int32)
131 if divisor_override != 0:
132 divisor = tl.full((BLOCK_H, BLOCK_W), divisor_override, dtype=tl.float32)
133 elif COUNT_INCLUDE_PAD:
134 # Count positions within padded boundary (correct for ceil_mode edges)
135 d_start_fwd = d_idx * stride_d - padding_d
136 d_padded_count = tl.minimum(d_start_fwd + kernel_d, in_d + padding_d) - (
137 tl.maximum(d_start_fwd, -padding_d)
138 )
139 d_padded_count = tl.maximum(d_padded_count, 0)
141 h_start_fwd = h_out_offsets[:, None] * stride_h - padding_h
142 h_padded_count = tl.minimum(h_start_fwd + kernel_h, in_h + padding_h) - (
143 tl.maximum(h_start_fwd, -padding_h)
144 )
145 h_padded_count = tl.maximum(h_padded_count, 0)
147 w_start_fwd = w_out_offsets[None, :] * stride_w - padding_w
148 w_padded_count = tl.minimum(w_start_fwd + kernel_w, in_w + padding_w) - (
149 tl.maximum(w_start_fwd, -padding_w)
150 )
151 w_padded_count = tl.maximum(w_padded_count, 0)
153 divisor = (d_padded_count * h_padded_count * w_padded_count).to(tl.float32)
154 else:
155 divisor = count_acc.to(tl.float32)
157 output_vals = tl.where(divisor != 0, sum_acc / divisor, 0.0)
159 out_base_ptr = output_ptr + pid_nc * out_d * out_h * out_w + d_idx * out_h * out_w
160 out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
161 out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
162 output_block_ptr = (
163 out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :]
164 )
166 out_mask = (out_h_offsets[:, None] < out_h) & (out_w_offsets[None, :] < out_w)
167 tl.store(
168 output_block_ptr, output_vals.to(output_ptr.type.element_ty), mask=out_mask
169 )
172@libentry()
173@triton.autotune(
174 configs=[
175 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4),
176 triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4),
177 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4),
178 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8),
179 triton.Config({"BLOCK_H": 64, "BLOCK_W": 32}, num_stages=2, num_warps=8),
180 triton.Config({"BLOCK_H": 32, "BLOCK_W": 64}, num_stages=2, num_warps=8),
181 ],
182 key=["in_h", "in_w", "kernel_d", "kernel_h", "kernel_w"],
183)
184@triton.jit
185def avg_pool3d_backward_kernel(
186 grad_output_ptr,
187 grad_input_ptr,
188 # Input/Output shapes
189 in_c,
190 in_d,
191 in_h,
192 in_w,
193 out_d,
194 out_h,
195 out_w,
196 # Strides for grad_input
197 in_stride_n,
198 in_stride_c,
199 in_stride_d,
200 in_stride_h,
201 in_stride_w,
202 # Strides for grad_output
203 out_stride_n,
204 out_stride_c,
205 out_stride_d,
206 out_stride_h,
207 out_stride_w,
208 # Pooling parameters
209 kernel_d: tl.constexpr,
210 kernel_h: tl.constexpr,
211 kernel_w: tl.constexpr,
212 stride_d: tl.constexpr,
213 stride_h: tl.constexpr,
214 stride_w: tl.constexpr,
215 padding_d: tl.constexpr,
216 padding_h: tl.constexpr,
217 padding_w: tl.constexpr,
218 # AvgPool specific parameters
219 COUNT_INCLUDE_PAD: tl.constexpr,
220 divisor_override,
221 # Tiling meta-parameters
222 BLOCK_H: tl.constexpr,
223 BLOCK_W: tl.constexpr,
224):
225 # Input-centric backward: iterate over input positions, gather from output.
226 # Uses tl.store (not atomic_add), safe with autotune.
227 # Grid: (N*C, in_d * cdiv(in_h, BLOCK_H) * cdiv(in_w, BLOCK_W))
228 pid_nc = tl.program_id(0)
229 pid_dhw = tl.program_id(1)
231 num_w_blocks = tl.cdiv(in_w, BLOCK_W)
232 num_h_blocks = tl.cdiv(in_h, BLOCK_H)
233 num_hw_blocks = num_h_blocks * num_w_blocks
235 d_in_idx = pid_dhw // num_hw_blocks
236 hw_remainder = pid_dhw % num_hw_blocks
237 h_block_idx = hw_remainder // num_w_blocks
238 w_block_idx = hw_remainder % num_w_blocks
240 n_idx = pid_nc // in_c
241 c_idx = pid_nc % in_c
243 grad_input_base = grad_input_ptr + n_idx * in_stride_n + c_idx * in_stride_c
244 grad_output_base = grad_output_ptr + n_idx * out_stride_n + c_idx * out_stride_c
246 h_in_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
247 w_in_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
249 grad_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32)
251 for kd in range(kernel_d):
252 d_out_num = d_in_idx + padding_d - kd
253 d_out_valid = (d_out_num >= 0) & ((d_out_num % stride_d) == 0)
254 d_out = d_out_num // stride_d
255 d_out_valid = d_out_valid & (d_out >= 0) & (d_out < out_d)
257 for kh in range(kernel_h):
258 for kw in range(kernel_w):
259 h_out_num = h_in_offsets[:, None] + padding_h - kh
260 w_out_num = w_in_offsets[None, :] + padding_w - kw
262 h_valid = (h_out_num >= 0) & ((h_out_num % stride_h) == 0)
263 w_valid = (w_out_num >= 0) & ((w_out_num % stride_w) == 0)
265 h_out = h_out_num // stride_h
266 w_out = w_out_num // stride_w
268 out_mask = (
269 d_out_valid & h_valid & w_valid & (h_out < out_h) & (w_out < out_w)
270 )
272 if divisor_override != 0:
273 divisor = tl.full(
274 (BLOCK_H, BLOCK_W), divisor_override, dtype=tl.float32
275 )
276 elif COUNT_INCLUDE_PAD:
277 # Count positions within padded boundary (ceil_mode)
278 d_start_bwd = d_out * stride_d - padding_d
279 d_pc = tl.minimum(
280 d_start_bwd + kernel_d, in_d + padding_d
281 ) - tl.maximum(d_start_bwd, -padding_d)
282 d_pc = tl.maximum(d_pc, 0)
284 h_start_bwd = h_out * stride_h - padding_h
285 h_pc = tl.minimum(
286 h_start_bwd + kernel_h, in_h + padding_h
287 ) - tl.maximum(h_start_bwd, -padding_h)
288 h_pc = tl.maximum(h_pc, 0)
290 w_start_bwd = w_out * stride_w - padding_w
291 w_pc = tl.minimum(
292 w_start_bwd + kernel_w, in_w + padding_w
293 ) - tl.maximum(w_start_bwd, -padding_w)
294 w_pc = tl.maximum(w_pc, 0)
296 divisor = (d_pc * h_pc * w_pc).to(tl.float32)
297 else:
298 d_start = d_out * stride_d - padding_d
299 d_count = tl.minimum(d_start + kernel_d, in_d) - tl.maximum(
300 d_start, 0
301 )
302 d_count = tl.maximum(d_count, 0)
304 h_start = h_out * stride_h - padding_h
305 h_count = tl.minimum(h_start + kernel_h, in_h) - tl.maximum(
306 h_start, 0
307 )
308 h_count = tl.maximum(h_count, 0)
310 w_start = w_out * stride_w - padding_w
311 w_count = tl.minimum(w_start + kernel_w, in_w) - tl.maximum(
312 w_start, 0
313 )
314 w_count = tl.maximum(w_count, 0)
316 divisor = (d_count * h_count * w_count).to(tl.float32)
318 divisor = tl.where(divisor == 0, 1.0, divisor)
320 grad_out_ptr = (
321 grad_output_base
322 + d_out * out_stride_d
323 + h_out * out_stride_h
324 + w_out * out_stride_w
325 )
326 grad_out_val = tl.load(grad_out_ptr, mask=out_mask, other=0.0)
327 grad_acc += tl.where(out_mask, grad_out_val / divisor, 0.0)
329 grad_input_store_ptr = (
330 grad_input_base
331 + d_in_idx * in_stride_d
332 + h_in_offsets[:, None] * in_stride_h
333 + w_in_offsets[None, :] * in_stride_w
334 )
335 in_write_mask = (h_in_offsets[:, None] < in_h) & (w_in_offsets[None, :] < in_w)
336 tl.store(
337 grad_input_store_ptr,
338 grad_acc.to(grad_input_ptr.type.element_ty),
339 mask=in_write_mask,
340 )
343def _parse_pool3d_params(kernel_size, stride, padding):
344 """Parse and validate 3D pooling parameters."""
345 if isinstance(kernel_size, int):
346 kernel_d = kernel_h = kernel_w = kernel_size
347 else:
348 kernel_d, kernel_h, kernel_w = kernel_size
350 if stride is None or (isinstance(stride, (list, tuple)) and not stride):
351 stride_d, stride_h, stride_w = kernel_d, kernel_h, kernel_w
352 elif isinstance(stride, int):
353 stride_d = stride_h = stride_w = stride
354 else:
355 stride_d, stride_h, stride_w = stride
357 if isinstance(padding, int):
358 padding_d = padding_h = padding_w = padding
359 else:
360 padding_d, padding_h, padding_w = padding
362 if stride_d <= 0 or stride_h <= 0 or stride_w <= 0:
363 raise ValueError("stride must be greater than zero")
365 if padding_d < 0 or padding_h < 0 or padding_w < 0:
366 raise ValueError("padding must be non-negative")
368 if (
369 padding_d > kernel_d // 2
370 or padding_h > kernel_h // 2
371 or padding_w > kernel_w // 2
372 ):
373 raise ValueError("pad should be smaller than or equal to half of kernel size")
375 return (
376 kernel_d,
377 kernel_h,
378 kernel_w,
379 stride_d,
380 stride_h,
381 stride_w,
382 padding_d,
383 padding_h,
384 padding_w,
385 )
388def avg_pool3d(
389 input: torch.Tensor,
390 kernel_size,
391 stride=None,
392 padding=0,
393 ceil_mode=False,
394 count_include_pad=True,
395 divisor_override=None,
396):
397 """Compute 3D average pooling over an input signal composed of several input
398 planes.
400 Args:
401 input: 5D tensor of shape (N, C, D, H, W).
402 kernel_size: Size of the pooling window. Can be int or (kD, kH, kW).
403 stride: Stride of the pooling window. Default: kernel_size.
404 padding: Implicit zero padding on both sides. Default: 0.
405 ceil_mode: Use ceil instead of floor to compute output shape. Default: False.
406 count_include_pad: Include zero-padding in the averaging calculation.
407 Default: True.
408 divisor_override: If specified, use this as the divisor instead of the
409 pool size. Default: None.
411 Returns:
412 5D tensor of shape (N, C, D_out, H_out, W_out).
413 """
414 logger.debug("GEMS AVG_POOL3D FORWARD")
416 if divisor_override is not None and divisor_override == 0:
417 raise ValueError("divisor_override cannot be zero")
419 input = input.contiguous()
421 (
422 kernel_d,
423 kernel_h,
424 kernel_w,
425 stride_d,
426 stride_h,
427 stride_w,
428 padding_d,
429 padding_h,
430 padding_w,
431 ) = _parse_pool3d_params(kernel_size, stride, padding)
432 dilation_d, dilation_h, dilation_w = 1, 1, 1
434 in_n, in_c, in_d, in_h, in_w = input.shape
436 out_d = pool3d_output_size(
437 in_d, kernel_d, stride_d, padding_d, dilation_d, ceil_mode
438 )
439 out_h = pool3d_output_size(
440 in_h, kernel_h, stride_h, padding_h, dilation_h, ceil_mode
441 )
442 out_w = pool3d_output_size(
443 in_w, kernel_w, stride_w, padding_w, dilation_w, ceil_mode
444 )
446 output = torch.empty(
447 (in_n, in_c, out_d, out_h, out_w), device=input.device, dtype=input.dtype
448 )
450 if output.numel() == 0:
451 return output
453 grid = lambda meta: (
454 in_n * in_c,
455 out_d
456 * triton.cdiv(out_h, meta["BLOCK_H"])
457 * triton.cdiv(out_w, meta["BLOCK_W"]),
458 )
460 avg_pool3d_forward_kernel[grid](
461 input,
462 output,
463 input.stride(0),
464 input.stride(1),
465 input.stride(2),
466 input.stride(3),
467 input.stride(4),
468 in_c,
469 in_d,
470 in_h,
471 in_w,
472 out_d,
473 out_h,
474 out_w,
475 kernel_d,
476 kernel_h,
477 kernel_w,
478 stride_d,
479 stride_h,
480 stride_w,
481 padding_d,
482 padding_h,
483 padding_w,
484 dilation_d,
485 dilation_h,
486 dilation_w,
487 COUNT_INCLUDE_PAD=count_include_pad,
488 divisor_override=divisor_override if divisor_override is not None else 0.0,
489 )
491 return output
494def avg_pool3d_backward(
495 grad_output: torch.Tensor,
496 input: torch.Tensor,
497 kernel_size,
498 stride,
499 padding,
500 ceil_mode,
501 count_include_pad,
502 divisor_override,
503):
504 """Compute the gradient of avg_pool3d.
506 Args:
507 grad_output: Gradient of the output tensor.
508 input: Original input tensor (used for shape information).
509 kernel_size: Size of the pooling window.
510 stride: Stride of the pooling window.
511 padding: Implicit zero padding.
512 ceil_mode: Whether ceil was used for output shape.
513 count_include_pad: Whether padding was included in averaging.
514 divisor_override: Custom divisor override.
516 Returns:
517 Gradient with respect to the input tensor.
518 """
519 logger.debug("GEMS AVG_POOL3D BACKWARD")
521 if divisor_override is not None and divisor_override == 0:
522 raise ValueError("divisor_override cannot be zero")
524 grad_output = grad_output.contiguous()
526 (
527 kernel_d,
528 kernel_h,
529 kernel_w,
530 stride_d,
531 stride_h,
532 stride_w,
533 padding_d,
534 padding_h,
535 padding_w,
536 ) = _parse_pool3d_params(kernel_size, stride, padding)
538 in_n, in_c, in_d, in_h, in_w = input.shape
539 out_d, out_h, out_w = (
540 grad_output.shape[2],
541 grad_output.shape[3],
542 grad_output.shape[4],
543 )
545 grad_input = torch.empty_like(input)
547 if grad_output.numel() == 0:
548 return grad_input.zero_()
550 # Input-centric grid: iterate over input positions
551 grid = lambda meta: (
552 in_n * in_c,
553 in_d * triton.cdiv(in_h, meta["BLOCK_H"]) * triton.cdiv(in_w, meta["BLOCK_W"]),
554 )
556 avg_pool3d_backward_kernel[grid](
557 grad_output,
558 grad_input,
559 in_c,
560 in_d,
561 in_h,
562 in_w,
563 out_d,
564 out_h,
565 out_w,
566 grad_input.stride(0),
567 grad_input.stride(1),
568 grad_input.stride(2),
569 grad_input.stride(3),
570 grad_input.stride(4),
571 grad_output.stride(0),
572 grad_output.stride(1),
573 grad_output.stride(2),
574 grad_output.stride(3),
575 grad_output.stride(4),
576 kernel_d,
577 kernel_h,
578 kernel_w,
579 stride_d,
580 stride_h,
581 stride_w,
582 padding_d,
583 padding_h,
584 padding_w,
585 COUNT_INCLUDE_PAD=count_include_pad,
586 divisor_override=divisor_override if divisor_override is not None else 0.0,
587 )
589 return grad_input