Coverage for src/flag_gems/ops/addmm.py: 49%

85 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import broadcastable_to, libentry, libtuner 

10from flag_gems.utils import triton_lang_extension as ext 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@libtuner( 

17 configs=runtime.get_tuned_config("addmm"), 

18 key=["M", "N", "K"], 

19 strategy=["align32", "align32", "align32"], 

20 warmup=5, 

21 rep=10, 

22 flagtune_op_name="addmm", 

23) 

24@triton.jit(do_not_specialize=["alpha", "beta"]) 

25def addmm_kernel( 

26 a_ptr, 

27 b_ptr, 

28 i_ptr, 

29 c_ptr, 

30 alpha, 

31 beta, 

32 M, 

33 N, 

34 K, 

35 stride_am, 

36 stride_ak, 

37 stride_bk, 

38 stride_bn, 

39 stride_im, 

40 stride_in, 

41 stride_cm, 

42 stride_cn, 

43 BLOCK_SIZE_M: tl.constexpr, 

44 BLOCK_SIZE_N: tl.constexpr, 

45 BLOCK_SIZE_K: tl.constexpr, 

46 IS_FP64: tl.constexpr = False, 

47): 

48 pid_m = ext.program_id(0) 

49 pid_n = ext.program_id(1) 

50 

51 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

52 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

53 offs_k = tl.arange(0, BLOCK_SIZE_K) 

54 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

55 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

56 

57 if IS_FP64: 

58 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float64) 

59 else: 

60 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

61 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

62 a = tl.load( 

63 a_ptrs, 

64 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), 

65 other=0.0, 

66 ) 

67 b = tl.load( 

68 b_ptrs, 

69 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N), 

70 other=0.0, 

71 ) 

72 if IS_FP64: 

73 a = a.to(tl.float32) 

74 b = b.to(tl.float32) 

75 accumulator += tl.dot(a, b, allow_tf32=False) 

76 a_ptrs += BLOCK_SIZE_K * stride_ak 

77 b_ptrs += BLOCK_SIZE_K * stride_bk 

78 

79 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

80 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

81 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

82 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

83 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :] 

84 bias = tl.load(i_ptrs, mask=c_mask, other=0.0) 

85 

86 accumulator = accumulator * alpha + bias * beta 

87 c = accumulator.to(bias.dtype) 

88 tl.store(c_ptrs, c, mask=c_mask) 

89 

90 

91def addmm(bias, mat1, mat2, *, beta=1, alpha=1): 

92 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

93 assert broadcastable_to( 

94 bias.shape, (mat1.shape[0], mat2.shape[1]) 

95 ), "Incompatible input shape" 

96 M, K = mat1.shape 

97 _, N = mat2.shape 

98 

99 logger.debug( 

100 "GEMS ADDMM, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

101 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s", 

102 M, 

103 N, 

104 K, 

105 mat1.stride(0) == 1, 

106 mat2.stride(0) == 1, 

107 bias.stride(0) == 1, 

108 ) 

109 mat1 = mat1.contiguous() 

110 # mat2 = mat2.contiguous() 

111 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

112 bias = bias.broadcast_to(out.shape) 

113 

114 grid = lambda META: ( 

115 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

116 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

117 ) 

118 with torch_device_fn.device(mat1.device): 

119 addmm_kernel[grid]( 

120 mat1, 

121 mat2, 

122 bias, 

123 out, 

124 alpha, 

125 beta, 

126 M, 

127 N, 

128 K, 

129 mat1.stride(0), 

130 mat1.stride(1), 

131 mat2.stride(0), 

132 mat2.stride(1), 

133 bias.stride(0), 

134 bias.stride(1), 

135 out.stride(0), 

136 out.stride(1), 

137 IS_FP64=mat1.dtype == torch.float64, 

138 ) 

139 return out 

140 

141 

142def addmm_out(bias, mat1, mat2, *, beta=1, alpha=1, out=None): 

143 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

144 assert broadcastable_to( 

145 bias.shape, (mat1.shape[0], mat2.shape[1]) 

146 ), "Incompatible input shape" 

147 M, K = mat1.shape 

148 _, N = mat2.shape 

149 if out is None: 

150 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

151 else: 

152 assert out.shape == (M, N), "Incompatible output shape" 

153 logger.debug( 

154 "GEMS ADDMM_OUT, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

155 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s", 

156 M, 

157 N, 

158 K, 

159 mat1.stride(0) == 1, 

160 mat2.stride(0) == 1, 

161 bias.stride(0) == 1, 

162 ) 

163 mat1 = mat1.contiguous() 

164 bias = bias.broadcast_to(out.shape) 

165 

166 grid = lambda META: ( 

167 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

168 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

169 ) 

170 with torch_device_fn.device(mat1.device): 

171 addmm_kernel[grid]( 

172 mat1, 

173 mat2, 

174 bias, 

175 out, 

176 alpha, 

177 beta, 

178 M, 

179 N, 

180 K, 

181 mat1.stride(0), 

182 mat1.stride(1), 

183 mat2.stride(0), 

184 mat2.stride(1), 

185 bias.stride(0), 

186 bias.stride(1), 

187 out.stride(0), 

188 out.stride(1), 

189 IS_FP64=mat1.dtype == torch.float64, 

190 ) 

191 return out 

192 

193 

194def addmm_dtype(bias, mat1, mat2, out_dtype, *, beta=1, alpha=1): 

195 logger.debug("GEMS ADDMM_DTYPE") 

196 out = torch.empty( 

197 (mat1.shape[0], mat2.shape[1]), 

198 device=mat1.device, 

199 dtype=out_dtype, 

200 ) 

201 return addmm_dtype_out(bias, mat1, mat2, out_dtype, beta=beta, alpha=alpha, out=out) 

202 

203 

204def addmm_dtype_out(bias, mat1, mat2, out_dtype, *, beta=1, alpha=1, out): 

205 logger.debug("GEMS ADDMM_DTYPE_OUT") 

206 if mat1.dtype != mat2.dtype: 

207 raise RuntimeError( 

208 f"mat1 and mat2 must have the same dtype, but got {mat1.dtype} and {mat2.dtype}" 

209 ) 

210 if out.dtype != out_dtype: 

211 raise RuntimeError( 

212 "out_dtype must be the same as the dtype of the provided out tensor" 

213 ) 

214 if not ( 

215 out_dtype == mat1.dtype 

216 or ( 

217 out_dtype == torch.float32 and mat1.dtype in (torch.float16, torch.bfloat16) 

218 ) 

219 ): 

220 raise RuntimeError( 

221 "out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs" 

222 ) 

223 if bias.dtype != out_dtype and bias.dtype != mat1.dtype: 

224 raise RuntimeError("self dtype must match either out_dtype or mat1 dtype") 

225 

226 bias_c = bias.to(out_dtype) 

227 return addmm_out(bias_c, mat1, mat2, beta=beta, alpha=alpha, out=out)