Coverage for src/flag_gems/ops/cudnn_convolution.py: 84%

25 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3from flag_gems.ops.conv1d import conv1d 

4from flag_gems.ops.conv2d import conv2d 

5from flag_gems.ops.conv3d import conv3d 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10def cudnn_convolution( 

11 input, 

12 weight, 

13 padding, 

14 stride, 

15 dilation, 

16 groups, 

17 benchmark, 

18 deterministic, 

19 allow_tf32, 

20): 

21 """ 

22 CUDNN convolution operation. 

23 

24 This is a lower-level convolution operation that does not include bias. 

25 It supports 1D, 2D, and 3D convolutions based on the input dimensionality. 

26 

27 Args: 

28 input: Input tensor of shape (N, C_in, *spatial_dims) 

29 weight: Weight tensor of shape (C_out, C_in/groups, *kernel_dims) 

30 padding: Padding for each spatial dimension 

31 stride: Stride for each spatial dimension 

32 dilation: Dilation for each spatial dimension 

33 groups: Number of groups for grouped convolution 

34 benchmark: cuDNN benchmark flag (ignored in Triton implementation) 

35 deterministic: cuDNN deterministic flag (ignored in Triton implementation) 

36 allow_tf32: Allow TF32 computation flag (ignored in Triton implementation) 

37 

38 Returns: 

39 Output tensor after convolution 

40 """ 

41 logger.debug("GEMS CUDNN_CONVOLUTION") 

42 

43 ndim = input.ndim - 2 

44 

45 # Extract values from lists if they are lists (cudnn_convolution receives lists) 

46 def extract_param(param, expected_len): 

47 if isinstance(param, (list, tuple)): 

48 if len(param) == expected_len: 

49 return param if expected_len > 1 else param[0] 

50 elif len(param) == 1: 

51 return param[0] 

52 return param 

53 

54 if ndim == 1: 

55 # For 1D convolution, extract single values from lists 

56 stride_val = extract_param(stride, 1) 

57 padding_val = extract_param(padding, 1) 

58 dilation_val = extract_param(dilation, 1) 

59 return conv1d( 

60 input, 

61 weight, 

62 bias=None, 

63 stride=stride_val, 

64 padding=padding_val, 

65 dilation=dilation_val, 

66 groups=groups, 

67 ) 

68 elif ndim == 2: 

69 return conv2d( 

70 input, 

71 weight, 

72 bias=None, 

73 stride=stride, 

74 padding=padding, 

75 dilation=dilation, 

76 groups=groups, 

77 ) 

78 elif ndim == 3: 

79 return conv3d( 

80 input, 

81 weight, 

82 bias=None, 

83 stride=stride, 

84 padding=padding, 

85 dilation=dilation, 

86 groups=groups, 

87 ) 

88 else: 

89 raise ValueError( 

90 f"cudnn_convolution only supports 1D, 2D, and 3D convolutions, " 

91 f"got input with {ndim} spatial dimensions" 

92 )