Coverage for src/flag_gems/runtime/backend/_cambricon/ops/max_pool2d_with_indices.py: 0%
168 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import libentry, libtuner
8from flag_gems.utils.limits import get_dtype_min
10from ..utils import MAX_GRID_SIZE_X, MAX_GRID_SIZE_Y
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15def max_pool2d_output_size(
16 in_size: int,
17 kernel_size: int,
18 stride: int,
19 padding: int,
20 dilation: int,
21 ceil_mode: bool = False,
22) -> int:
23 effective_kernel_size = (kernel_size - 1) * dilation + 1
24 numerator = in_size + 2 * padding - effective_kernel_size
25 if ceil_mode:
26 output_size = (numerator + stride - 1) // stride + 1
27 # PyTorch-compatible adjustment for ceil_mode
28 if (output_size - 1) * stride >= in_size + padding:
29 output_size -= 1
30 else:
31 output_size = numerator // stride + 1
33 return output_size
36def limit_grid(grid_0, grid_1):
37 grid_0_ub = MAX_GRID_SIZE_X // 4
38 grid_1_ub = MAX_GRID_SIZE_Y
39 return min(grid_0, grid_0_ub), min(grid_1, grid_1_ub)
42@libentry()
43@libtuner(
44 configs=[
45 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4),
46 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4),
47 triton.Config({"BLOCK_H": 16, "BLOCK_W": 8}, num_stages=5, num_warps=1),
48 triton.Config({"BLOCK_H": 8, "BLOCK_W": 16}, num_stages=5, num_warps=1),
49 triton.Config({"BLOCK_H": 8, "BLOCK_W": 8}, num_stages=5, num_warps=1),
50 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=4),
51 ],
52 key=["out_h", "out_w", "kernel_h", "kernel_w", "stride_h", "stride_w"],
53 strategy=["align32", "align32", "align32", "align32", "align32", "align32"],
54 warmup=5,
55 rep=10,
56)
57@triton.jit
58def max_pool2d_forward_kernel(
59 input_ptr,
60 output_ptr,
61 indices_ptr,
62 # Input tensor strides
63 in_stride_n,
64 in_stride_c,
65 in_stride_h,
66 in_stride_w,
67 # Input/Output shapes
68 in_c,
69 in_h,
70 in_w,
71 out_h,
72 out_w,
73 # Total number of tasks on axis 0
74 task_num_0,
75 # Pooling parameters
76 kernel_h: tl.constexpr,
77 kernel_w: tl.constexpr,
78 stride_h: tl.constexpr,
79 stride_w: tl.constexpr,
80 padding_h: tl.constexpr,
81 padding_w: tl.constexpr,
82 dilation_h: tl.constexpr,
83 dilation_w: tl.constexpr,
84 # Meta-parameters for tiling
85 BLOCK_H: tl.constexpr,
86 BLOCK_W: tl.constexpr,
87):
88 task_num_1 = tl.cdiv(out_h, BLOCK_H) * tl.cdiv(out_w, BLOCK_W)
89 grid_0 = tl.num_programs(0)
90 grid_1 = tl.num_programs(1)
91 pid_nc = tl.program_id(0)
92 while pid_nc < task_num_0:
93 pid_hw = tl.program_id(1)
94 while pid_hw < task_num_1:
95 num_w_blocks = tl.cdiv(out_w, BLOCK_W)
96 h_block_idx = pid_hw // num_w_blocks
97 w_block_idx = pid_hw % num_w_blocks
98 n_idx = pid_nc // in_c
99 c_idx = pid_nc % in_c
101 h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
102 w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
104 dtype = input_ptr.type.element_ty
105 min_val = get_dtype_min(dtype)
106 max_val_acc = tl.full((BLOCK_H, BLOCK_W), min_val, dtype=dtype)
107 max_idx_acc = tl.full((BLOCK_H, BLOCK_W), -1, dtype=tl.int64)
109 input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c
111 for kh in tl.static_range(0, kernel_h):
112 for kw in tl.static_range(0, kernel_w):
113 h_in = (
114 h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h
115 )
116 w_in = (
117 w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w
118 )
119 in_mask = (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w)
120 input_offset = h_in * in_stride_h + w_in * in_stride_w
121 current_val = tl.load(
122 input_base_ptr + input_offset, mask=in_mask, other=min_val
123 )
124 current_idx = h_in * in_w + w_in
126 is_new_max = current_val > max_val_acc
127 max_val_acc = tl.where(is_new_max, current_val, max_val_acc)
128 max_idx_acc = tl.where(
129 is_new_max & in_mask, current_idx, max_idx_acc
130 )
132 out_base_ptr = output_ptr + pid_nc * out_h * out_w
133 indices_base_ptr = indices_ptr + pid_nc * out_h * out_w
134 out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
135 out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
136 output_block_ptr = (
137 out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :]
138 )
139 indices_block_ptr = (
140 indices_base_ptr
141 + out_h_offsets[:, None] * out_w
142 + out_w_offsets[None, :]
143 )
145 out_mask = (out_h_offsets[:, None] < out_h) & (
146 out_w_offsets[None, :] < out_w
147 )
148 tl.store(output_block_ptr, max_val_acc, mask=out_mask)
149 tl.store(indices_block_ptr, max_idx_acc, mask=out_mask)
150 pid_hw += grid_1
151 pid_nc += grid_0
154@libentry()
155@libtuner(
156 configs=[
157 triton.Config({"BLOCK_IN_H": 16, "BLOCK_IN_W": 32}, num_warps=1, num_stages=0),
158 triton.Config({"BLOCK_IN_H": 16, "BLOCK_IN_W": 32}, num_warps=1, num_stages=5),
159 triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 64}, num_warps=1, num_stages=0),
160 triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 64}, num_warps=1, num_stages=5),
161 triton.Config({"BLOCK_IN_H": 8, "BLOCK_IN_W": 16}, num_warps=1, num_stages=0),
162 triton.Config({"BLOCK_IN_H": 8, "BLOCK_IN_W": 16}, num_warps=1, num_stages=5),
163 triton.Config({"BLOCK_IN_H": 8, "BLOCK_IN_W": 32}, num_warps=1, num_stages=0),
164 triton.Config({"BLOCK_IN_H": 8, "BLOCK_IN_W": 8}, num_warps=1, num_stages=0),
165 triton.Config({"BLOCK_IN_H": 8, "BLOCK_IN_W": 8}, num_warps=1, num_stages=5),
166 triton.Config({"BLOCK_IN_H": 16, "BLOCK_IN_W": 16}, num_warps=1, num_stages=5),
167 triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 32}, num_warps=1, num_stages=0),
168 triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 32}, num_warps=1, num_stages=5),
169 ],
170 key=["in_h", "in_w", "kernel_h", "kernel_w", "stride_h", "stride_w"],
171 strategy=["align32", "align32", "align32", "align32", "align32", "align32"],
172 warmup=5,
173 rep=10,
174)
175@triton.jit
176def max_pool2d_backward_kernel(
177 grad_output_ptr,
178 indices_ptr,
179 grad_input_ptr,
180 # Shape info
181 in_h,
182 in_w,
183 out_h,
184 out_w,
185 # Strides for grad_output/indices
186 out_stride_nc,
187 out_stride_h,
188 out_stride_w,
189 # Total number of tasks on axis 0
190 task_num_0,
191 # Pooling parameters
192 kernel_h: tl.constexpr,
193 kernel_w: tl.constexpr,
194 stride_h: tl.constexpr,
195 stride_w: tl.constexpr,
196 padding_h: tl.constexpr,
197 padding_w: tl.constexpr,
198 dilation_h: tl.constexpr,
199 dilation_w: tl.constexpr,
200 # Tiling parameters
201 BLOCK_IN_H: tl.constexpr,
202 BLOCK_IN_W: tl.constexpr,
203):
204 task_num_1 = tl.cdiv(in_h, BLOCK_IN_H) * tl.cdiv(in_w, BLOCK_IN_W)
205 grid_0 = tl.num_programs(0)
206 grid_1 = tl.num_programs(1)
207 nc_idx = tl.program_id(0)
208 while nc_idx < task_num_0:
209 pid_hw = tl.program_id(1)
210 while pid_hw < task_num_1:
211 num_w_blocks = tl.cdiv(in_w, BLOCK_IN_W)
212 h_block_idx = pid_hw // num_w_blocks
213 w_block_idx = pid_hw % num_w_blocks
215 h_in_offsets = h_block_idx * BLOCK_IN_H + tl.arange(0, BLOCK_IN_H)
216 w_in_offsets = w_block_idx * BLOCK_IN_W + tl.arange(0, BLOCK_IN_W)
218 current_input_flat_idx = (
219 h_in_offsets[:, None] * in_w + w_in_offsets[None, :]
220 )
221 grad_acc = tl.zeros((BLOCK_IN_H, BLOCK_IN_W), dtype=tl.float32)
223 indices_base_ptr = indices_ptr + nc_idx * out_stride_nc
224 grad_output_base_ptr = grad_output_ptr + nc_idx * out_stride_nc
226 for kh in tl.static_range(0, kernel_h):
227 for kw in tl.static_range(0, kernel_w):
228 numerator_h = h_in_offsets[:, None] + padding_h - kh * dilation_h
229 numerator_w = w_in_offsets[None, :] + padding_w - kw * dilation_w
231 valid_map_mask = (numerator_h % stride_h == 0) & (
232 numerator_w % stride_w == 0
233 )
234 h_out = numerator_h // stride_h
235 w_out = numerator_w // stride_w
236 out_bounds_mask = (
237 (h_out >= 0) & (h_out < out_h) & (w_out >= 0) & (w_out < out_w)
238 )
239 load_mask = valid_map_mask & out_bounds_mask
241 safe_h_out = tl.where(load_mask, h_out, 0)
242 safe_w_out = tl.where(load_mask, w_out, 0)
243 out_offsets = safe_h_out * out_stride_h + safe_w_out
245 indices_block = tl.load(
246 indices_base_ptr + out_offsets, mask=load_mask, other=-1
247 )
248 match_mask = indices_block == current_input_flat_idx
250 grad_block = tl.load(
251 grad_output_base_ptr + out_offsets, mask=match_mask, other=0.0
252 )
253 grad_acc += grad_block
255 grad_input_base_ptr = grad_input_ptr + nc_idx * in_h * in_w
256 grad_input_offsets = h_in_offsets[:, None] * in_w + w_in_offsets[None, :]
257 store_mask = (h_in_offsets[:, None] < in_h) & (w_in_offsets[None, :] < in_w)
258 tl.store(
259 grad_input_base_ptr + grad_input_offsets, grad_acc, mask=store_mask
260 )
261 pid_hw += grid_1
262 nc_idx += grid_0
265def _parse_pool_params(kernel_size, stride, padding, dilation):
266 def _parse_param(param, name, default=None):
267 if param is None:
268 return default
269 if isinstance(param, int):
270 return param, param
271 if isinstance(param, (list, tuple)) and len(param) == 2:
272 return param
273 raise ValueError(f"Invalid {name}: {param}")
275 kernel_h, kernel_w = _parse_param(kernel_size, "kernel_size")
276 stride_h, stride_w = _parse_param(stride, "stride", default=(kernel_h, kernel_w))
277 padding_h, padding_w = _parse_param(padding, "padding", default=(0, 0))
278 dilation_h, dilation_w = _parse_param(dilation, "dilation", default=(1, 1))
280 if stride_h <= 0 or stride_w <= 0:
281 raise ValueError(
282 f"stride must be positive, but got stride=({stride_h}, {stride_w})"
283 )
284 if padding_h < 0 or padding_w < 0:
285 raise ValueError(
286 f"padding must be non-negative, but got padding=({padding_h}, {padding_w})"
287 )
288 if dilation_h <= 0 or dilation_w <= 0:
289 raise ValueError(
290 f"dilation must be positive, but got dilation=({dilation_h}, {dilation_w})"
291 )
293 return (
294 kernel_h,
295 kernel_w,
296 stride_h,
297 stride_w,
298 padding_h,
299 padding_w,
300 dilation_h,
301 dilation_w,
302 )
305def max_pool2d_with_indices(
306 input: torch.Tensor,
307 kernel_size,
308 stride=None,
309 padding=0,
310 dilation=1,
311 ceil_mode=False,
312):
313 logger.debug("GEMS_CAMBRICON MAX_POOL2D_WITH_INDICES FORWARD")
314 input = input.contiguous()
316 params = _parse_pool_params(kernel_size, stride, padding, dilation)
317 (
318 kernel_h,
319 kernel_w,
320 stride_h,
321 stride_w,
322 padding_h,
323 padding_w,
324 dilation_h,
325 dilation_w,
326 ) = params
328 in_n, in_c, in_h, in_w = input.shape
329 out_h = max_pool2d_output_size(
330 in_h, kernel_h, stride_h, padding_h, dilation_h, ceil_mode
331 )
332 out_w = max_pool2d_output_size(
333 in_w, kernel_w, stride_w, padding_w, dilation_w, ceil_mode
334 )
336 output = torch.empty(
337 (in_n, in_c, out_h, out_w), device=input.device, dtype=input.dtype
338 )
339 indices = torch.empty(
340 (in_n, in_c, out_h, out_w), device=input.device, dtype=torch.int64
341 )
343 if output.numel() == 0:
344 return output, indices
346 def grid(meta):
347 grid_0 = in_n * in_c
348 grid_1 = triton.cdiv(out_h, meta["BLOCK_H"]) * triton.cdiv(
349 out_w, meta["BLOCK_W"]
350 )
351 return limit_grid(grid_0, grid_1)
353 task_num_0 = in_n * in_c
354 max_pool2d_forward_kernel[grid](
355 input,
356 output,
357 indices,
358 input.stride(0),
359 input.stride(1),
360 input.stride(2),
361 input.stride(3),
362 in_c,
363 in_h,
364 in_w,
365 out_h,
366 out_w,
367 task_num_0,
368 kernel_h,
369 kernel_w,
370 stride_h,
371 stride_w,
372 padding_h,
373 padding_w,
374 dilation_h,
375 dilation_w,
376 is_linear=True,
377 )
379 return output, indices
382def max_pool2d_backward(
383 grad_output: torch.Tensor,
384 input: torch.Tensor,
385 indices: torch.Tensor,
386 kernel_size,
387 stride,
388 padding,
389 dilation,
390 ceil_mode,
391):
392 logger.debug("GEMS_CAMBRICON MAX_POOL2D_WITH_INDICES BACKWARD")
393 grad_output = grad_output.contiguous()
394 indices = indices.contiguous()
396 params = _parse_pool_params(kernel_size, stride, padding, dilation)
397 (
398 kernel_h,
399 kernel_w,
400 stride_h,
401 stride_w,
402 padding_h,
403 padding_w,
404 dilation_h,
405 dilation_w,
406 ) = params
408 in_n, in_c, in_h, in_w = input.shape
409 out_h, out_w = grad_output.shape[2], grad_output.shape[3]
411 grad_input = torch.zeros_like(input, dtype=torch.float32)
413 if grad_input.numel() == 0:
414 return grad_input.to(grad_output.dtype)
416 def grid(meta):
417 grid_0 = in_n * in_c
418 grid_1 = triton.cdiv(in_h, meta["BLOCK_IN_H"]) * triton.cdiv(
419 in_w, meta["BLOCK_IN_W"]
420 )
421 return limit_grid(grid_0, grid_1)
423 task_num_0 = in_n * in_c
425 out_stride_nc = out_h * out_w
426 out_stride_h = out_w
427 out_stride_w = 1
429 max_pool2d_backward_kernel[grid](
430 grad_output,
431 indices,
432 grad_input,
433 in_h,
434 in_w,
435 out_h,
436 out_w,
437 out_stride_nc,
438 out_stride_h,
439 out_stride_w,
440 task_num_0,
441 kernel_h,
442 kernel_w,
443 stride_h,
444 stride_w,
445 padding_h,
446 padding_w,
447 dilation_h,
448 dilation_w,
449 is_linear=True,
450 )
452 return grad_input.to(grad_output.dtype)