Coverage for src/flag_gems/ops/addmm.py: 50%

86 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import broadcastable_to, libentry, libtuner 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16@libentry() 

17@libtuner( 

18 configs=runtime.ops_get_configs("addmm", pre_hook=None) 

19 if os.environ.get("USE_FLAGTUNE") == "1" 

20 else runtime.get_tuned_config("addmm"), 

21 key=["M", "N", "K"], 

22 strategy=runtime.get_expand_config("addmm")["strategy"] 

23 if os.environ.get("USE_FLAGTUNE") == "1" 

24 else ["align32", "align32", "align32"], 

25 warmup=5, 

26 rep=10, 

27) 

28@triton.jit(do_not_specialize=["alpha", "beta"]) 

29def addmm_kernel( 

30 a_ptr, 

31 b_ptr, 

32 i_ptr, 

33 c_ptr, 

34 alpha, 

35 beta, 

36 M, 

37 N, 

38 K, 

39 stride_am, 

40 stride_ak, 

41 stride_bk, 

42 stride_bn, 

43 stride_im, 

44 stride_in, 

45 stride_cm, 

46 stride_cn, 

47 BLOCK_SIZE_M: tl.constexpr, 

48 BLOCK_SIZE_N: tl.constexpr, 

49 BLOCK_SIZE_K: tl.constexpr, 

50 IS_FP64: tl.constexpr = False, 

51): 

52 pid_m = tle.program_id(0) 

53 pid_n = tle.program_id(1) 

54 

55 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

56 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

57 offs_k = tl.arange(0, BLOCK_SIZE_K) 

58 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

59 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

60 

61 if IS_FP64: 

62 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float64) 

63 else: 

64 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

65 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

66 a = tl.load( 

67 a_ptrs, 

68 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), 

69 other=0.0, 

70 ) 

71 b = tl.load( 

72 b_ptrs, 

73 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N), 

74 other=0.0, 

75 ) 

76 if IS_FP64: 

77 a = a.to(tl.float32) 

78 b = b.to(tl.float32) 

79 accumulator += tl.dot(a, b, allow_tf32=False) 

80 a_ptrs += BLOCK_SIZE_K * stride_ak 

81 b_ptrs += BLOCK_SIZE_K * stride_bk 

82 

83 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

84 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

85 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

86 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

87 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :] 

88 bias = tl.load(i_ptrs, mask=c_mask, other=0.0) 

89 

90 accumulator = accumulator * alpha + bias * beta 

91 c = accumulator.to(bias.dtype) 

92 tl.store(c_ptrs, c, mask=c_mask) 

93 

94 

95def addmm(bias, mat1, mat2, *, beta=1, alpha=1): 

96 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

97 assert broadcastable_to( 

98 bias.shape, (mat1.shape[0], mat2.shape[1]) 

99 ), "Incompatible input shape" 

100 M, K = mat1.shape 

101 _, N = mat2.shape 

102 

103 logger.debug( 

104 "GEMS ADDMM, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

105 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s", 

106 M, 

107 N, 

108 K, 

109 mat1.stride(0) == 1, 

110 mat2.stride(0) == 1, 

111 bias.stride(0) == 1, 

112 ) 

113 mat1 = mat1.contiguous() 

114 # mat2 = mat2.contiguous() 

115 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

116 bias = bias.broadcast_to(out.shape) 

117 

118 grid = lambda META: ( 

119 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

120 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

121 ) 

122 with torch_device_fn.device(mat1.device): 

123 addmm_kernel[grid]( 

124 mat1, 

125 mat2, 

126 bias, 

127 out, 

128 alpha, 

129 beta, 

130 M, 

131 N, 

132 K, 

133 mat1.stride(0), 

134 mat1.stride(1), 

135 mat2.stride(0), 

136 mat2.stride(1), 

137 bias.stride(0), 

138 bias.stride(1), 

139 out.stride(0), 

140 out.stride(1), 

141 IS_FP64=mat1.dtype == torch.float64, 

142 ) 

143 return out 

144 

145 

146def addmm_out(bias, mat1, mat2, *, beta=1, alpha=1, out=None): 

147 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

148 assert broadcastable_to( 

149 bias.shape, (mat1.shape[0], mat2.shape[1]) 

150 ), "Incompatible input shape" 

151 M, K = mat1.shape 

152 _, N = mat2.shape 

153 if out is None: 

154 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

155 else: 

156 assert out.shape == (M, N), "Incompatible output shape" 

157 logger.debug( 

158 "GEMS ADDMM_OUT, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

159 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s", 

160 M, 

161 N, 

162 K, 

163 mat1.stride(0) == 1, 

164 mat2.stride(0) == 1, 

165 bias.stride(0) == 1, 

166 ) 

167 mat1 = mat1.contiguous() 

168 bias = bias.broadcast_to(out.shape) 

169 

170 grid = lambda META: ( 

171 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

172 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

173 ) 

174 with torch_device_fn.device(mat1.device): 

175 addmm_kernel[grid]( 

176 mat1, 

177 mat2, 

178 bias, 

179 out, 

180 alpha, 

181 beta, 

182 M, 

183 N, 

184 K, 

185 mat1.stride(0), 

186 mat1.stride(1), 

187 mat2.stride(0), 

188 mat2.stride(1), 

189 bias.stride(0), 

190 bias.stride(1), 

191 out.stride(0), 

192 out.stride(1), 

193 IS_FP64=mat1.dtype == torch.float64, 

194 ) 

195 return out 

196 

197 

198def addmm_dtype(bias, mat1, mat2, out_dtype, *, beta=1, alpha=1): 

199 logger.debug("GEMS ADDMM_DTYPE") 

200 out = torch.empty( 

201 (mat1.shape[0], mat2.shape[1]), 

202 device=mat1.device, 

203 dtype=out_dtype, 

204 ) 

205 return addmm_dtype_out(bias, mat1, mat2, out_dtype, beta=beta, alpha=alpha, out=out) 

206 

207 

208def addmm_dtype_out(bias, mat1, mat2, out_dtype, *, beta=1, alpha=1, out): 

209 logger.debug("GEMS ADDMM_DTYPE_OUT") 

210 if mat1.dtype != mat2.dtype: 

211 raise RuntimeError( 

212 f"mat1 and mat2 must have the same dtype, but got {mat1.dtype} and {mat2.dtype}" 

213 ) 

214 if out.dtype != out_dtype: 

215 raise RuntimeError( 

216 "out_dtype must be the same as the dtype of the provided out tensor" 

217 ) 

218 if not ( 

219 out_dtype == mat1.dtype 

220 or ( 

221 out_dtype == torch.float32 and mat1.dtype in (torch.float16, torch.bfloat16) 

222 ) 

223 ): 

224 raise RuntimeError( 

225 "out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs" 

226 ) 

227 if bias.dtype != out_dtype and bias.dtype != mat1.dtype: 

228 raise RuntimeError("self dtype must match either out_dtype or mat1 dtype") 

229 

230 bias_c = bias.to(out_dtype) 

231 return addmm_out(bias_c, mat1, mat2, beta=beta, alpha=alpha, out=out)