Coverage for src/flag_gems/runtime/backend/_mthreads/ops/bmm.py: 0%

135 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7from triton.tools.tensor_descriptor import TensorDescriptor 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry, libtuner 

12from flag_gems.utils import triton_lang_extension as ext 

13 

14logger = logging.getLogger( 

15 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

16) 

17 

18EXPAND_CONFIG_FILENAME = os.path.normpath( 

19 os.path.join(os.path.dirname(__file__), "..", "bmm_mthreads_expand.yaml") 

20) 

21 

22 

23def is_supported_sqmma_layout(tensor): 

24 return tensor.is_contiguous() or ( 

25 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0] 

26 ) 

27 

28 

29def is_sqmma_compatible(a, b, N, K): 

30 return ( 

31 a.dtype == b.dtype 

32 and a.dtype in (torch.float16, torch.bfloat16) 

33 and is_supported_sqmma_layout(a) 

34 and is_supported_sqmma_layout(b) 

35 and N % 8 == 0 

36 and K % 8 == 0 

37 ) 

38 

39 

40@libentry() 

41@libtuner( 

42 configs=runtime.get_tuned_config("bmm"), 

43 key=["M", "N", "K"], 

44 strategy=["align32", "align32", "align32"], 

45) 

46@triton.heuristics(runtime.get_heuristic_config("bmm")) 

47@triton.jit 

48def bmm_kernel( 

49 A, 

50 B, 

51 O, 

52 M, 

53 N, 

54 K, 

55 TILE_M: tl.constexpr, 

56 TILE_N: tl.constexpr, 

57 TILE_K: tl.constexpr, 

58 GROUP_M: tl.constexpr, 

59 DIVISIBLE_M: tl.constexpr, 

60 DIVISIBLE_N: tl.constexpr, 

61 DIVISIBLE_K: tl.constexpr, 

62 IS_FP64: tl.constexpr = False, 

63): 

64 # batch offsets 

65 pid_b = ext.program_id(2) 

66 A += pid_b * M * K 

67 B += pid_b * K * N 

68 O += pid_b * M * N 

69 

70 pidx = ext.program_id(0) 

71 pidy = ext.program_id(1) 

72 

73 if GROUP_M == 1: 

74 pid_m, pid_n = pidx, pidy 

75 else: 

76 # reorder CTAs 

77 gridx = ext.num_programs(0) 

78 gridy = ext.num_programs(1) 

79 pid = pidx + pidy * gridx 

80 

81 num_CTA_per_group = gridy * GROUP_M 

82 

83 group_id = pid // num_CTA_per_group 

84 inner_group_id = pid % num_CTA_per_group 

85 GROUP_SIZE = tl.where( 

86 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M 

87 ) 

88 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE 

89 pid_n = inner_group_id // GROUP_SIZE 

90 

91 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M) 

92 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) 

93 offs_k = tl.arange(0, TILE_K) 

94 

95 if not DIVISIBLE_M: 

96 mask_m = offs_m < M 

97 if not DIVISIBLE_N: 

98 mask_n = offs_n < N 

99 

100 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :] 

101 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :] 

102 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :] 

103 

104 num_iters = tl.cdiv(K, TILE_K) 

105 if IS_FP64: 

106 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float64) 

107 else: 

108 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) 

109 for _ in range(num_iters): 

110 if DIVISIBLE_K: 

111 if DIVISIBLE_M: 

112 mask_a = None 

113 else: 

114 mask_a = mask_m[:, None] 

115 if DIVISIBLE_N: 

116 mask_b = None 

117 else: 

118 mask_b = mask_n[None, :] 

119 else: 

120 mask_k = offs_k < K 

121 if DIVISIBLE_M: 

122 mask_a = mask_k[None, :] 

123 else: 

124 mask_a = mask_m[:, None] & mask_k[None, :] 

125 if DIVISIBLE_N: 

126 mask_b = mask_k[:, None] 

127 else: 

128 mask_b = mask_k[:, None] & mask_n[None, :] 

129 

130 a = tl.load(a_ptrs, mask_a) 

131 b = tl.load(b_ptrs, mask_b) 

132 

133 offs_k += TILE_K 

134 a_ptrs += TILE_K 

135 b_ptrs += TILE_K * N 

136 

137 o += tl.dot(a, b, allow_tf32=False) 

138 

139 if DIVISIBLE_M and DIVISIBLE_N: 

140 mask_c = None 

141 elif DIVISIBLE_M and not DIVISIBLE_N: 

142 mask_c = mask_n[None, :] 

143 elif not DIVISIBLE_M and DIVISIBLE_N: 

144 mask_c = mask_m[:, None] 

145 else: 

146 mask_c = mask_m[:, None] & mask_n[None, :] 

147 tl.store(o_ptrs, o, mask_c) 

148 

149 

150def bmm_fma(A, B): 

151 logger.debug("GEMS_MTHREADS BMM(FMA)") 

152 batch, M, K = A.shape 

153 _, _, N = B.shape 

154 A = A.contiguous() 

155 B = B.contiguous() 

156 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) 

157 

158 grid_fn = lambda meta: ( 

159 triton.cdiv(meta["M"], meta["TILE_M"]), 

160 triton.cdiv(meta["N"], meta["TILE_N"]), 

161 batch, 

162 ) 

163 with torch_device_fn.device(A.device): 

164 bmm_kernel[grid_fn](A, B, out, M, N, K, IS_FP64=A.dtype == torch.float64) 

165 return out 

166 

167 

168def bmm_sqmma_descriptor_pre_hook(nargs): 

169 nargs["a_desc"].block_shape = [nargs["BLOCK_SIZE_M"], nargs["BLOCK_SIZE_K"]] 

170 nargs["b_desc"].block_shape = [nargs["BLOCK_SIZE_K"], nargs["BLOCK_SIZE_N"]] 

171 nargs["c_desc"].block_shape = [nargs["BLOCK_SIZE_M"], nargs["BLOCK_SIZE_N"]] 

172 

173 

174@libentry() 

175@libtuner( 

176 configs=runtime.ops_get_configs( 

177 "bmm_sqmma", 

178 pre_hook=bmm_sqmma_descriptor_pre_hook, 

179 yaml_path=EXPAND_CONFIG_FILENAME, 

180 ) 

181 if os.environ.get("USE_FLAGTUNE") == "1" 

182 else [ 

183 triton.Config( 

184 {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, 

185 num_stages=1, 

186 num_warps=4, 

187 pre_hook=bmm_sqmma_descriptor_pre_hook, 

188 ) 

189 ], 

190 key=["M", "N", "K"], 

191 strategy=runtime.get_expand_config("bmm_sqmma", yaml_path=EXPAND_CONFIG_FILENAME)[ 

192 "strategy" 

193 ][:3] 

194 if os.environ.get("USE_FLAGTUNE") == "1" 

195 else ["align32", "align32", "align32"], 

196 warmup=5, 

197 rep=5, 

198) 

199@triton.jit 

200def bmm_sqmma_kernel( 

201 a_desc, 

202 b_desc, 

203 c_desc, 

204 batch, 

205 M, 

206 N, 

207 K, 

208 BLOCK_SIZE_M: tl.constexpr, 

209 BLOCK_SIZE_N: tl.constexpr, 

210 BLOCK_SIZE_K: tl.constexpr, 

211): 

212 pid = tl.program_id(axis=0) 

213 batch_index = tl.program_id(axis=1) 

214 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 

215 pid_m = pid % num_pid_m 

216 pid_n = pid // num_pid_m 

217 offs_am = (pid_m * BLOCK_SIZE_M + batch_index * M).to(tl.int32) 

218 offs_bn = (pid_n * BLOCK_SIZE_N).to(tl.int32) 

219 offs_ak = 0 

220 offs_ak = offs_ak.to(tl.int32) 

221 offs_bk = (batch_index * K).to(tl.int32) 

222 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

223 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

224 a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_ak]) 

225 b = tl.load_tensor_descriptor(b_desc, [offs_bk, offs_bn]) 

226 accumulator = tl.dot(a, b, acc=accumulator) 

227 offs_ak += BLOCK_SIZE_K 

228 offs_bk += BLOCK_SIZE_K 

229 tl.store_tensor_descriptor(c_desc, [offs_am, offs_bn], accumulator.to(c_desc.dtype)) 

230 

231 

232def bmm_sqmma(A, B, elem_type, batch, M, N, K): 

233 device = "musa" 

234 c_type = elem_type if (elem_type != torch.bfloat16) else torch.float16 

235 C = torch.empty((batch, M, N), dtype=torch.float16, device=device).to(c_type) 

236 desc_a = TensorDescriptor.from_tensor(A.reshape(batch * M, K), [1, 1]) 

237 desc_b = TensorDescriptor.from_tensor(B.reshape(batch * K, N), [1, 1]) 

238 desc_c = TensorDescriptor.from_tensor(C.reshape(batch * M, N), [1, 1]) 

239 grid = lambda META: ( 

240 triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), 

241 batch, 

242 1, 

243 ) 

244 bmm_sqmma_kernel[grid]( 

245 desc_a, 

246 desc_b, 

247 desc_c, 

248 batch, 

249 M, 

250 N, 

251 K, 

252 ) 

253 return C 

254 

255 

256def bmm(a, b): 

257 a_dtype = a.dtype 

258 batch, M, K = a.shape 

259 _, _, N = b.shape 

260 if is_sqmma_compatible(a, b, N, K) and M >= 128: 

261 return bmm_sqmma(a, b, a_dtype, batch, M, N, K) 

262 else: 

263 return bmm_fma(a, b)