Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/mm.py: 0%

97 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry, libtuner 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14@libentry() 

15@libtuner( 

16 configs=runtime.get_tuned_config("mm"), 

17 key=["M", "N", "K", "stride_am", "stride_bk", "stride_ak", "stride_bn"], 

18 strategy=[ 

19 "align32", 

20 "align32", 

21 "align32", 

22 "align32", 

23 "align32", 

24 "align32", 

25 "align32", 

26 ], 

27 warmup=1, 

28 rep=2, 

29) 

30@triton.heuristics(runtime.get_heuristic_config("mm")) 

31@triton.jit 

32def mm_kernel( 

33 A, 

34 B, 

35 C, 

36 M, 

37 N, 

38 K, 

39 stride_am, 

40 stride_ak, 

41 stride_bk, 

42 stride_bn, 

43 stride_cm, 

44 stride_cn, 

45 dot_out_dtype: tl.constexpr, 

46 BLOCK_M: tl.constexpr, 

47 BLOCK_N: tl.constexpr, 

48 BLOCK_K: tl.constexpr, 

49 GROUP_M: tl.constexpr, 

50 SPLIT_K: tl.constexpr, 

51 EVEN_K: tl.constexpr, 

52 UPCAST: tl.constexpr, 

53): 

54 # matrix multiplication 

55 if UPCAST: 

56 pid = tl.program_id(0).to(tl.int64) 

57 pid_z = tl.program_id(1).to(tl.int64) 

58 else: 

59 pid = tl.program_id(0) 

60 pid_z = tl.program_id(1) 

61 grid_m = tl.cdiv(M, BLOCK_M) 

62 grid_n = tl.cdiv(N, BLOCK_N) 

63 # re-order program ID for better L2 performance 

64 width = GROUP_M * grid_n 

65 group_id = pid // width 

66 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

67 pid_m = group_id * GROUP_M + (pid % group_size) 

68 pid_n = (pid % width) // (group_size) 

69 # do matrix multiplication 

70 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

71 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

72 rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) 

73 # pointers 

74 A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak) 

75 B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn) 

76 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) 

77 for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): 

78 if EVEN_K: 

79 a = tl.load(A, mask=(rm < M)[:, None], other=0.0) 

80 b = tl.load(B, mask=(rn < N)[None, :], other=0.0) 

81 else: 

82 k_remaining = K - k * (BLOCK_K * SPLIT_K) 

83 a = tl.load( 

84 A, mask=(rk[None, :] < k_remaining) & (rm < M)[:, None], other=0.0 

85 ) 

86 b = tl.load( 

87 B, mask=(rk[:, None] < k_remaining) & (rn < N)[None, :], other=0.0 

88 ) 

89 

90 if a.dtype != b.dtype: 

91 a = a.to(C.dtype.element_ty) 

92 b = b.to(C.dtype.element_ty) 

93 acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=False) 

94 A += BLOCK_K * SPLIT_K * stride_ak 

95 B += BLOCK_K * SPLIT_K * stride_bk 

96 acc = acc.to(C.dtype.element_ty) 

97 # rematerialize rm and rn to save registers 

98 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

99 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

100 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

101 mask = (rm < M)[:, None] & (rn < N)[None, :] 

102 # handles write-back with reduction-splitting 

103 if SPLIT_K == 1: 

104 tl.store(C, acc, mask=mask) 

105 else: 

106 tl.atomic_add(C, acc, mask=mask) 

107 

108 

109_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] 

110 

111 

112def get_higher_dtype(a, b): 

113 if a is b: 

114 return a 

115 

116 assert a in _ordered_datatypes 

117 assert b in _ordered_datatypes 

118 

119 for d in _ordered_datatypes: 

120 if a is d: 

121 return b 

122 if b is d: 

123 return a 

124 

125 

126def mm(a, b): 

127 logger.debug("GEMS_TSINGMICRO MM") 

128 device = a.device 

129 # handle non-contiguous inputs if necessary 

130 if a.stride(0) > 1 and a.stride(1) > 1: 

131 a = a.contiguous() 

132 if b.stride(0) > 1 and b.stride(1) > 1: 

133 b = b.contiguous() 

134 # if not a.is_contiguous(): 

135 # a = a.contiguous() 

136 # if not b.is_contiguous(): 

137 # b = b.contiguous() 

138 # checks constraints 

139 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

140 M, K = a.shape 

141 _, N = b.shape 

142 # allocates output 

143 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

144 c = torch.empty((M, N), device=device, dtype=c_dtype) 

145 dot_out_dtype = tl.float32 

146 UPCAST = ( 

147 M * max(a.stride(0), c.stride(0)) >= 1 << 31 

148 or N * max(b.stride(1), c.stride(1)) >= 1 << 31 

149 or K * max(a.stride(1), b.stride(0)) >= 1 << 31 

150 ) 

151 # launch kernel 

152 grid = lambda META: ( 

153 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

154 META["SPLIT_K"], 

155 ) 

156 with torch_device_fn.device(a.device): 

157 mm_kernel[grid]( 

158 a, 

159 b, 

160 c, 

161 M, 

162 N, 

163 K, 

164 a.stride(0), 

165 a.stride(1), 

166 b.stride(0), 

167 b.stride(1), 

168 c.stride(0), 

169 c.stride(1), 

170 dot_out_dtype=dot_out_dtype, 

171 GROUP_M=8, 

172 UPCAST=UPCAST, 

173 ) 

174 return c 

175 

176 

177def mm_out(a, b, *, out): 

178 logger.debug("GEMS_TSINGMICRO MM_OUT") 

179 # handle non-contiguous inputs if necessary 

180 if a.stride(0) > 1 and a.stride(1) > 1: 

181 a = a.contiguous() 

182 if b.stride(0) > 1 and b.stride(1) > 1: 

183 b = b.contiguous() 

184 # if not a.is_contiguous(): 

185 # a = a.contiguous() 

186 # if not b.is_contiguous(): 

187 # b = b.contiguous() 

188 # checks constraints 

189 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

190 M, K = a.shape 

191 _, N = b.shape 

192 # allocates output 

193 c = out 

194 dot_out_dtype = tl.float32 

195 UPCAST = ( 

196 M * max(a.stride(0), c.stride(0)) >= 1 << 31 

197 or N * max(b.stride(1), c.stride(1)) >= 1 << 31 

198 or K * max(a.stride(1), b.stride(0)) >= 1 << 31 

199 ) 

200 # launch kernel 

201 grid = lambda META: ( 

202 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

203 META["SPLIT_K"], 

204 ) 

205 with torch_device_fn.device(a.device): 

206 mm_kernel[grid]( 

207 a, 

208 b, 

209 c, 

210 M, 

211 N, 

212 K, 

213 a.stride(0), 

214 a.stride(1), 

215 b.stride(0), 

216 b.stride(1), 

217 c.stride(0), 

218 c.stride(1), 

219 dot_out_dtype=dot_out_dtype, 

220 GROUP_M=8, 

221 UPCAST=UPCAST, 

222 ) 

223 return c