Coverage for src/flag_gems/runtime/backend/_sunrise/ops/conv2d.py: 0%
179 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8import flag_gems
9from flag_gems import runtime
10from flag_gems.utils import libentry
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15def conv2d_output_size(
16 in_size: int,
17 kernel_size: int,
18 stride: int,
19 padding: int,
20 dilation: int,
21) -> int:
22 """
23 Determines the output size of a 2D convolution operation.
25 Args:
26 in_size: Input size.
27 kernel_size: Kernel size.
28 stride: Stride.
29 padding: Padding.
30 dilation: Dilation.
32 Returns:
33 Output size of 2D convolution.
34 """
35 return (in_size + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
38@libentry()
39@triton.autotune(
40 configs=runtime.get_tuned_config("conv2d_forward"),
41 key=[
42 "in_n",
43 "weight_c",
44 "input_height",
45 "input_width",
46 "out_c",
47 "out_height",
48 "out_width",
49 "weight_height",
50 "weight_width",
51 "stride_height",
52 "stride_width",
53 "padding_height",
54 "padding_width",
55 "groups",
56 ],
57)
58@triton.jit
59def conv2d_forward_kernel(
60 input_pointer,
61 weight_pointer,
62 output_pointer,
63 bias_pointer,
64 in_n,
65 input_height,
66 input_width,
67 out_c,
68 out_height,
69 out_width,
70 input_n_stride,
71 input_c_stride,
72 input_height_stride,
73 input_width_stride,
74 weight_n_stride,
75 weight_c_stride,
76 weight_height_stride,
77 weight_width_stride,
78 output_n_stride,
79 output_c_stride,
80 output_height_stride,
81 output_width_stride,
82 weight_c: tl.constexpr,
83 weight_height: tl.constexpr,
84 weight_width: tl.constexpr,
85 stride_height: tl.constexpr,
86 stride_width: tl.constexpr,
87 padding_height: tl.constexpr,
88 padding_width: tl.constexpr,
89 dilation_height: tl.constexpr,
90 dilation_width: tl.constexpr,
91 groups: tl.constexpr,
92 BLOCK_NI_HO_WO: tl.constexpr,
93 BLOCK_CI: tl.constexpr,
94 BLOCK_CO: tl.constexpr,
95):
96 pid_ni_ho_wo = tl.program_id(0)
97 pid_co = tl.program_id(1)
98 pid_group = tl.program_id(2)
100 # caculate in_n out_height out_weight value in kernel
101 ni_ho_wo_offset = pid_ni_ho_wo * BLOCK_NI_HO_WO + tl.arange(0, BLOCK_NI_HO_WO)
102 ni_ho_offset = ni_ho_wo_offset // out_width
103 in_n_point_value = ni_ho_offset // out_height
104 output_height_point_value = ni_ho_offset % out_height
105 output_width_point_value = ni_ho_wo_offset % out_width
107 # Load the input and weight pointers. input and weight are of shape
108 # [in_n, groups, in_c, input_height, input_width] and [groups, out_c, in_c, weight_height, weight_width]
109 out_per_group_c = out_c // groups
110 output_c_offset = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO)
111 input_pointer += (
112 input_n_stride * in_n_point_value + input_c_stride * pid_group * weight_c
113 )[:, None]
114 weight_pointer += (
115 weight_n_stride * output_c_offset
116 + weight_n_stride * pid_group * out_per_group_c
117 )[None, :]
119 accum = tl.zeros((BLOCK_NI_HO_WO, BLOCK_CO), dtype=tl.float32)
120 BLOCK_CI_COUNT = (weight_c + BLOCK_CI - 1) // BLOCK_CI
121 for hwc in range(weight_height * weight_width * BLOCK_CI_COUNT):
122 c = (hwc % BLOCK_CI_COUNT) * BLOCK_CI
123 hw = hwc // BLOCK_CI_COUNT
124 h = hw // weight_width
125 w = hw % weight_width
127 input_c_offset = c + tl.arange(0, BLOCK_CI)
128 input_height_offset = (
129 h * dilation_height
130 - padding_height
131 + stride_height * output_height_point_value
132 )
133 input_width_offset = (
134 w * dilation_width - padding_width + stride_width * output_width_point_value
135 )
137 curr_input_pointer = (
138 input_pointer
139 + (input_c_stride * input_c_offset)[None, :]
140 + (input_height_stride * input_height_offset)[:, None]
141 + (input_width_stride * input_width_offset)[:, None]
142 )
143 curr_weight_pointer = (
144 weight_pointer
145 + (weight_c_stride * input_c_offset)[:, None]
146 + (weight_height_stride * h)
147 + (weight_width_stride * w)
148 )
150 input_mask = (
151 (in_n_point_value < in_n)[:, None]
152 & (input_c_offset < weight_c)[None, :]
153 & (0 <= input_height_offset)[:, None]
154 & (input_height_offset < input_height)[:, None]
155 & (0 <= input_width_offset)[:, None]
156 & (input_width_offset < input_width)[:, None]
157 )
158 weight_mask = (input_c_offset < weight_c)[:, None] & (
159 output_c_offset < out_per_group_c
160 )[None, :]
162 input_block = tl.load(curr_input_pointer, mask=input_mask)
163 weight_block = tl.load(curr_weight_pointer, mask=weight_mask)
165 accum += tl.dot(input_block, weight_block, allow_tf32=False)
166 bias_pointer += (pid_group[None] * out_per_group_c)[None, :] + output_c_offset[
167 None, :
168 ]
169 mask_bias = (output_c_offset < out_per_group_c)[None, :]
170 bias = tl.load(bias_pointer, mask_bias).to(tl.float32)
171 accum += bias
172 output_pointer += (
173 (output_n_stride * in_n_point_value)[:, None]
174 + (output_c_stride * (pid_group * out_per_group_c + output_c_offset))[None, :]
175 + (output_height_stride * output_height_point_value)[:, None]
176 + (output_width_stride * output_width_point_value)[:, None]
177 )
178 output_mask = (
179 (in_n_point_value < in_n)[:, None]
180 & (output_c_offset < out_per_group_c)[None, :]
181 & (output_height_point_value < out_height)[:, None]
182 & (output_width_point_value < out_width)[:, None]
183 )
185 tl.store(output_pointer, accum, mask=output_mask)
188@libentry()
189@triton.autotune(
190 configs=runtime.get_tuned_config("conv2d_backward_weight"),
191 key=[
192 "in_n",
193 "input_height",
194 "input_width",
195 "weight_height",
196 "weight_width",
197 "input_c",
198 "stride_height",
199 "stride_width",
200 "out_height",
201 "out_width",
202 "out_c",
203 "padding_height",
204 "padding_width",
205 ],
206)
207@triton.jit
208def conv2d_backward_kernel_weight(
209 input_pointer,
210 out_grad_pointer,
211 weight_pointer,
212 input_n_stride,
213 input_c_stride,
214 input_height_stride,
215 input_width_stride,
216 weight_n_stride,
217 weight_c_stride,
218 weight_height_stride,
219 weight_width_stride,
220 output_n_stride,
221 output_c_stride,
222 output_height_stride,
223 output_width_stride,
224 input_height,
225 input_width,
226 weight_height,
227 weight_width,
228 input_c,
229 in_n,
230 stride_height,
231 stride_width,
232 out_height,
233 out_width,
234 out_c,
235 padding_height,
236 padding_width,
237 dilation_height,
238 dilation_width,
239 BLOCK_NO: tl.constexpr,
240 BLOCK_CI_HK_WK: tl.constexpr,
241 BLOCK_CO: tl.constexpr,
242):
243 # load out_grad n (groups out_c) ho wo
244 # load weight (groups out_c) ci h w
245 # load input n (groups ci) hi wi
247 # init pid and offset 0 for ci*hk*wk, 1 for groups, 2 for co.
248 pid_ci_hk_wk = tl.program_id(0)
249 pid_groups = tl.program_id(1)
250 pid_co = tl.program_id(2)
252 # caculate ci weight_height weight_weight value in kernel
253 ci_hk_wk_offset = pid_ci_hk_wk * BLOCK_CI_HK_WK + tl.arange(0, BLOCK_CI_HK_WK)
254 ci_hk_offset = ci_hk_wk_offset // weight_width
255 ci_point_value = ci_hk_offset // weight_height
256 weight_height_point_value = ci_hk_offset % weight_height
257 weight_width_point_value = ci_hk_wk_offset % weight_width
259 # caculate init pointer info of tensors
260 output_c_offset = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO)
261 out_grad_pointer += (output_c_offset * output_c_stride)[None, :] + (
262 pid_groups[None] * output_c_stride * out_c
263 )[:, None]
265 weight_pointer += (
266 pid_groups * weight_n_stride * out_c + output_c_offset * weight_n_stride
267 )[None, :] + (
268 ci_point_value * weight_c_stride
269 + weight_height_point_value * weight_height_stride
270 + weight_width_point_value * weight_width_stride
271 )[
272 :, None
273 ]
275 input_pointer += (ci_point_value * input_c_stride[None])[:, None] + (
276 pid_groups[None] * input_c_stride * input_c
277 )[None, :]
279 # calculate the values of the input based on the width and height of the output by looping
280 accum = tl.zeros((BLOCK_CI_HK_WK, BLOCK_CO), dtype=tl.float32)
281 for h in range(0, out_height):
282 for w in range(0, out_width):
283 for n in range(0, in_n, BLOCK_NO):
284 output_n_offset = n + tl.arange(0, BLOCK_NO)
286 # caculate input pointer to [cin*kh*kw, *] out_grad pointer to [*, out_c], N*hout*wout as reduce dim
287 curr_out_grad_pointer = (
288 out_grad_pointer
289 + (
290 output_n_offset * output_n_stride
291 + h * output_height_stride
292 + w * output_width_stride
293 )[:, None]
294 )
295 out_grad_mask = (output_n_offset < in_n)[:, None] & (
296 output_c_offset < out_c
297 )[None, :]
299 curr_out_grad = tl.load(curr_out_grad_pointer, mask=out_grad_mask)
301 input_height_offset = (
302 weight_height_point_value * dilation_height
303 - padding_height
304 + stride_height * h
305 )
307 input_width_offset = (
308 weight_width_point_value * dilation_width
309 - padding_width
310 + stride_width * w
311 )
313 curr_input_pointer = (
314 input_pointer
315 + (input_n_stride * output_n_offset)[None, :]
316 + (input_height_stride * input_height_offset)[:, None]
317 + (input_width_stride * input_width_offset)[:, None]
318 )
319 input_mask = (
320 (output_n_offset < in_n)[None, :]
321 & (ci_point_value < input_c)[:, None]
322 & (0 <= input_height_offset)[:, None]
323 & (input_height_offset < input_height)[:, None]
324 & (0 <= input_width_offset)[:, None]
325 & (input_width_offset < input_width)[:, None]
326 )
328 curr_input = tl.load(curr_input_pointer, mask=input_mask)
329 accum += tl.dot(curr_input, curr_out_grad, allow_tf32=False)
331 weight_mask = (
332 (ci_point_value < input_c)[:, None]
333 & (output_c_offset < out_c)[None, :]
334 & (weight_height_point_value < weight_height)[:, None]
335 & (weight_width_point_value < weight_width)[:, None]
336 )
337 tl.store(weight_pointer, accum, weight_mask)
340class Conv2d(torch.autograd.Function):
341 @staticmethod
342 def forward(ctx, input, weight, bias, stride, padding, dilation, groups):
343 logger.debug("GEMS CONV2D")
344 assert weight.ndim == 4, "Weights must be 4D, received shape {weight.shape}"
345 assert (
346 bias is None or bias.ndim == 1
347 ), "Bias must be 1D, received shape {bias.shape}"
349 assert (
350 input.shape[1] == groups * weight.shape[1]
351 ), "Incompatible input ({input.shape}) and weights ({weight.shape}) shape with {groups} groups"
352 assert (
353 bias is None or weight.shape[0] == bias.shape[0]
354 ), "Incompatible weights ({weight.shape}) and bias ({bias.shape}) shape"
356 if isinstance(stride, (list, tuple)):
357 stride_height, stride_width = stride
358 else:
359 stride_height = stride_width = stride
361 if isinstance(padding, (list, tuple)):
362 padding_height, padding_width = padding
363 else:
364 padding_height = padding_width = padding
366 if isinstance(dilation, (list, tuple)):
367 dilation_height, dilation_width = dilation
368 else:
369 dilation_height = dilation_width = dilation
371 in_n, _, input_height, input_width = input.shape
372 out_c, weight_c, weight_height, weight_width = weight.shape
373 out_height = conv2d_output_size(
374 input_height, weight_height, stride_height, padding_height, dilation_height
375 )
376 out_width = conv2d_output_size(
377 input_width, weight_width, stride_width, padding_width, dilation_width
378 )
380 output_dtype = input.dtype
381 output = torch.empty(
382 (in_n, out_c, out_height, out_width),
383 device=input.device,
384 dtype=output_dtype,
385 )
387 # BLOCK_NI_HO_WO along the in_n, out_height, and out_width dimensions,
388 # BLOCK_CO along the out_c,
389 # one group per cat
390 grid = lambda META: (
391 triton.cdiv(in_n * out_height * out_width, META["BLOCK_NI_HO_WO"]),
392 triton.cdiv(int(out_c // groups), META["BLOCK_CO"]),
393 groups,
394 )
396 if bias is None:
397 bias_pointer = torch.zeros(out_c, device=input.device, dtype=output_dtype)
398 else:
399 bias_pointer = bias
400 conv2d_forward_kernel[grid](
401 input,
402 weight,
403 output,
404 bias_pointer,
405 in_n,
406 input_height,
407 input_width,
408 out_c,
409 out_height,
410 out_width,
411 *input.stride(),
412 *weight.stride(),
413 *output.stride(),
414 weight_c,
415 weight_height,
416 weight_width,
417 stride_height,
418 stride_width,
419 padding_height,
420 padding_width,
421 dilation_height,
422 dilation_width,
423 groups=groups,
424 )
426 ctx.save_for_backward(weight, input, bias)
428 ctx.stride = (stride_height, stride_width)
429 ctx.padding = (padding_height, padding_width)
430 ctx.dilation = (dilation_height, dilation_width)
432 ctx.weight_info = (int(out_c / groups), weight_c, weight_height, weight_width)
433 ctx.input_info = (in_n, input_height, input_width)
434 ctx.out_info = (out_height, out_width)
436 ctx.device = input.device
437 ctx.groups = groups
439 return output
441 @staticmethod
442 def backward(ctx, out_grad):
443 logger.debug("GEMS CONV2D VJP")
444 (weight, input, bias) = ctx.saved_tensors
445 # (out_c equals origin cout divide groups)
446 out_c, weight_c, weight_height, weight_width = ctx.weight_info
447 in_n, input_height, input_width = ctx.input_info
448 out_height, out_width = ctx.out_info
450 device = ctx.device
451 groups = ctx.groups
453 stride_height, stride_width = ctx.stride
454 dilation_height, dilation_width = ctx.dilation
455 padding_height, padding_width = ctx.padding
457 revert_padding_height = dilation_height * (weight_height - 1) - padding_height
458 revert_padding_width = dilation_width * (weight_width - 1) - padding_width
459 revert_weight = weight.clone()
460 # revert_weight = torch.flip(revert_weight, dims=[2, 3]).contiguous()
461 revert_weight = flag_gems.flip(revert_weight, dims=[2, 3]).contiguous()
463 if groups != 1:
464 revert_weight = revert_weight.reshape(
465 groups, out_c, weight_c, weight_height, weight_width
466 )
467 revert_weight = revert_weight.transpose(1, 2)
468 revert_weight = revert_weight.reshape(
469 groups * weight_c, out_c, weight_height, weight_width
470 ).contiguous()
471 else:
472 revert_weight = revert_weight.transpose(0, 1).contiguous()
474 new_out_height = out_grad.shape[2] + (stride_height - 1) * (
475 out_grad.shape[2] - 1
476 )
477 new_out_width = out_grad.shape[3] + (stride_width - 1) * (out_grad.shape[3] - 1)
479 new_out = torch.zeros(
480 out_grad.shape[0],
481 out_grad.shape[1],
482 new_out_height,
483 new_out_width,
484 device=device,
485 dtype=out_grad.dtype,
486 )
488 # copy out_grad to new_out
489 if stride_height > 1 or stride_width > 1:
490 for i in range(out_grad.shape[2]):
491 for j in range(out_grad.shape[3]):
492 new_out[:, :, i * (stride_height), j * (stride_width)] = out_grad[
493 :, :, i, j
494 ]
495 else:
496 new_out = out_grad
498 input_back = torch.zeros(
499 in_n,
500 weight_c * groups,
501 input_height,
502 input_width,
503 dtype=torch.float32,
504 device=device,
505 )
507 grid = lambda META: (
508 triton.cdiv(
509 out_grad.shape[0] * input_height * input_width, META["BLOCK_NI_HO_WO"]
510 ),
511 triton.cdiv(int(weight_c), META["BLOCK_CO"]),
512 groups,
513 )
514 bias_zero = torch.zeros(groups * weight_c, device=device, dtype=out_grad.dtype)
515 conv2d_forward_kernel[grid](
516 new_out,
517 revert_weight,
518 input_back,
519 bias_zero,
520 out_grad.shape[0],
521 new_out_height,
522 new_out_width,
523 groups * weight_c,
524 input_height,
525 input_width,
526 *new_out.stride(),
527 *revert_weight.stride(),
528 *input_back.stride(),
529 out_c,
530 weight_height,
531 weight_width,
532 1,
533 1,
534 revert_padding_height,
535 revert_padding_width,
536 dilation_height,
537 dilation_width,
538 groups=groups,
539 )
541 weight_back = torch.zeros(
542 out_c * groups,
543 weight_c,
544 weight_height,
545 weight_width,
546 dtype=weight.dtype,
547 device=device,
548 )
550 grid_weight = lambda meta: (
551 triton.cdiv(
552 weight_c * weight_height * weight_width, meta["BLOCK_CI_HK_WK"]
553 ),
554 groups,
555 triton.cdiv(out_c, meta["BLOCK_CO"]),
556 )
557 conv2d_backward_kernel_weight[grid_weight](
558 input,
559 out_grad,
560 weight_back,
561 *input.stride(),
562 *weight.stride(),
563 *out_grad.stride(),
564 input_height,
565 input_width,
566 weight_height,
567 weight_width,
568 weight_c,
569 in_n,
570 stride_height,
571 stride_width,
572 out_height,
573 out_width,
574 out_c,
575 padding_height,
576 padding_width,
577 dilation_height,
578 dilation_width,
579 )
580 if bias is not None:
581 bias_grad = out_grad.sum(dim=(0, 2, 3))
582 else:
583 bias_grad = None
584 return (
585 input_back,
586 weight_back,
587 bias_grad,
588 None,
589 None,
590 None,
591 None,
592 )
595# todo test SymInt[2] of stride or padding
596def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
597 if isinstance(padding, str):
598 if padding == "same":
599 assert (
600 stride == 1
601 ), "Doesn't support any stride values other than 1 \
602 in padding = 'same' mode, received stride value {stride}"
603 ih = input.shape[-2]
604 iw = input.shape[-1]
605 kernel_size_h = weight.shape[-2]
606 kernel_size_w = weight.shape[-1]
607 padding_h = int(
608 math.ceil(
609 (stride * (ih - 1) + 1 + dilation * (kernel_size_h - 1) - ih) / 2
610 )
611 )
612 padding_w = int(
613 math.ceil(
614 (stride * (iw - 1) + 1 + dilation * (kernel_size_w - 1) - iw) / 2
615 )
616 )
617 oh = int(
618 (ih + 2 * padding_h - dilation * (kernel_size_h - 1) - 1) / stride + 1
619 )
620 ow = int(
621 (iw + 2 * padding_w - dilation * (kernel_size_w - 1) - 1) / stride + 1
622 )
623 padding = max(padding_h, padding_w)
624 return Conv2d.apply(input, weight, bias, stride, padding, dilation, groups)[
625 ..., (oh - ih) :, (ow - iw) :
626 ]
627 elif padding == "valid":
628 return Conv2d.apply(input, weight, bias, stride, 0, dilation, groups)
629 else:
630 raise ValueError(
631 f"Unsupported padding string: {padding}, only'valild'/'same' are allowed."
632 )
633 else:
634 return Conv2d.apply(input, weight, bias, stride, padding, dilation, groups)