Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/addmm.py: 0%

85 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8# from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import broadcastable_to, libentry 

11from flag_gems.utils import triton_lang_extension as ext 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16autotune_decorator = triton.autotune( 

17 configs=[], 

18 generate_configs="addmm", 

19 key=["M", "N", "K"], 

20) 

21 

22 

23KLX_USE_AUTOTUNE = os.environ.get("KLX_USE_AUTOTUNE", "1") == "1" 

24 

25if not KLX_USE_AUTOTUNE: 

26 

27 def heur_block_m(args): 

28 M = args["M"] 

29 if M == 1: 

30 return 2 

31 if M <= 32: 

32 return M 

33 

34 return 128 

35 

36 def heur_block_n(args): 

37 N = args["N"] 

38 if N == 1: 

39 return 2 

40 if N <= 32: 

41 return N 

42 return 128 

43 

44 def heur_block_k(args): 

45 K = args["K"] 

46 return min(K, 128) 

47 

48 autotune_decorator = triton.heuristics( 

49 { 

50 "BLOCK_SIZE_M": heur_block_m, 

51 "BLOCK_SIZE_N": heur_block_n, 

52 "BLOCK_SIZE_K": heur_block_k, 

53 } 

54 ) 

55 

56 

57@libentry() 

58@autotune_decorator 

59@triton.jit(do_not_specialize=["alpha", "beta"]) 

60def addmm_kernel( 

61 a_ptr, 

62 b_ptr, 

63 i_ptr, 

64 c_ptr, 

65 alpha, 

66 beta, 

67 M, 

68 N, 

69 K, 

70 stride_am, 

71 stride_ak, 

72 stride_bk, 

73 stride_bn, 

74 stride_im, 

75 stride_in, 

76 stride_cm, 

77 stride_cn, 

78 BLOCK_SIZE_M: tl.constexpr, 

79 BLOCK_SIZE_N: tl.constexpr, 

80 BLOCK_SIZE_K: tl.constexpr, 

81): 

82 pid_m = ext.program_id(0) 

83 pid_n = ext.program_id(1) 

84 

85 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

86 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

87 offs_k = tl.arange(0, BLOCK_SIZE_K) 

88 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

89 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

90 

91 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

92 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

93 a = tl.load( 

94 a_ptrs, 

95 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), 

96 other=0.0, 

97 ) 

98 b = tl.load( 

99 b_ptrs, 

100 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N), 

101 other=0.0, 

102 ) 

103 accumulator += tl.dot(a, b, allow_tf32=False) 

104 a_ptrs += BLOCK_SIZE_K * stride_ak 

105 b_ptrs += BLOCK_SIZE_K * stride_bk 

106 

107 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

108 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

109 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

110 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

111 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :] 

112 bias = tl.load(i_ptrs, mask=c_mask, other=0.0) 

113 

114 accumulator = accumulator * alpha + bias * beta 

115 c = accumulator.to(bias.dtype) 

116 tl.store(c_ptrs, c, mask=c_mask) 

117 

118 

119def addmm(bias, mat1, mat2, *, beta=1.0, alpha=1.0): 

120 logger.debug("GEMS_KUNLUNXIN ADDMM") 

121 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

122 assert broadcastable_to( 

123 bias.shape, (mat1.shape[0], mat2.shape[1]) 

124 ), "Incompatible input shape" 

125 M, K = mat1.shape 

126 _, N = mat2.shape 

127 

128 mat1 = mat1.contiguous() 

129 # mat2 = mat2.contiguous() 

130 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

131 bias = bias.broadcast_to(out.shape) 

132 

133 grid = lambda META: ( 

134 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

135 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

136 ) 

137 with torch_device_fn.device(mat1.device): 

138 addmm_kernel[grid]( 

139 mat1, 

140 mat2, 

141 bias, 

142 out, 

143 alpha, 

144 beta, 

145 M, 

146 N, 

147 K, 

148 mat1.stride(0), 

149 mat1.stride(1), 

150 mat2.stride(0), 

151 mat2.stride(1), 

152 bias.stride(0), 

153 bias.stride(1), 

154 out.stride(0), 

155 out.stride(1), 

156 ) 

157 return out 

158 

159 

160def addmm_out(bias, mat1, mat2, *, beta=1.0, alpha=1.0, out=None): 

161 logger.debug("GEMS_KUNLUNXIN ADDMM_OUT") 

162 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

163 assert broadcastable_to( 

164 bias.shape, (mat1.shape[0], mat2.shape[1]) 

165 ), "Incompatible input shape" 

166 M, K = mat1.shape 

167 _, N = mat2.shape 

168 if out is None: 

169 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

170 else: 

171 assert out.shape == (M, N), "Incompatible output shape" 

172 

173 mat1 = mat1.contiguous() 

174 bias = bias.broadcast_to(out.shape) 

175 

176 grid = lambda META: ( 

177 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

178 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

179 ) 

180 with torch_device_fn.device(mat1.device): 

181 addmm_kernel[grid]( 

182 mat1, 

183 mat2, 

184 bias, 

185 out, 

186 alpha, 

187 beta, 

188 M, 

189 N, 

190 K, 

191 mat1.stride(0), 

192 mat1.stride(1), 

193 mat2.stride(0), 

194 mat2.stride(1), 

195 bias.stride(0), 

196 bias.stride(1), 

197 out.stride(0), 

198 out.stride(1), 

199 ) 

200 return out