Coverage for src/flag_gems/ops/max_pool3d_with_indices.py: 12%

164 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import libentry 

8from flag_gems.utils.limits import get_dtype_min 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13def pool3d_output_size( 

14 in_size: int, 

15 kernel_size: int, 

16 stride: int, 

17 padding: int, 

18 dilation: int, 

19 ceil_mode: bool = False, 

20) -> int: 

21 """Compute one spatial dimension of the 3-D max-pool output.""" 

22 effective_kernel_size = (kernel_size - 1) * dilation + 1 

23 numerator = in_size + 2 * padding - effective_kernel_size 

24 if ceil_mode: 

25 output_size = (numerator + stride - 1) // stride + 1 

26 # PyTorch-compatible adjustment for ceil_mode 

27 if (output_size - 1) * stride >= in_size + padding: 

28 output_size -= 1 

29 else: 

30 output_size = numerator // stride + 1 

31 return output_size 

32 

33 

34@libentry() 

35@triton.autotune( 

36 configs=[ 

37 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4), 

38 triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4), 

39 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4), 

40 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8), 

41 triton.Config({"BLOCK_H": 8, "BLOCK_W": 8}, num_stages=5, num_warps=2), 

42 triton.Config({"BLOCK_H": 16, "BLOCK_W": 8}, num_stages=5, num_warps=2), 

43 triton.Config({"BLOCK_H": 8, "BLOCK_W": 16}, num_stages=5, num_warps=2), 

44 triton.Config({"BLOCK_H": 64, "BLOCK_W": 16}, num_stages=2, num_warps=8), 

45 triton.Config({"BLOCK_H": 16, "BLOCK_W": 64}, num_stages=2, num_warps=8), 

46 triton.Config({"BLOCK_H": 32, "BLOCK_W": 64}, num_stages=3, num_warps=8), 

47 triton.Config({"BLOCK_H": 64, "BLOCK_W": 32}, num_stages=3, num_warps=8), 

48 triton.Config({"BLOCK_H": 64, "BLOCK_W": 64}, num_stages=2, num_warps=8), 

49 ], 

50 key=[ 

51 "out_d", 

52 "out_h", 

53 "out_w", 

54 "kernel_d", 

55 "kernel_h", 

56 "kernel_w", 

57 "stride_d", 

58 "stride_h", 

59 "stride_w", 

60 ], 

61) 

62@triton.jit 

63def max_pool3d_forward_kernel( 

64 input_ptr, 

65 output_ptr, 

66 indices_ptr, 

67 # Input tensor strides 

68 in_stride_n, 

69 in_stride_c, 

70 in_stride_d, 

71 in_stride_h, 

72 in_stride_w, 

73 # Input/Output shapes 

74 in_c, 

75 in_d, 

76 in_h, 

77 in_w, 

78 out_d, 

79 out_h, 

80 out_w, 

81 # Pooling parameters 

82 kernel_d: tl.constexpr, 

83 kernel_h: tl.constexpr, 

84 kernel_w: tl.constexpr, 

85 stride_d: tl.constexpr, 

86 stride_h: tl.constexpr, 

87 stride_w: tl.constexpr, 

88 padding_d: tl.constexpr, 

89 padding_h: tl.constexpr, 

90 padding_w: tl.constexpr, 

91 dilation_d: tl.constexpr, 

92 dilation_h: tl.constexpr, 

93 dilation_w: tl.constexpr, 

94 # Meta-parameters for tiling 

95 BLOCK_H: tl.constexpr, 

96 BLOCK_W: tl.constexpr, 

97): 

98 """Forward kernel for 3-D max pooling. 

99 

100 Grid: (N * C, num_d_blocks * num_h_blocks * num_w_blocks) 

101 where num_h_blocks = cdiv(out_h, BLOCK_H), 

102 num_w_blocks = cdiv(out_w, BLOCK_W). 

103 The depth dimension is iterated inside the kernel via a loop 

104 over num_d_blocks output depth positions. 

105 """ 

106 pid_nc = tl.program_id(0) 

107 pid_dhw = tl.program_id(1) 

108 

109 num_h_blocks = tl.cdiv(out_h, BLOCK_H) 

110 num_w_blocks = tl.cdiv(out_w, BLOCK_W) 

111 

112 d_block_idx = pid_dhw // (num_h_blocks * num_w_blocks) 

113 hw_remainder = pid_dhw % (num_h_blocks * num_w_blocks) 

114 h_block_idx = hw_remainder // num_w_blocks 

115 w_block_idx = hw_remainder % num_w_blocks 

116 

117 n_idx = pid_nc // in_c 

118 c_idx = pid_nc % in_c 

119 

120 d_out = d_block_idx 

121 

122 h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

123 w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

124 

125 dtype = input_ptr.type.element_ty 

126 min_val = get_dtype_min(dtype) 

127 max_val_acc = tl.full((BLOCK_H, BLOCK_W), min_val, dtype=dtype) 

128 max_idx_acc = tl.full((BLOCK_H, BLOCK_W), -1, dtype=tl.int64) 

129 

130 input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c 

131 

132 for kd in tl.static_range(0, kernel_d): 

133 d_in = d_out * stride_d - padding_d + kd * dilation_d 

134 d_valid = (d_in >= 0) & (d_in < in_d) 

135 for kh in tl.static_range(0, kernel_h): 

136 for kw in tl.static_range(0, kernel_w): 

137 h_in = h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h 

138 w_in = w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w 

139 in_mask = ( 

140 d_valid & (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w) 

141 ) 

142 input_offset = ( 

143 d_in * in_stride_d + h_in * in_stride_h + w_in * in_stride_w 

144 ) 

145 current_val = tl.load( 

146 input_base_ptr + input_offset, mask=in_mask, other=min_val 

147 ) 

148 # Flat index in (D, H, W) space 

149 current_idx = d_in * in_h * in_w + h_in * in_w + w_in 

150 

151 is_new_max = current_val > max_val_acc 

152 max_val_acc = tl.where(is_new_max, current_val, max_val_acc) 

153 max_idx_acc = tl.where(is_new_max & in_mask, current_idx, max_idx_acc) 

154 

155 out_spatial = out_h * out_w 

156 out_base_offset = pid_nc * out_d * out_spatial + d_out * out_spatial 

157 out_base_ptr = output_ptr + out_base_offset 

158 indices_base_ptr = indices_ptr + out_base_offset 

159 out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

160 out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

161 output_block_ptr = ( 

162 out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :] 

163 ) 

164 indices_block_ptr = ( 

165 indices_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :] 

166 ) 

167 

168 out_mask = (out_h_offsets[:, None] < out_h) & (out_w_offsets[None, :] < out_w) 

169 tl.store(output_block_ptr, max_val_acc, mask=out_mask) 

170 tl.store(indices_block_ptr, max_idx_acc, mask=out_mask) 

171 

172 

173@libentry() 

174@triton.autotune( 

175 configs=[ 

176 triton.Config({"BLOCK_IN_H": 16, "BLOCK_IN_W": 16}, num_warps=4), 

177 triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 8}, num_warps=4), 

178 triton.Config({"BLOCK_IN_H": 8, "BLOCK_IN_W": 32}, num_warps=4), 

179 triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 32}, num_warps=8), 

180 triton.Config({"BLOCK_IN_H": 16, "BLOCK_IN_W": 64}, num_warps=8), 

181 triton.Config({"BLOCK_IN_H": 64, "BLOCK_IN_W": 16}, num_warps=8), 

182 ], 

183 key=[ 

184 "in_d", 

185 "in_h", 

186 "in_w", 

187 "kernel_d", 

188 "kernel_h", 

189 "kernel_w", 

190 "stride_d", 

191 "stride_h", 

192 "stride_w", 

193 ], 

194) 

195@triton.jit 

196def max_pool3d_backward_kernel( 

197 grad_output_ptr, 

198 indices_ptr, 

199 grad_input_ptr, 

200 # Shape info 

201 in_d, 

202 in_h, 

203 in_w, 

204 out_d, 

205 out_h, 

206 out_w, 

207 # Strides for grad_output/indices (contiguous layout: NC, D, H, W) 

208 out_stride_nc, 

209 out_stride_d, 

210 out_stride_h, 

211 out_stride_w, 

212 # Pooling parameters 

213 kernel_d: tl.constexpr, 

214 kernel_h: tl.constexpr, 

215 kernel_w: tl.constexpr, 

216 stride_d: tl.constexpr, 

217 stride_h: tl.constexpr, 

218 stride_w: tl.constexpr, 

219 padding_d: tl.constexpr, 

220 padding_h: tl.constexpr, 

221 padding_w: tl.constexpr, 

222 dilation_d: tl.constexpr, 

223 dilation_h: tl.constexpr, 

224 dilation_w: tl.constexpr, 

225 # Tiling parameters 

226 BLOCK_IN_H: tl.constexpr, 

227 BLOCK_IN_W: tl.constexpr, 

228): 

229 """Backward kernel for 3-D max pooling. 

230 

231 Grid: (N * C, num_d_in * num_h_blocks * num_w_blocks) 

232 For each input (d, h, w) position, iterate over all kernel 

233 offsets to find which output positions could have selected it, 

234 then accumulate the gradient. 

235 """ 

236 nc_idx = tl.program_id(0) 

237 pid_dhw = tl.program_id(1) 

238 

239 num_h_blocks = tl.cdiv(in_h, BLOCK_IN_H) 

240 num_w_blocks = tl.cdiv(in_w, BLOCK_IN_W) 

241 

242 d_in_idx = pid_dhw // (num_h_blocks * num_w_blocks) 

243 hw_remainder = pid_dhw % (num_h_blocks * num_w_blocks) 

244 h_block_idx = hw_remainder // num_w_blocks 

245 w_block_idx = hw_remainder % num_w_blocks 

246 

247 h_in_offsets = h_block_idx * BLOCK_IN_H + tl.arange(0, BLOCK_IN_H) 

248 w_in_offsets = w_block_idx * BLOCK_IN_W + tl.arange(0, BLOCK_IN_W) 

249 

250 # Flat index of current input position in (D, H, W) space 

251 current_input_flat_idx = ( 

252 d_in_idx * in_h * in_w + h_in_offsets[:, None] * in_w + w_in_offsets[None, :] 

253 ) 

254 grad_acc = tl.zeros((BLOCK_IN_H, BLOCK_IN_W), dtype=tl.float32) 

255 

256 indices_base_ptr = indices_ptr + nc_idx * out_stride_nc 

257 grad_output_base_ptr = grad_output_ptr + nc_idx * out_stride_nc 

258 

259 for kd in tl.static_range(0, kernel_d): 

260 numerator_d = d_in_idx + padding_d - kd * dilation_d 

261 valid_d = numerator_d % stride_d == 0 

262 d_out = numerator_d // stride_d 

263 d_bounds = (d_out >= 0) & (d_out < out_d) 

264 d_valid = valid_d & d_bounds 

265 

266 for kh in tl.static_range(0, kernel_h): 

267 for kw in tl.static_range(0, kernel_w): 

268 numerator_h = h_in_offsets[:, None] + padding_h - kh * dilation_h 

269 numerator_w = w_in_offsets[None, :] + padding_w - kw * dilation_w 

270 

271 valid_map_mask = ( 

272 d_valid 

273 & (numerator_h % stride_h == 0) 

274 & (numerator_w % stride_w == 0) 

275 ) 

276 h_out = numerator_h // stride_h 

277 w_out = numerator_w // stride_w 

278 out_bounds_mask = ( 

279 (h_out >= 0) & (h_out < out_h) & (w_out >= 0) & (w_out < out_w) 

280 ) 

281 load_mask = valid_map_mask & out_bounds_mask 

282 

283 safe_h_out = tl.where(load_mask, h_out, 0) 

284 safe_w_out = tl.where(load_mask, w_out, 0) 

285 safe_d_out = tl.where(load_mask, d_out, 0) 

286 out_offsets = ( 

287 safe_d_out * out_stride_d + safe_h_out * out_stride_h + safe_w_out 

288 ) 

289 

290 indices_block = tl.load( 

291 indices_base_ptr + out_offsets, mask=load_mask, other=-1 

292 ) 

293 match_mask = indices_block == current_input_flat_idx 

294 

295 grad_block = tl.load( 

296 grad_output_base_ptr + out_offsets, 

297 mask=match_mask, 

298 other=0.0, 

299 ) 

300 grad_acc += grad_block 

301 

302 in_spatial = in_h * in_w 

303 grad_input_base_ptr = grad_input_ptr + nc_idx * in_d * in_spatial 

304 grad_input_offsets = ( 

305 d_in_idx * in_spatial + h_in_offsets[:, None] * in_w + w_in_offsets[None, :] 

306 ) 

307 store_mask = (h_in_offsets[:, None] < in_h) & (w_in_offsets[None, :] < in_w) 

308 tl.store(grad_input_base_ptr + grad_input_offsets, grad_acc, mask=store_mask) 

309 

310 

311def _parse_pool3d_params(kernel_size, stride, padding, dilation): 

312 """Parse and validate 3-D pooling parameters. 

313 

314 Each parameter can be an int (applied to all 3 spatial dims) or a 

315 3-element tuple/list (D, H, W). 

316 """ 

317 

318 def _parse_param(param, name, default=None): 

319 if param is None: 

320 return default 

321 if isinstance(param, int): 

322 return param, param, param 

323 if isinstance(param, (list, tuple)) and len(param) == 3: 

324 return tuple(param) 

325 raise ValueError(f"Invalid {name}: {param}") 

326 

327 kd, kh, kw = _parse_param(kernel_size, "kernel_size") 

328 sd, sh, sw = _parse_param(stride, "stride", default=(kd, kh, kw)) 

329 pd, ph, pw = _parse_param(padding, "padding", default=(0, 0, 0)) 

330 dd, dh, dw = _parse_param(dilation, "dilation", default=(1, 1, 1)) 

331 

332 if sd <= 0 or sh <= 0 or sw <= 0: 

333 raise ValueError(f"stride must be positive, but got stride=({sd}, {sh}, {sw})") 

334 if pd < 0 or ph < 0 or pw < 0: 

335 raise ValueError( 

336 f"padding must be non-negative, but got padding=({pd}, {ph}, {pw})" 

337 ) 

338 if dd <= 0 or dh <= 0 or dw <= 0: 

339 raise ValueError( 

340 f"dilation must be positive, but got dilation=({dd}, {dh}, {dw})" 

341 ) 

342 

343 return kd, kh, kw, sd, sh, sw, pd, ph, pw, dd, dh, dw 

344 

345 

346def max_pool3d_with_indices( 

347 input: torch.Tensor, 

348 kernel_size, 

349 stride=None, 

350 padding=0, 

351 dilation=1, 

352 ceil_mode=False, 

353): 

354 """Compute 3-D max pooling, returning (output, indices). 

355 

356 Indices are flat offsets into the (D, H, W) spatial volume of the input. 

357 """ 

358 logger.debug("GEMS MAX_POOL3D_WITH_INDICES") 

359 input = input.contiguous() 

360 

361 params = _parse_pool3d_params(kernel_size, stride, padding, dilation) 

362 kd, kh, kw, sd, sh, sw, pd, ph, pw, dd, dh, dw = params 

363 

364 in_n, in_c, in_d, in_h, in_w = input.shape 

365 out_d = pool3d_output_size(in_d, kd, sd, pd, dd, ceil_mode) 

366 out_h = pool3d_output_size(in_h, kh, sh, ph, dh, ceil_mode) 

367 out_w = pool3d_output_size(in_w, kw, sw, pw, dw, ceil_mode) 

368 

369 output = torch.empty( 

370 (in_n, in_c, out_d, out_h, out_w), device=input.device, dtype=input.dtype 

371 ) 

372 indices = torch.empty( 

373 (in_n, in_c, out_d, out_h, out_w), device=input.device, dtype=torch.int64 

374 ) 

375 

376 if output.numel() == 0: 

377 return output, indices 

378 

379 grid = lambda meta: ( 

380 in_n * in_c, 

381 out_d 

382 * triton.cdiv(out_h, meta["BLOCK_H"]) 

383 * triton.cdiv(out_w, meta["BLOCK_W"]), 

384 ) 

385 

386 max_pool3d_forward_kernel[grid]( 

387 input, 

388 output, 

389 indices, 

390 input.stride(0), 

391 input.stride(1), 

392 input.stride(2), 

393 input.stride(3), 

394 input.stride(4), 

395 in_c, 

396 in_d, 

397 in_h, 

398 in_w, 

399 out_d, 

400 out_h, 

401 out_w, 

402 kd, 

403 kh, 

404 kw, 

405 sd, 

406 sh, 

407 sw, 

408 pd, 

409 ph, 

410 pw, 

411 dd, 

412 dh, 

413 dw, 

414 ) 

415 

416 return output, indices 

417 

418 

419def max_pool3d_backward( 

420 grad_output: torch.Tensor, 

421 input: torch.Tensor, 

422 indices: torch.Tensor, 

423 kernel_size, 

424 stride, 

425 padding, 

426 dilation, 

427 ceil_mode, 

428): 

429 """Backward pass for 3-D max pooling.""" 

430 logger.debug("GEMS MAX_POOL3D BACKWARD") 

431 grad_output = grad_output.contiguous() 

432 indices = indices.contiguous() 

433 

434 params = _parse_pool3d_params(kernel_size, stride, padding, dilation) 

435 kd, kh, kw, sd, sh, sw, pd, ph, pw, dd, dh, dw = params 

436 

437 in_n, in_c, in_d, in_h, in_w = input.shape 

438 out_d, out_h, out_w = ( 

439 grad_output.shape[2], 

440 grad_output.shape[3], 

441 grad_output.shape[4], 

442 ) 

443 

444 grad_input = torch.zeros_like(input, dtype=torch.float32) 

445 

446 if grad_input.numel() == 0: 

447 return grad_input.to(grad_output.dtype) 

448 

449 out_spatial = out_h * out_w 

450 out_stride_nc = out_d * out_spatial 

451 out_stride_d = out_spatial 

452 out_stride_h = out_w 

453 out_stride_w = 1 

454 

455 grid = lambda meta: ( 

456 in_n * in_c, 

457 in_d 

458 * triton.cdiv(in_h, meta["BLOCK_IN_H"]) 

459 * triton.cdiv(in_w, meta["BLOCK_IN_W"]), 

460 ) 

461 

462 max_pool3d_backward_kernel[grid]( 

463 grad_output, 

464 indices, 

465 grad_input, 

466 in_d, 

467 in_h, 

468 in_w, 

469 out_d, 

470 out_h, 

471 out_w, 

472 out_stride_nc, 

473 out_stride_d, 

474 out_stride_h, 

475 out_stride_w, 

476 kd, 

477 kh, 

478 kw, 

479 sd, 

480 sh, 

481 sw, 

482 pd, 

483 ph, 

484 pw, 

485 dd, 

486 dh, 

487 dw, 

488 ) 

489 

490 return grad_input.to(grad_output.dtype)