Coverage for src/flag_gems/runtime/backend/_cambricon/ops/max_pool2d_with_indices.py: 0%

168 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import libentry, libtuner 

8from flag_gems.utils.limits import get_dtype_min 

9 

10from ..utils import MAX_GRID_SIZE_X, MAX_GRID_SIZE_Y 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15def max_pool2d_output_size( 

16 in_size: int, 

17 kernel_size: int, 

18 stride: int, 

19 padding: int, 

20 dilation: int, 

21 ceil_mode: bool = False, 

22) -> int: 

23 effective_kernel_size = (kernel_size - 1) * dilation + 1 

24 numerator = in_size + 2 * padding - effective_kernel_size 

25 if ceil_mode: 

26 output_size = (numerator + stride - 1) // stride + 1 

27 # PyTorch-compatible adjustment for ceil_mode 

28 if (output_size - 1) * stride >= in_size + padding: 

29 output_size -= 1 

30 else: 

31 output_size = numerator // stride + 1 

32 

33 return output_size 

34 

35 

36def limit_grid(grid_0, grid_1): 

37 grid_0_ub = MAX_GRID_SIZE_X // 4 

38 grid_1_ub = MAX_GRID_SIZE_Y 

39 return min(grid_0, grid_0_ub), min(grid_1, grid_1_ub) 

40 

41 

42@libentry() 

43@libtuner( 

44 configs=[ 

45 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4), 

46 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4), 

47 triton.Config({"BLOCK_H": 16, "BLOCK_W": 8}, num_stages=5, num_warps=1), 

48 triton.Config({"BLOCK_H": 8, "BLOCK_W": 16}, num_stages=5, num_warps=1), 

49 triton.Config({"BLOCK_H": 8, "BLOCK_W": 8}, num_stages=5, num_warps=1), 

50 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=4), 

51 ], 

52 key=["out_h", "out_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], 

53 strategy=["align32", "align32", "align32", "align32", "align32", "align32"], 

54 warmup=5, 

55 rep=10, 

56) 

57@triton.jit 

58def max_pool2d_forward_kernel( 

59 input_ptr, 

60 output_ptr, 

61 indices_ptr, 

62 # Input tensor strides 

63 in_stride_n, 

64 in_stride_c, 

65 in_stride_h, 

66 in_stride_w, 

67 # Input/Output shapes 

68 in_c, 

69 in_h, 

70 in_w, 

71 out_h, 

72 out_w, 

73 # Total number of tasks on axis 0 

74 task_num_0, 

75 # Pooling parameters 

76 kernel_h: tl.constexpr, 

77 kernel_w: tl.constexpr, 

78 stride_h: tl.constexpr, 

79 stride_w: tl.constexpr, 

80 padding_h: tl.constexpr, 

81 padding_w: tl.constexpr, 

82 dilation_h: tl.constexpr, 

83 dilation_w: tl.constexpr, 

84 # Meta-parameters for tiling 

85 BLOCK_H: tl.constexpr, 

86 BLOCK_W: tl.constexpr, 

87): 

88 task_num_1 = tl.cdiv(out_h, BLOCK_H) * tl.cdiv(out_w, BLOCK_W) 

89 grid_0 = tl.num_programs(0) 

90 grid_1 = tl.num_programs(1) 

91 pid_nc = tl.program_id(0) 

92 while pid_nc < task_num_0: 

93 pid_hw = tl.program_id(1) 

94 while pid_hw < task_num_1: 

95 num_w_blocks = tl.cdiv(out_w, BLOCK_W) 

96 h_block_idx = pid_hw // num_w_blocks 

97 w_block_idx = pid_hw % num_w_blocks 

98 n_idx = pid_nc // in_c 

99 c_idx = pid_nc % in_c 

100 

101 h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

102 w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

103 

104 dtype = input_ptr.type.element_ty 

105 min_val = get_dtype_min(dtype) 

106 max_val_acc = tl.full((BLOCK_H, BLOCK_W), min_val, dtype=dtype) 

107 max_idx_acc = tl.full((BLOCK_H, BLOCK_W), -1, dtype=tl.int64) 

108 

109 input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c 

110 

111 for kh in tl.static_range(0, kernel_h): 

112 for kw in tl.static_range(0, kernel_w): 

113 h_in = ( 

114 h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h 

115 ) 

116 w_in = ( 

117 w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w 

118 ) 

119 in_mask = (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w) 

120 input_offset = h_in * in_stride_h + w_in * in_stride_w 

121 current_val = tl.load( 

122 input_base_ptr + input_offset, mask=in_mask, other=min_val 

123 ) 

124 current_idx = h_in * in_w + w_in 

125 

126 is_new_max = current_val > max_val_acc 

127 max_val_acc = tl.where(is_new_max, current_val, max_val_acc) 

128 max_idx_acc = tl.where( 

129 is_new_max & in_mask, current_idx, max_idx_acc 

130 ) 

131 

132 out_base_ptr = output_ptr + pid_nc * out_h * out_w 

133 indices_base_ptr = indices_ptr + pid_nc * out_h * out_w 

134 out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

135 out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

136 output_block_ptr = ( 

137 out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :] 

138 ) 

139 indices_block_ptr = ( 

140 indices_base_ptr 

141 + out_h_offsets[:, None] * out_w 

142 + out_w_offsets[None, :] 

143 ) 

144 

145 out_mask = (out_h_offsets[:, None] < out_h) & ( 

146 out_w_offsets[None, :] < out_w 

147 ) 

148 tl.store(output_block_ptr, max_val_acc, mask=out_mask) 

149 tl.store(indices_block_ptr, max_idx_acc, mask=out_mask) 

150 pid_hw += grid_1 

151 pid_nc += grid_0 

152 

153 

154@libentry() 

155@libtuner( 

156 configs=[ 

157 triton.Config({"BLOCK_IN_H": 16, "BLOCK_IN_W": 32}, num_warps=1, num_stages=0), 

158 triton.Config({"BLOCK_IN_H": 16, "BLOCK_IN_W": 32}, num_warps=1, num_stages=5), 

159 triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 64}, num_warps=1, num_stages=0), 

160 triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 64}, num_warps=1, num_stages=5), 

161 triton.Config({"BLOCK_IN_H": 8, "BLOCK_IN_W": 16}, num_warps=1, num_stages=0), 

162 triton.Config({"BLOCK_IN_H": 8, "BLOCK_IN_W": 16}, num_warps=1, num_stages=5), 

163 triton.Config({"BLOCK_IN_H": 8, "BLOCK_IN_W": 32}, num_warps=1, num_stages=0), 

164 triton.Config({"BLOCK_IN_H": 8, "BLOCK_IN_W": 8}, num_warps=1, num_stages=0), 

165 triton.Config({"BLOCK_IN_H": 8, "BLOCK_IN_W": 8}, num_warps=1, num_stages=5), 

166 triton.Config({"BLOCK_IN_H": 16, "BLOCK_IN_W": 16}, num_warps=1, num_stages=5), 

167 triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 32}, num_warps=1, num_stages=0), 

168 triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 32}, num_warps=1, num_stages=5), 

169 ], 

170 key=["in_h", "in_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], 

171 strategy=["align32", "align32", "align32", "align32", "align32", "align32"], 

172 warmup=5, 

173 rep=10, 

174) 

175@triton.jit 

176def max_pool2d_backward_kernel( 

177 grad_output_ptr, 

178 indices_ptr, 

179 grad_input_ptr, 

180 # Shape info 

181 in_h, 

182 in_w, 

183 out_h, 

184 out_w, 

185 # Strides for grad_output/indices 

186 out_stride_nc, 

187 out_stride_h, 

188 out_stride_w, 

189 # Total number of tasks on axis 0 

190 task_num_0, 

191 # Pooling parameters 

192 kernel_h: tl.constexpr, 

193 kernel_w: tl.constexpr, 

194 stride_h: tl.constexpr, 

195 stride_w: tl.constexpr, 

196 padding_h: tl.constexpr, 

197 padding_w: tl.constexpr, 

198 dilation_h: tl.constexpr, 

199 dilation_w: tl.constexpr, 

200 # Tiling parameters 

201 BLOCK_IN_H: tl.constexpr, 

202 BLOCK_IN_W: tl.constexpr, 

203): 

204 task_num_1 = tl.cdiv(in_h, BLOCK_IN_H) * tl.cdiv(in_w, BLOCK_IN_W) 

205 grid_0 = tl.num_programs(0) 

206 grid_1 = tl.num_programs(1) 

207 nc_idx = tl.program_id(0) 

208 while nc_idx < task_num_0: 

209 pid_hw = tl.program_id(1) 

210 while pid_hw < task_num_1: 

211 num_w_blocks = tl.cdiv(in_w, BLOCK_IN_W) 

212 h_block_idx = pid_hw // num_w_blocks 

213 w_block_idx = pid_hw % num_w_blocks 

214 

215 h_in_offsets = h_block_idx * BLOCK_IN_H + tl.arange(0, BLOCK_IN_H) 

216 w_in_offsets = w_block_idx * BLOCK_IN_W + tl.arange(0, BLOCK_IN_W) 

217 

218 current_input_flat_idx = ( 

219 h_in_offsets[:, None] * in_w + w_in_offsets[None, :] 

220 ) 

221 grad_acc = tl.zeros((BLOCK_IN_H, BLOCK_IN_W), dtype=tl.float32) 

222 

223 indices_base_ptr = indices_ptr + nc_idx * out_stride_nc 

224 grad_output_base_ptr = grad_output_ptr + nc_idx * out_stride_nc 

225 

226 for kh in tl.static_range(0, kernel_h): 

227 for kw in tl.static_range(0, kernel_w): 

228 numerator_h = h_in_offsets[:, None] + padding_h - kh * dilation_h 

229 numerator_w = w_in_offsets[None, :] + padding_w - kw * dilation_w 

230 

231 valid_map_mask = (numerator_h % stride_h == 0) & ( 

232 numerator_w % stride_w == 0 

233 ) 

234 h_out = numerator_h // stride_h 

235 w_out = numerator_w // stride_w 

236 out_bounds_mask = ( 

237 (h_out >= 0) & (h_out < out_h) & (w_out >= 0) & (w_out < out_w) 

238 ) 

239 load_mask = valid_map_mask & out_bounds_mask 

240 

241 safe_h_out = tl.where(load_mask, h_out, 0) 

242 safe_w_out = tl.where(load_mask, w_out, 0) 

243 out_offsets = safe_h_out * out_stride_h + safe_w_out 

244 

245 indices_block = tl.load( 

246 indices_base_ptr + out_offsets, mask=load_mask, other=-1 

247 ) 

248 match_mask = indices_block == current_input_flat_idx 

249 

250 grad_block = tl.load( 

251 grad_output_base_ptr + out_offsets, mask=match_mask, other=0.0 

252 ) 

253 grad_acc += grad_block 

254 

255 grad_input_base_ptr = grad_input_ptr + nc_idx * in_h * in_w 

256 grad_input_offsets = h_in_offsets[:, None] * in_w + w_in_offsets[None, :] 

257 store_mask = (h_in_offsets[:, None] < in_h) & (w_in_offsets[None, :] < in_w) 

258 tl.store( 

259 grad_input_base_ptr + grad_input_offsets, grad_acc, mask=store_mask 

260 ) 

261 pid_hw += grid_1 

262 nc_idx += grid_0 

263 

264 

265def _parse_pool_params(kernel_size, stride, padding, dilation): 

266 def _parse_param(param, name, default=None): 

267 if param is None: 

268 return default 

269 if isinstance(param, int): 

270 return param, param 

271 if isinstance(param, (list, tuple)) and len(param) == 2: 

272 return param 

273 raise ValueError(f"Invalid {name}: {param}") 

274 

275 kernel_h, kernel_w = _parse_param(kernel_size, "kernel_size") 

276 stride_h, stride_w = _parse_param(stride, "stride", default=(kernel_h, kernel_w)) 

277 padding_h, padding_w = _parse_param(padding, "padding", default=(0, 0)) 

278 dilation_h, dilation_w = _parse_param(dilation, "dilation", default=(1, 1)) 

279 

280 if stride_h <= 0 or stride_w <= 0: 

281 raise ValueError( 

282 f"stride must be positive, but got stride=({stride_h}, {stride_w})" 

283 ) 

284 if padding_h < 0 or padding_w < 0: 

285 raise ValueError( 

286 f"padding must be non-negative, but got padding=({padding_h}, {padding_w})" 

287 ) 

288 if dilation_h <= 0 or dilation_w <= 0: 

289 raise ValueError( 

290 f"dilation must be positive, but got dilation=({dilation_h}, {dilation_w})" 

291 ) 

292 

293 return ( 

294 kernel_h, 

295 kernel_w, 

296 stride_h, 

297 stride_w, 

298 padding_h, 

299 padding_w, 

300 dilation_h, 

301 dilation_w, 

302 ) 

303 

304 

305def max_pool2d_with_indices( 

306 input: torch.Tensor, 

307 kernel_size, 

308 stride=None, 

309 padding=0, 

310 dilation=1, 

311 ceil_mode=False, 

312): 

313 logger.debug("GEMS_CAMBRICON MAX_POOL2D_WITH_INDICES FORWARD") 

314 input = input.contiguous() 

315 

316 params = _parse_pool_params(kernel_size, stride, padding, dilation) 

317 ( 

318 kernel_h, 

319 kernel_w, 

320 stride_h, 

321 stride_w, 

322 padding_h, 

323 padding_w, 

324 dilation_h, 

325 dilation_w, 

326 ) = params 

327 

328 in_n, in_c, in_h, in_w = input.shape 

329 out_h = max_pool2d_output_size( 

330 in_h, kernel_h, stride_h, padding_h, dilation_h, ceil_mode 

331 ) 

332 out_w = max_pool2d_output_size( 

333 in_w, kernel_w, stride_w, padding_w, dilation_w, ceil_mode 

334 ) 

335 

336 output = torch.empty( 

337 (in_n, in_c, out_h, out_w), device=input.device, dtype=input.dtype 

338 ) 

339 indices = torch.empty( 

340 (in_n, in_c, out_h, out_w), device=input.device, dtype=torch.int64 

341 ) 

342 

343 if output.numel() == 0: 

344 return output, indices 

345 

346 def grid(meta): 

347 grid_0 = in_n * in_c 

348 grid_1 = triton.cdiv(out_h, meta["BLOCK_H"]) * triton.cdiv( 

349 out_w, meta["BLOCK_W"] 

350 ) 

351 return limit_grid(grid_0, grid_1) 

352 

353 task_num_0 = in_n * in_c 

354 max_pool2d_forward_kernel[grid]( 

355 input, 

356 output, 

357 indices, 

358 input.stride(0), 

359 input.stride(1), 

360 input.stride(2), 

361 input.stride(3), 

362 in_c, 

363 in_h, 

364 in_w, 

365 out_h, 

366 out_w, 

367 task_num_0, 

368 kernel_h, 

369 kernel_w, 

370 stride_h, 

371 stride_w, 

372 padding_h, 

373 padding_w, 

374 dilation_h, 

375 dilation_w, 

376 is_linear=True, 

377 ) 

378 

379 return output, indices 

380 

381 

382def max_pool2d_backward( 

383 grad_output: torch.Tensor, 

384 input: torch.Tensor, 

385 indices: torch.Tensor, 

386 kernel_size, 

387 stride, 

388 padding, 

389 dilation, 

390 ceil_mode, 

391): 

392 logger.debug("GEMS_CAMBRICON MAX_POOL2D_WITH_INDICES BACKWARD") 

393 grad_output = grad_output.contiguous() 

394 indices = indices.contiguous() 

395 

396 params = _parse_pool_params(kernel_size, stride, padding, dilation) 

397 ( 

398 kernel_h, 

399 kernel_w, 

400 stride_h, 

401 stride_w, 

402 padding_h, 

403 padding_w, 

404 dilation_h, 

405 dilation_w, 

406 ) = params 

407 

408 in_n, in_c, in_h, in_w = input.shape 

409 out_h, out_w = grad_output.shape[2], grad_output.shape[3] 

410 

411 grad_input = torch.zeros_like(input, dtype=torch.float32) 

412 

413 if grad_input.numel() == 0: 

414 return grad_input.to(grad_output.dtype) 

415 

416 def grid(meta): 

417 grid_0 = in_n * in_c 

418 grid_1 = triton.cdiv(in_h, meta["BLOCK_IN_H"]) * triton.cdiv( 

419 in_w, meta["BLOCK_IN_W"] 

420 ) 

421 return limit_grid(grid_0, grid_1) 

422 

423 task_num_0 = in_n * in_c 

424 

425 out_stride_nc = out_h * out_w 

426 out_stride_h = out_w 

427 out_stride_w = 1 

428 

429 max_pool2d_backward_kernel[grid]( 

430 grad_output, 

431 indices, 

432 grad_input, 

433 in_h, 

434 in_w, 

435 out_h, 

436 out_w, 

437 out_stride_nc, 

438 out_stride_h, 

439 out_stride_w, 

440 task_num_0, 

441 kernel_h, 

442 kernel_w, 

443 stride_h, 

444 stride_w, 

445 padding_h, 

446 padding_w, 

447 dilation_h, 

448 dilation_w, 

449 is_linear=True, 

450 ) 

451 

452 return grad_input.to(grad_output.dtype)