Coverage for src/flag_gems/ops/avg_pool3d.py: 34%

183 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import libentry 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12def pool3d_output_size( 

13 in_size: int, 

14 kernel_size: int, 

15 stride: int, 

16 padding: int, 

17 dilation: int, 

18 ceil_mode: bool = False, 

19) -> int: 

20 """Compute the output size for one spatial dimension of a 3D pooling operation.""" 

21 effective_kernel_size = (kernel_size - 1) * dilation + 1 

22 numerator = in_size + 2 * padding - effective_kernel_size 

23 if ceil_mode: 

24 output_size = (numerator + stride - 1) // stride + 1 

25 if (output_size - 1) * stride >= in_size + padding: 

26 output_size -= 1 

27 else: 

28 output_size = numerator // stride + 1 

29 

30 return output_size 

31 

32 

33@libentry() 

34@triton.autotune( 

35 configs=[ 

36 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4), 

37 triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4), 

38 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4), 

39 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8), 

40 triton.Config({"BLOCK_H": 8, "BLOCK_W": 8}, num_stages=5, num_warps=2), 

41 triton.Config({"BLOCK_H": 8, "BLOCK_W": 16}, num_stages=5, num_warps=2), 

42 triton.Config({"BLOCK_H": 16, "BLOCK_W": 8}, num_stages=5, num_warps=2), 

43 triton.Config({"BLOCK_H": 64, "BLOCK_W": 16}, num_stages=2, num_warps=8), 

44 triton.Config({"BLOCK_H": 16, "BLOCK_W": 64}, num_stages=2, num_warps=8), 

45 ], 

46 key=["out_d", "out_h", "out_w", "kernel_d", "kernel_h", "kernel_w"], 

47) 

48@triton.jit 

49def avg_pool3d_forward_kernel( 

50 input_ptr, 

51 output_ptr, 

52 # Input tensor strides 

53 in_stride_n, 

54 in_stride_c, 

55 in_stride_d, 

56 in_stride_h, 

57 in_stride_w, 

58 # Input/Output shapes 

59 in_c, 

60 in_d, 

61 in_h, 

62 in_w, 

63 out_d, 

64 out_h, 

65 out_w, 

66 # Pooling parameters 

67 kernel_d: tl.constexpr, 

68 kernel_h: tl.constexpr, 

69 kernel_w: tl.constexpr, 

70 stride_d: tl.constexpr, 

71 stride_h: tl.constexpr, 

72 stride_w: tl.constexpr, 

73 padding_d: tl.constexpr, 

74 padding_h: tl.constexpr, 

75 padding_w: tl.constexpr, 

76 dilation_d: tl.constexpr, 

77 dilation_h: tl.constexpr, 

78 dilation_w: tl.constexpr, 

79 # AvgPool specific parameters 

80 COUNT_INCLUDE_PAD: tl.constexpr, 

81 divisor_override, 

82 # Tiling meta-parameters 

83 BLOCK_H: tl.constexpr, 

84 BLOCK_W: tl.constexpr, 

85): 

86 # Grid: (N*C, out_d * cdiv(out_h, BLOCK_H) * cdiv(out_w, BLOCK_W)) 

87 pid_nc = tl.program_id(0) 

88 pid_dhw = tl.program_id(1) 

89 

90 num_w_blocks = tl.cdiv(out_w, BLOCK_W) 

91 num_h_blocks = tl.cdiv(out_h, BLOCK_H) 

92 num_hw_blocks = num_h_blocks * num_w_blocks 

93 

94 # Decompose pid_dhw into d_idx, h_block_idx, w_block_idx 

95 d_idx = pid_dhw // num_hw_blocks 

96 hw_remainder = pid_dhw % num_hw_blocks 

97 h_block_idx = hw_remainder // num_w_blocks 

98 w_block_idx = hw_remainder % num_w_blocks 

99 

100 n_idx = pid_nc // in_c 

101 c_idx = pid_nc % in_c 

102 

103 h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

104 w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

105 

106 sum_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32) 

107 count_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32) 

108 

109 input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c 

110 

111 for kd in range(0, kernel_d): 

112 d_in = d_idx * stride_d - padding_d + kd * dilation_d 

113 d_valid = (d_in >= 0) & (d_in < in_d) 

114 for kh in range(0, kernel_h): 

115 for kw in range(0, kernel_w): 

116 h_in = h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h 

117 w_in = w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w 

118 hw_mask = (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w) 

119 in_mask = hw_mask & d_valid 

120 

121 input_offset = ( 

122 d_in * in_stride_d + h_in * in_stride_h + w_in * in_stride_w 

123 ) 

124 current_val = tl.load( 

125 input_base_ptr + input_offset, mask=in_mask, other=0.0 

126 ) 

127 

128 sum_acc += tl.where(in_mask, current_val, 0.0) 

129 count_acc += in_mask.to(tl.int32) 

130 

131 if divisor_override != 0: 

132 divisor = tl.full((BLOCK_H, BLOCK_W), divisor_override, dtype=tl.float32) 

133 elif COUNT_INCLUDE_PAD: 

134 # Count positions within padded boundary (correct for ceil_mode edges) 

135 d_start_fwd = d_idx * stride_d - padding_d 

136 d_padded_count = tl.minimum(d_start_fwd + kernel_d, in_d + padding_d) - ( 

137 tl.maximum(d_start_fwd, -padding_d) 

138 ) 

139 d_padded_count = tl.maximum(d_padded_count, 0) 

140 

141 h_start_fwd = h_out_offsets[:, None] * stride_h - padding_h 

142 h_padded_count = tl.minimum(h_start_fwd + kernel_h, in_h + padding_h) - ( 

143 tl.maximum(h_start_fwd, -padding_h) 

144 ) 

145 h_padded_count = tl.maximum(h_padded_count, 0) 

146 

147 w_start_fwd = w_out_offsets[None, :] * stride_w - padding_w 

148 w_padded_count = tl.minimum(w_start_fwd + kernel_w, in_w + padding_w) - ( 

149 tl.maximum(w_start_fwd, -padding_w) 

150 ) 

151 w_padded_count = tl.maximum(w_padded_count, 0) 

152 

153 divisor = (d_padded_count * h_padded_count * w_padded_count).to(tl.float32) 

154 else: 

155 divisor = count_acc.to(tl.float32) 

156 

157 output_vals = tl.where(divisor != 0, sum_acc / divisor, 0.0) 

158 

159 out_base_ptr = output_ptr + pid_nc * out_d * out_h * out_w + d_idx * out_h * out_w 

160 out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

161 out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

162 output_block_ptr = ( 

163 out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :] 

164 ) 

165 

166 out_mask = (out_h_offsets[:, None] < out_h) & (out_w_offsets[None, :] < out_w) 

167 tl.store( 

168 output_block_ptr, output_vals.to(output_ptr.type.element_ty), mask=out_mask 

169 ) 

170 

171 

172@libentry() 

173@triton.autotune( 

174 configs=[ 

175 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4), 

176 triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4), 

177 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4), 

178 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8), 

179 triton.Config({"BLOCK_H": 64, "BLOCK_W": 32}, num_stages=2, num_warps=8), 

180 triton.Config({"BLOCK_H": 32, "BLOCK_W": 64}, num_stages=2, num_warps=8), 

181 ], 

182 key=["in_h", "in_w", "kernel_d", "kernel_h", "kernel_w"], 

183) 

184@triton.jit 

185def avg_pool3d_backward_kernel( 

186 grad_output_ptr, 

187 grad_input_ptr, 

188 # Input/Output shapes 

189 in_c, 

190 in_d, 

191 in_h, 

192 in_w, 

193 out_d, 

194 out_h, 

195 out_w, 

196 # Strides for grad_input 

197 in_stride_n, 

198 in_stride_c, 

199 in_stride_d, 

200 in_stride_h, 

201 in_stride_w, 

202 # Strides for grad_output 

203 out_stride_n, 

204 out_stride_c, 

205 out_stride_d, 

206 out_stride_h, 

207 out_stride_w, 

208 # Pooling parameters 

209 kernel_d: tl.constexpr, 

210 kernel_h: tl.constexpr, 

211 kernel_w: tl.constexpr, 

212 stride_d: tl.constexpr, 

213 stride_h: tl.constexpr, 

214 stride_w: tl.constexpr, 

215 padding_d: tl.constexpr, 

216 padding_h: tl.constexpr, 

217 padding_w: tl.constexpr, 

218 # AvgPool specific parameters 

219 COUNT_INCLUDE_PAD: tl.constexpr, 

220 divisor_override, 

221 # Tiling meta-parameters 

222 BLOCK_H: tl.constexpr, 

223 BLOCK_W: tl.constexpr, 

224): 

225 # Input-centric backward: iterate over input positions, gather from output. 

226 # Uses tl.store (not atomic_add), safe with autotune. 

227 # Grid: (N*C, in_d * cdiv(in_h, BLOCK_H) * cdiv(in_w, BLOCK_W)) 

228 pid_nc = tl.program_id(0) 

229 pid_dhw = tl.program_id(1) 

230 

231 num_w_blocks = tl.cdiv(in_w, BLOCK_W) 

232 num_h_blocks = tl.cdiv(in_h, BLOCK_H) 

233 num_hw_blocks = num_h_blocks * num_w_blocks 

234 

235 d_in_idx = pid_dhw // num_hw_blocks 

236 hw_remainder = pid_dhw % num_hw_blocks 

237 h_block_idx = hw_remainder // num_w_blocks 

238 w_block_idx = hw_remainder % num_w_blocks 

239 

240 n_idx = pid_nc // in_c 

241 c_idx = pid_nc % in_c 

242 

243 grad_input_base = grad_input_ptr + n_idx * in_stride_n + c_idx * in_stride_c 

244 grad_output_base = grad_output_ptr + n_idx * out_stride_n + c_idx * out_stride_c 

245 

246 h_in_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

247 w_in_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

248 

249 grad_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32) 

250 

251 for kd in range(kernel_d): 

252 d_out_num = d_in_idx + padding_d - kd 

253 d_out_valid = (d_out_num >= 0) & ((d_out_num % stride_d) == 0) 

254 d_out = d_out_num // stride_d 

255 d_out_valid = d_out_valid & (d_out >= 0) & (d_out < out_d) 

256 

257 for kh in range(kernel_h): 

258 for kw in range(kernel_w): 

259 h_out_num = h_in_offsets[:, None] + padding_h - kh 

260 w_out_num = w_in_offsets[None, :] + padding_w - kw 

261 

262 h_valid = (h_out_num >= 0) & ((h_out_num % stride_h) == 0) 

263 w_valid = (w_out_num >= 0) & ((w_out_num % stride_w) == 0) 

264 

265 h_out = h_out_num // stride_h 

266 w_out = w_out_num // stride_w 

267 

268 out_mask = ( 

269 d_out_valid & h_valid & w_valid & (h_out < out_h) & (w_out < out_w) 

270 ) 

271 

272 if divisor_override != 0: 

273 divisor = tl.full( 

274 (BLOCK_H, BLOCK_W), divisor_override, dtype=tl.float32 

275 ) 

276 elif COUNT_INCLUDE_PAD: 

277 # Count positions within padded boundary (ceil_mode) 

278 d_start_bwd = d_out * stride_d - padding_d 

279 d_pc = tl.minimum( 

280 d_start_bwd + kernel_d, in_d + padding_d 

281 ) - tl.maximum(d_start_bwd, -padding_d) 

282 d_pc = tl.maximum(d_pc, 0) 

283 

284 h_start_bwd = h_out * stride_h - padding_h 

285 h_pc = tl.minimum( 

286 h_start_bwd + kernel_h, in_h + padding_h 

287 ) - tl.maximum(h_start_bwd, -padding_h) 

288 h_pc = tl.maximum(h_pc, 0) 

289 

290 w_start_bwd = w_out * stride_w - padding_w 

291 w_pc = tl.minimum( 

292 w_start_bwd + kernel_w, in_w + padding_w 

293 ) - tl.maximum(w_start_bwd, -padding_w) 

294 w_pc = tl.maximum(w_pc, 0) 

295 

296 divisor = (d_pc * h_pc * w_pc).to(tl.float32) 

297 else: 

298 d_start = d_out * stride_d - padding_d 

299 d_count = tl.minimum(d_start + kernel_d, in_d) - tl.maximum( 

300 d_start, 0 

301 ) 

302 d_count = tl.maximum(d_count, 0) 

303 

304 h_start = h_out * stride_h - padding_h 

305 h_count = tl.minimum(h_start + kernel_h, in_h) - tl.maximum( 

306 h_start, 0 

307 ) 

308 h_count = tl.maximum(h_count, 0) 

309 

310 w_start = w_out * stride_w - padding_w 

311 w_count = tl.minimum(w_start + kernel_w, in_w) - tl.maximum( 

312 w_start, 0 

313 ) 

314 w_count = tl.maximum(w_count, 0) 

315 

316 divisor = (d_count * h_count * w_count).to(tl.float32) 

317 

318 divisor = tl.where(divisor == 0, 1.0, divisor) 

319 

320 grad_out_ptr = ( 

321 grad_output_base 

322 + d_out * out_stride_d 

323 + h_out * out_stride_h 

324 + w_out * out_stride_w 

325 ) 

326 grad_out_val = tl.load(grad_out_ptr, mask=out_mask, other=0.0) 

327 grad_acc += tl.where(out_mask, grad_out_val / divisor, 0.0) 

328 

329 grad_input_store_ptr = ( 

330 grad_input_base 

331 + d_in_idx * in_stride_d 

332 + h_in_offsets[:, None] * in_stride_h 

333 + w_in_offsets[None, :] * in_stride_w 

334 ) 

335 in_write_mask = (h_in_offsets[:, None] < in_h) & (w_in_offsets[None, :] < in_w) 

336 tl.store( 

337 grad_input_store_ptr, 

338 grad_acc.to(grad_input_ptr.type.element_ty), 

339 mask=in_write_mask, 

340 ) 

341 

342 

343def _parse_pool3d_params(kernel_size, stride, padding): 

344 """Parse and validate 3D pooling parameters.""" 

345 if isinstance(kernel_size, int): 

346 kernel_d = kernel_h = kernel_w = kernel_size 

347 else: 

348 kernel_d, kernel_h, kernel_w = kernel_size 

349 

350 if stride is None or (isinstance(stride, (list, tuple)) and not stride): 

351 stride_d, stride_h, stride_w = kernel_d, kernel_h, kernel_w 

352 elif isinstance(stride, int): 

353 stride_d = stride_h = stride_w = stride 

354 else: 

355 stride_d, stride_h, stride_w = stride 

356 

357 if isinstance(padding, int): 

358 padding_d = padding_h = padding_w = padding 

359 else: 

360 padding_d, padding_h, padding_w = padding 

361 

362 if stride_d <= 0 or stride_h <= 0 or stride_w <= 0: 

363 raise ValueError("stride must be greater than zero") 

364 

365 if padding_d < 0 or padding_h < 0 or padding_w < 0: 

366 raise ValueError("padding must be non-negative") 

367 

368 if ( 

369 padding_d > kernel_d // 2 

370 or padding_h > kernel_h // 2 

371 or padding_w > kernel_w // 2 

372 ): 

373 raise ValueError("pad should be smaller than or equal to half of kernel size") 

374 

375 return ( 

376 kernel_d, 

377 kernel_h, 

378 kernel_w, 

379 stride_d, 

380 stride_h, 

381 stride_w, 

382 padding_d, 

383 padding_h, 

384 padding_w, 

385 ) 

386 

387 

388def avg_pool3d( 

389 input: torch.Tensor, 

390 kernel_size, 

391 stride=None, 

392 padding=0, 

393 ceil_mode=False, 

394 count_include_pad=True, 

395 divisor_override=None, 

396): 

397 """Compute 3D average pooling over an input signal composed of several input 

398 planes. 

399 

400 Args: 

401 input: 5D tensor of shape (N, C, D, H, W). 

402 kernel_size: Size of the pooling window. Can be int or (kD, kH, kW). 

403 stride: Stride of the pooling window. Default: kernel_size. 

404 padding: Implicit zero padding on both sides. Default: 0. 

405 ceil_mode: Use ceil instead of floor to compute output shape. Default: False. 

406 count_include_pad: Include zero-padding in the averaging calculation. 

407 Default: True. 

408 divisor_override: If specified, use this as the divisor instead of the 

409 pool size. Default: None. 

410 

411 Returns: 

412 5D tensor of shape (N, C, D_out, H_out, W_out). 

413 """ 

414 logger.debug("GEMS AVG_POOL3D FORWARD") 

415 

416 if divisor_override is not None and divisor_override == 0: 

417 raise ValueError("divisor_override cannot be zero") 

418 

419 input = input.contiguous() 

420 

421 ( 

422 kernel_d, 

423 kernel_h, 

424 kernel_w, 

425 stride_d, 

426 stride_h, 

427 stride_w, 

428 padding_d, 

429 padding_h, 

430 padding_w, 

431 ) = _parse_pool3d_params(kernel_size, stride, padding) 

432 dilation_d, dilation_h, dilation_w = 1, 1, 1 

433 

434 in_n, in_c, in_d, in_h, in_w = input.shape 

435 

436 out_d = pool3d_output_size( 

437 in_d, kernel_d, stride_d, padding_d, dilation_d, ceil_mode 

438 ) 

439 out_h = pool3d_output_size( 

440 in_h, kernel_h, stride_h, padding_h, dilation_h, ceil_mode 

441 ) 

442 out_w = pool3d_output_size( 

443 in_w, kernel_w, stride_w, padding_w, dilation_w, ceil_mode 

444 ) 

445 

446 output = torch.empty( 

447 (in_n, in_c, out_d, out_h, out_w), device=input.device, dtype=input.dtype 

448 ) 

449 

450 if output.numel() == 0: 

451 return output 

452 

453 grid = lambda meta: ( 

454 in_n * in_c, 

455 out_d 

456 * triton.cdiv(out_h, meta["BLOCK_H"]) 

457 * triton.cdiv(out_w, meta["BLOCK_W"]), 

458 ) 

459 

460 avg_pool3d_forward_kernel[grid]( 

461 input, 

462 output, 

463 input.stride(0), 

464 input.stride(1), 

465 input.stride(2), 

466 input.stride(3), 

467 input.stride(4), 

468 in_c, 

469 in_d, 

470 in_h, 

471 in_w, 

472 out_d, 

473 out_h, 

474 out_w, 

475 kernel_d, 

476 kernel_h, 

477 kernel_w, 

478 stride_d, 

479 stride_h, 

480 stride_w, 

481 padding_d, 

482 padding_h, 

483 padding_w, 

484 dilation_d, 

485 dilation_h, 

486 dilation_w, 

487 COUNT_INCLUDE_PAD=count_include_pad, 

488 divisor_override=divisor_override if divisor_override is not None else 0.0, 

489 ) 

490 

491 return output 

492 

493 

494def avg_pool3d_backward( 

495 grad_output: torch.Tensor, 

496 input: torch.Tensor, 

497 kernel_size, 

498 stride, 

499 padding, 

500 ceil_mode, 

501 count_include_pad, 

502 divisor_override, 

503): 

504 """Compute the gradient of avg_pool3d. 

505 

506 Args: 

507 grad_output: Gradient of the output tensor. 

508 input: Original input tensor (used for shape information). 

509 kernel_size: Size of the pooling window. 

510 stride: Stride of the pooling window. 

511 padding: Implicit zero padding. 

512 ceil_mode: Whether ceil was used for output shape. 

513 count_include_pad: Whether padding was included in averaging. 

514 divisor_override: Custom divisor override. 

515 

516 Returns: 

517 Gradient with respect to the input tensor. 

518 """ 

519 logger.debug("GEMS AVG_POOL3D BACKWARD") 

520 

521 if divisor_override is not None and divisor_override == 0: 

522 raise ValueError("divisor_override cannot be zero") 

523 

524 grad_output = grad_output.contiguous() 

525 

526 ( 

527 kernel_d, 

528 kernel_h, 

529 kernel_w, 

530 stride_d, 

531 stride_h, 

532 stride_w, 

533 padding_d, 

534 padding_h, 

535 padding_w, 

536 ) = _parse_pool3d_params(kernel_size, stride, padding) 

537 

538 in_n, in_c, in_d, in_h, in_w = input.shape 

539 out_d, out_h, out_w = ( 

540 grad_output.shape[2], 

541 grad_output.shape[3], 

542 grad_output.shape[4], 

543 ) 

544 

545 grad_input = torch.empty_like(input) 

546 

547 if grad_output.numel() == 0: 

548 return grad_input.zero_() 

549 

550 # Input-centric grid: iterate over input positions 

551 grid = lambda meta: ( 

552 in_n * in_c, 

553 in_d * triton.cdiv(in_h, meta["BLOCK_H"]) * triton.cdiv(in_w, meta["BLOCK_W"]), 

554 ) 

555 

556 avg_pool3d_backward_kernel[grid]( 

557 grad_output, 

558 grad_input, 

559 in_c, 

560 in_d, 

561 in_h, 

562 in_w, 

563 out_d, 

564 out_h, 

565 out_w, 

566 grad_input.stride(0), 

567 grad_input.stride(1), 

568 grad_input.stride(2), 

569 grad_input.stride(3), 

570 grad_input.stride(4), 

571 grad_output.stride(0), 

572 grad_output.stride(1), 

573 grad_output.stride(2), 

574 grad_output.stride(3), 

575 grad_output.stride(4), 

576 kernel_d, 

577 kernel_h, 

578 kernel_w, 

579 stride_d, 

580 stride_h, 

581 stride_w, 

582 padding_d, 

583 padding_h, 

584 padding_w, 

585 COUNT_INCLUDE_PAD=count_include_pad, 

586 divisor_override=divisor_override if divisor_override is not None else 0.0, 

587 ) 

588 

589 return grad_input