Coverage for src/flag_gems/ops/cudnn_convolution.py: 84%
25 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3from flag_gems.ops.conv1d import conv1d
4from flag_gems.ops.conv2d import conv2d
5from flag_gems.ops.conv3d import conv3d
7logger = logging.getLogger(__name__)
10def cudnn_convolution(
11 input,
12 weight,
13 padding,
14 stride,
15 dilation,
16 groups,
17 benchmark,
18 deterministic,
19 allow_tf32,
20):
21 """
22 CUDNN convolution operation.
24 This is a lower-level convolution operation that does not include bias.
25 It supports 1D, 2D, and 3D convolutions based on the input dimensionality.
27 Args:
28 input: Input tensor of shape (N, C_in, *spatial_dims)
29 weight: Weight tensor of shape (C_out, C_in/groups, *kernel_dims)
30 padding: Padding for each spatial dimension
31 stride: Stride for each spatial dimension
32 dilation: Dilation for each spatial dimension
33 groups: Number of groups for grouped convolution
34 benchmark: cuDNN benchmark flag (ignored in Triton implementation)
35 deterministic: cuDNN deterministic flag (ignored in Triton implementation)
36 allow_tf32: Allow TF32 computation flag (ignored in Triton implementation)
38 Returns:
39 Output tensor after convolution
40 """
41 logger.debug("GEMS CUDNN_CONVOLUTION")
43 ndim = input.ndim - 2
45 # Extract values from lists if they are lists (cudnn_convolution receives lists)
46 def extract_param(param, expected_len):
47 if isinstance(param, (list, tuple)):
48 if len(param) == expected_len:
49 return param if expected_len > 1 else param[0]
50 elif len(param) == 1:
51 return param[0]
52 return param
54 if ndim == 1:
55 # For 1D convolution, extract single values from lists
56 stride_val = extract_param(stride, 1)
57 padding_val = extract_param(padding, 1)
58 dilation_val = extract_param(dilation, 1)
59 return conv1d(
60 input,
61 weight,
62 bias=None,
63 stride=stride_val,
64 padding=padding_val,
65 dilation=dilation_val,
66 groups=groups,
67 )
68 elif ndim == 2:
69 return conv2d(
70 input,
71 weight,
72 bias=None,
73 stride=stride,
74 padding=padding,
75 dilation=dilation,
76 groups=groups,
77 )
78 elif ndim == 3:
79 return conv3d(
80 input,
81 weight,
82 bias=None,
83 stride=stride,
84 padding=padding,
85 dilation=dilation,
86 groups=groups,
87 )
88 else:
89 raise ValueError(
90 f"cudnn_convolution only supports 1D, 2D, and 3D convolutions, "
91 f"got input with {ndim} spatial dimensions"
92 )