Coverage for src/flag_gems/runtime/backend/_spacemit/ops/addmm.py: 0%

53 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6import triton.language.extra.smt as smt 

7 

8from flag_gems import runtime 

9from flag_gems.utils import libentry, libtuner 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@libtuner( 

16 configs=runtime.get_tuned_config("addmm_spacemit"), 

17 key=["M", "N", "K"], 

18) 

19@triton.jit 

20def addmm_kernel( 

21 a_ptr, 

22 b_ptr, 

23 bias_ptr, 

24 c_ptr, 

25 alpha, 

26 beta, 

27 M, 

28 N, 

29 K, 

30 stride_am, 

31 stride_ak, 

32 stride_bk, 

33 stride_bn, 

34 stride_im, 

35 stride_in, 

36 stride_cm, 

37 stride_cn, 

38 BLOCK_SIZE_M: tl.constexpr, 

39 BLOCK_SIZE_N: tl.constexpr, 

40 EVEN_K: tl.constexpr, 

41 BLOCK_SIZE_K: tl.constexpr, 

42 MICRO_M: tl.constexpr, 

43 MICRO_K: tl.constexpr, 

44 MICRO_N: tl.constexpr, 

45 SUB_BLK_K: tl.constexpr, 

46): 

47 pid_m = tl.program_id(0) 

48 pid_n = tl.program_id(1) 

49 

50 a_block_ptr = tl.make_block_ptr( 

51 base=a_ptr, 

52 shape=[M, K], 

53 strides=[stride_am, stride_ak], 

54 offsets=[pid_m * BLOCK_SIZE_M, 0], 

55 block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], 

56 order=[1, 0], 

57 ) 

58 

59 b_block_ptr = tl.make_block_ptr( 

60 base=b_ptr, 

61 shape=[K, N], 

62 strides=[stride_bk, stride_bn], 

63 offsets=[0, pid_n * BLOCK_SIZE_N], 

64 block_shape=[BLOCK_SIZE_K, BLOCK_SIZE_N], 

65 order=[1, 0], 

66 ) 

67 

68 if EVEN_K: 

69 a_descriptor_load = smt.descriptor_load(a_block_ptr, (0, 0)) 

70 a = smt.view( 

71 a_descriptor_load, 

72 (0, 0), 

73 (BLOCK_SIZE_M, BLOCK_SIZE_K), 

74 (MICRO_M, MICRO_K), 

75 ) 

76 b_descriptor_load = smt.descriptor_load(b_block_ptr, (0, 0)) 

77 b = smt.view( 

78 b_descriptor_load, 

79 (0, 0), 

80 (BLOCK_SIZE_K, BLOCK_SIZE_N), 

81 (MICRO_K, MICRO_N), 

82 ) 

83 accumulator = smt.dot(a, b) 

84 else: 

85 accumulator = tl.zeros( 

86 (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=a_ptr.type.element_ty 

87 ) 

88 accumulator = smt.view( 

89 accumulator, (0, 0), (BLOCK_SIZE_M, BLOCK_SIZE_N), (MICRO_M, MICRO_N) 

90 ) 

91 sub_num = (K + SUB_BLK_K - 1) // SUB_BLK_K 

92 for k in tl.range(0, sub_num): 

93 a_descriptor_load = smt.descriptor_load(a_block_ptr, (0, 0)) 

94 a = smt.view( 

95 a_descriptor_load, 

96 (0, k * SUB_BLK_K), 

97 (BLOCK_SIZE_M, SUB_BLK_K), 

98 (MICRO_M, MICRO_K), 

99 ) 

100 b_descriptor_load = smt.descriptor_load(b_block_ptr, (0, 0)) 

101 b = smt.view( 

102 b_descriptor_load, 

103 (k * SUB_BLK_K, 0), 

104 (SUB_BLK_K, BLOCK_SIZE_N), 

105 (MICRO_K, MICRO_N), 

106 ) 

107 accumulator += smt.dot(a, b) 

108 accumulator = smt.view(accumulator, (0, 0), (BLOCK_SIZE_M, BLOCK_SIZE_N), (1, 1)) 

109 

110 bias_block_ptr = tl.make_block_ptr( 

111 base=bias_ptr, 

112 shape=[M, N], 

113 strides=[stride_im, stride_in], 

114 offsets=[pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], 

115 block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], 

116 order=[1, 0], 

117 ) 

118 bias = tl.load(bias_block_ptr, boundary_check=(0, 1)) 

119 accumulator = accumulator * alpha + bias * beta 

120 c = accumulator.to(c_ptr.dtype.element_ty) 

121 

122 c_block_ptr = tl.make_block_ptr( 

123 base=c_ptr, 

124 shape=[M, N], 

125 strides=[stride_cm, stride_cn], 

126 offsets=[pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], 

127 block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], 

128 order=[1, 0], 

129 ) 

130 

131 tl.store(c_block_ptr, c, boundary_check=(0, 1)) 

132 

133 

134def addmm(bias, mat1, mat2, *, beta=1, alpha=1): 

135 logger.debug("GEMS_SPACEMIT ADDMM") 

136 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

137 M, K = mat1.shape 

138 _, N = mat2.shape 

139 

140 mat1 = mat1.contiguous() 

141 mat2 = mat2.contiguous() 

142 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

143 bias = bias.broadcast_to(out.shape).contiguous() 

144 

145 def grid(META): 

146 return ( 

147 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

148 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

149 ) 

150 

151 BLOCK_SIZE_K = triton.next_power_of_2(K) 

152 SUB_BLK_K = min(1024, BLOCK_SIZE_K) 

153 

154 addmm_kernel[grid]( 

155 mat1, 

156 mat2, 

157 bias, 

158 out, 

159 alpha, 

160 beta, 

161 M, 

162 N, 

163 K, 

164 mat1.stride(0), 

165 mat1.stride(1), 

166 mat2.stride(0), 

167 mat2.stride(1), 

168 bias.stride(0), 

169 bias.stride(1), 

170 out.stride(0), 

171 out.stride(1), 

172 BLOCK_SIZE_K=BLOCK_SIZE_K, 

173 SUB_BLK_K=SUB_BLK_K, 

174 ) 

175 return out