Coverage for src/flag_gems/ops/scaled_mm.py: 42%

213 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.fused.cutlass_scaled_mm import cutlass_scaled_mm as _csmm 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, libtuner 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger(__name__) 

14 

15GROUP_M = 8 

16SCALAR_SCALE = 0 

17VECTOR_SCALE = 1 

18ASCEND_ALIGNED_BLOCK = 128 

19ASCEND_ALIGNED_KERNEL_BLOCK = 64 

20ASCEND_ALIGNED_MIN_VOLUME = 512 * 512 * 512 

21 

22 

23def _heur_even_k(args): 

24 return args["K"] % args["BLOCK_K"] == 0 

25 

26 

27@libentry() 

28@libtuner( 

29 configs=runtime.get_tuned_config("scaled_mm"), 

30 key=["M", "N", "K", "stride_am", "stride_bk"], 

31 strategy=["align32", "align32", "align32", "align32", "align32"], 

32 warmup=2, 

33 rep=4, 

34) 

35@triton.heuristics({"EVEN_K": _heur_even_k}) 

36@triton.jit 

37def scaled_mm_kernel( 

38 A, 

39 B, 

40 ScaleA, 

41 ScaleB, 

42 Bias, 

43 C, 

44 M: tl.constexpr, 

45 N: tl.constexpr, 

46 K: tl.constexpr, 

47 stride_am: tl.constexpr, 

48 stride_ak: tl.constexpr, 

49 stride_bk: tl.constexpr, 

50 stride_bn: tl.constexpr, 

51 stride_cm: tl.constexpr, 

52 stride_cn: tl.constexpr, 

53 ACC_DTYPE: tl.constexpr, 

54 SCALE_A_MODE: tl.constexpr, 

55 SCALE_B_MODE: tl.constexpr, 

56 HAS_BIAS: tl.constexpr, 

57 BLOCK_M: tl.constexpr, 

58 BLOCK_N: tl.constexpr, 

59 BLOCK_K: tl.constexpr, 

60 GROUP_M: tl.constexpr, 

61 EVEN_K: tl.constexpr, 

62): 

63 pid = tle.program_id(0) 

64 grid_m = tl.cdiv(M, BLOCK_M) 

65 grid_n = tl.cdiv(N, BLOCK_N) 

66 width = GROUP_M * grid_n 

67 group_id = pid // width 

68 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

69 pid_m = group_id * GROUP_M + (pid % group_size) 

70 pid_n = (pid % width) // group_size 

71 

72 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

73 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

74 offs_m = offs_m.to(tl.int64) 

75 offs_n = offs_n.to(tl.int64) 

76 offs_k = tl.arange(0, BLOCK_K) 

77 

78 a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak 

79 b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn 

80 

81 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_DTYPE) 

82 for k in range(0, tl.cdiv(K, BLOCK_K)): 

83 if EVEN_K: 

84 a = tl.load(a_ptrs, mask=offs_m[:, None] < M, other=0.0) 

85 b = tl.load(b_ptrs, mask=offs_n[None, :] < N, other=0.0) 

86 else: 

87 k_remaining = K - k * BLOCK_K 

88 a = tl.load( 

89 a_ptrs, 

90 mask=(offs_m[:, None] < M) & (offs_k[None, :] < k_remaining), 

91 other=0.0, 

92 ) 

93 b = tl.load( 

94 b_ptrs, 

95 mask=(offs_k[:, None] < k_remaining) & (offs_n[None, :] < N), 

96 other=0.0, 

97 ) 

98 acc += tl.dot(a, b, out_dtype=ACC_DTYPE, allow_tf32=False) 

99 a_ptrs += BLOCK_K * stride_ak 

100 b_ptrs += BLOCK_K * stride_bk 

101 

102 acc = acc.to(tl.float32) 

103 

104 if SCALE_A_MODE == 0: 

105 scale_a = tl.full((BLOCK_M,), tl.load(ScaleA), dtype=tl.float32) 

106 else: 

107 scale_a = tl.load(ScaleA + offs_m, mask=offs_m < M, other=0.0) 

108 

109 if SCALE_B_MODE == 0: 

110 scale_b = tl.full((BLOCK_N,), tl.load(ScaleB), dtype=tl.float32) 

111 else: 

112 scale_b = tl.load(ScaleB + offs_n, mask=offs_n < N, other=0.0) 

113 

114 acc = acc * scale_a[:, None] * scale_b[None, :] 

115 

116 if HAS_BIAS: 

117 bias = tl.load(Bias + offs_n, mask=offs_n < N, other=0.0) 

118 acc += bias[None, :] 

119 

120 c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn 

121 c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) 

122 tl.store(c_ptrs, acc, mask=c_mask) 

123 

124 

125@libentry() 

126@triton.jit 

127def scaled_mm_aligned_kernel( 

128 A, 

129 B, 

130 ScaleA, 

131 ScaleB, 

132 Bias, 

133 C, 

134 M: tl.constexpr, 

135 N: tl.constexpr, 

136 K: tl.constexpr, 

137 stride_am: tl.constexpr, 

138 stride_ak: tl.constexpr, 

139 stride_bk: tl.constexpr, 

140 stride_bn: tl.constexpr, 

141 stride_cm: tl.constexpr, 

142 stride_cn: tl.constexpr, 

143 ACC_DTYPE: tl.constexpr, 

144 SCALE_A_MODE: tl.constexpr, 

145 SCALE_B_MODE: tl.constexpr, 

146 HAS_BIAS: tl.constexpr, 

147 BLOCK_M: tl.constexpr, 

148 BLOCK_N: tl.constexpr, 

149 BLOCK_K: tl.constexpr, 

150 GROUP_M: tl.constexpr, 

151): 

152 pid = tle.program_id(0) 

153 grid_m = tl.cdiv(M, BLOCK_M) 

154 grid_n = tl.cdiv(N, BLOCK_N) 

155 width = GROUP_M * grid_n 

156 group_id = pid // width 

157 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

158 pid_m = group_id * GROUP_M + (pid % group_size) 

159 pid_n = (pid % width) // group_size 

160 

161 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

162 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

163 offs_m = offs_m.to(tl.int64) 

164 offs_n = offs_n.to(tl.int64) 

165 offs_k = tl.arange(0, BLOCK_K) 

166 

167 a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak 

168 b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn 

169 

170 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_DTYPE) 

171 for _ in range(0, tl.cdiv(K, BLOCK_K)): 

172 a = tl.load(a_ptrs) 

173 b = tl.load(b_ptrs) 

174 acc += tl.dot(a, b, out_dtype=ACC_DTYPE, allow_tf32=False) 

175 a_ptrs += BLOCK_K * stride_ak 

176 b_ptrs += BLOCK_K * stride_bk 

177 

178 acc = acc.to(tl.float32) 

179 

180 if SCALE_A_MODE == 0: 

181 scale_a = tl.full((BLOCK_M,), tl.load(ScaleA), dtype=tl.float32) 

182 else: 

183 scale_a = tl.load(ScaleA + offs_m) 

184 

185 if SCALE_B_MODE == 0: 

186 scale_b = tl.full((BLOCK_N,), tl.load(ScaleB), dtype=tl.float32) 

187 else: 

188 scale_b = tl.load(ScaleB + offs_n) 

189 

190 acc = acc * scale_a[:, None] * scale_b[None, :] 

191 

192 if HAS_BIAS: 

193 bias = tl.load(Bias + offs_n) 

194 acc += bias[None, :] 

195 

196 c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn 

197 tl.store(c_ptrs, acc) 

198 

199 

200def _resolve_out_dtype(self, out_dtype, out=None): 

201 if out_dtype is not None: 

202 if out is not None and out.dtype != out_dtype: 

203 raise RuntimeError( 

204 "out_dtype must be the same as the dtype of the provided out tensor" 

205 ) 

206 return out_dtype 

207 if out is not None: 

208 return out.dtype 

209 return self.dtype 

210 

211 

212def _normalize_scale(scale, expected_size, *, is_left_scale): 

213 if scale.numel() == 1: 

214 return scale.reshape(1).contiguous(), SCALAR_SCALE 

215 

216 valid_vector = scale.ndim == 1 and scale.shape[0] == expected_size 

217 if is_left_scale: 

218 valid_vector = valid_vector or ( 

219 scale.ndim == 2 and scale.shape == (expected_size, 1) 

220 ) 

221 else: 

222 valid_vector = valid_vector or ( 

223 scale.ndim == 2 and scale.shape == (1, expected_size) 

224 ) 

225 

226 if valid_vector: 

227 return scale.reshape(expected_size).contiguous(), VECTOR_SCALE 

228 

229 scale_name = "scale_a" if is_left_scale else "scale_b" 

230 expected_shape = ( 

231 f"({expected_size}, 1)" if is_left_scale else f"(1, {expected_size})" 

232 ) 

233 raise RuntimeError( 

234 f"{scale_name} must be a scalar tensor or have shape {expected_shape}" 

235 ) 

236 

237 

238def _normalize_bias(bias, cols): 

239 if bias is None: 

240 return None 

241 if bias.numel() != cols: 

242 raise RuntimeError(f"Bias must be size {cols} but got {bias.numel()}") 

243 return bias.reshape(cols).contiguous() 

244 

245 

246def _check_inputs(self, mat2): 

247 if self.ndim != 2: 

248 raise RuntimeError("self must be a matrix") 

249 if mat2.ndim != 2: 

250 raise RuntimeError("mat2 must be a matrix") 

251 if self.shape[1] != mat2.shape[0]: 

252 raise RuntimeError( 

253 f"mat1 and mat2 shapes cannot be multiplied ({self.shape[0]}x{self.shape[1]} " 

254 f"and {mat2.shape[0]}x{mat2.shape[1]})" 

255 ) 

256 if self.dtype != mat2.dtype: 

257 raise RuntimeError( 

258 f"self and mat2 must have the same dtype, but got {self.dtype} and {mat2.dtype}" 

259 ) 

260 

261 

262def _maybe_make_contiguous_for_kernel(self, mat2): 

263 if self.stride(0) > 1 and self.stride(1) > 1: 

264 self = self.contiguous() 

265 if mat2.stride(0) > 1 and mat2.stride(1) > 1: 

266 mat2 = mat2.contiguous() 

267 return self, mat2 

268 

269 

270def _can_use_cutlass_scaled_mm(self, mat2, scale_a, scale_b, bias, out): 

271 if self.device.type != "cuda": 

272 return False 

273 is_fp8 = hasattr(torch, "float8_e4m3fn") and self.dtype == torch.float8_e4m3fn 

274 if not (is_fp8 or self.dtype == torch.int8): 

275 return False 

276 if self.dtype != mat2.dtype: 

277 return False 

278 major, minor = torch.cuda.get_device_capability(self.device) 

279 sm_version_num = major * 10 + minor 

280 if not (90 <= sm_version_num < 100): 

281 return False 

282 if scale_a.dtype != torch.float32 or scale_b.dtype != torch.float32: 

283 return False 

284 if scale_a.numel() not in (1, self.shape[0]): 

285 return False 

286 if scale_b.numel() not in (1, mat2.shape[1]): 

287 return False 

288 if not scale_a.is_contiguous() or not scale_b.is_contiguous(): 

289 return False 

290 if self.stride(1) != 1 or out.stride(1) != 1: 

291 return False 

292 if mat2.stride(0) != 1: 

293 return False 

294 if out.stride(0) % 16 != 0 or mat2.stride(1) % 16 != 0: 

295 return False 

296 if bias is not None and (bias.ndim != 1 or not bias.is_contiguous()): 

297 return False 

298 return True 

299 

300 

301def _can_use_ascend_aligned_scaled_mm(self, mat2, out): 

302 if self.device.type != "npu" or runtime.device.vendor_name != "ascend": 

303 return False 

304 if not self.is_floating_point(): 

305 return False 

306 M, K = self.shape 

307 _, N = mat2.shape 

308 return ( 

309 M * N * K >= ASCEND_ALIGNED_MIN_VOLUME 

310 and M % ASCEND_ALIGNED_BLOCK == 0 

311 and N % ASCEND_ALIGNED_BLOCK == 0 

312 and K % ASCEND_ALIGNED_BLOCK == 0 

313 and self.stride(1) == 1 

314 and mat2.stride(1) == 1 

315 and out.stride(1) == 1 

316 ) 

317 

318 

319def _scaled_mm_impl( 

320 self, 

321 mat2, 

322 scale_a, 

323 scale_b, 

324 bias, 

325 out_dtype, 

326 out, 

327): 

328 _check_inputs(self, mat2) 

329 M, K = self.shape 

330 _, N = mat2.shape 

331 

332 output_dtype = _resolve_out_dtype(self, out_dtype, out) 

333 if out is None: 

334 out = torch.empty((M, N), dtype=output_dtype, device=self.device) 

335 else: 

336 if out.shape != (M, N): 

337 raise RuntimeError("Incompatible output shape") 

338 

339 scale_a, scale_a_mode = _normalize_scale(scale_a, M, is_left_scale=True) 

340 scale_b, scale_b_mode = _normalize_scale(scale_b, N, is_left_scale=False) 

341 bias = _normalize_bias(bias, N) 

342 

343 if M == 0 or N == 0: 

344 return out 

345 

346 if _can_use_cutlass_scaled_mm(self, mat2, scale_a, scale_b, bias, out): 

347 with torch_device_fn.device(self.device): 

348 _csmm(out, self, mat2, scale_a, scale_b, bias) 

349 return out 

350 

351 self, mat2 = _maybe_make_contiguous_for_kernel(self, mat2) 

352 acc_dtype = tl.float32 if self.is_floating_point() else tl.int32 

353 grid = lambda META: ( 

354 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

355 ) 

356 with torch_device_fn.device(self.device): 

357 if _can_use_ascend_aligned_scaled_mm(self, mat2, out): 

358 block = ASCEND_ALIGNED_KERNEL_BLOCK 

359 aligned_grid = (triton.cdiv(M, block) * triton.cdiv(N, block),) 

360 scaled_mm_aligned_kernel[aligned_grid]( 

361 self, 

362 mat2, 

363 scale_a, 

364 scale_b, 

365 bias, 

366 out, 

367 M, 

368 N, 

369 K, 

370 self.stride(0), 

371 self.stride(1), 

372 mat2.stride(0), 

373 mat2.stride(1), 

374 out.stride(0), 

375 out.stride(1), 

376 ACC_DTYPE=acc_dtype, 

377 SCALE_A_MODE=scale_a_mode, 

378 SCALE_B_MODE=scale_b_mode, 

379 HAS_BIAS=bias is not None, 

380 BLOCK_M=block, 

381 BLOCK_N=block, 

382 BLOCK_K=block, 

383 GROUP_M=GROUP_M, 

384 ) 

385 else: 

386 scaled_mm_kernel[grid]( 

387 self, 

388 mat2, 

389 scale_a, 

390 scale_b, 

391 bias, 

392 out, 

393 M, 

394 N, 

395 K, 

396 self.stride(0), 

397 self.stride(1), 

398 mat2.stride(0), 

399 mat2.stride(1), 

400 out.stride(0), 

401 out.stride(1), 

402 ACC_DTYPE=acc_dtype, 

403 SCALE_A_MODE=scale_a_mode, 

404 SCALE_B_MODE=scale_b_mode, 

405 HAS_BIAS=bias is not None, 

406 GROUP_M=GROUP_M, 

407 ) 

408 return out 

409 

410 

411def scaled_mm( 

412 self, 

413 mat2, 

414 scale_a, 

415 scale_b, 

416 bias=None, 

417 scale_result=None, 

418 out_dtype=None, 

419 use_fast_accum=False, 

420): 

421 logger.debug("GEMS SCALED_MM") 

422 return _scaled_mm_impl(self, mat2, scale_a, scale_b, bias, out_dtype, None) 

423 

424 

425def scaled_mm_out( 

426 self, 

427 mat2, 

428 scale_a, 

429 scale_b, 

430 bias=None, 

431 scale_result=None, 

432 out_dtype=None, 

433 use_fast_accum=False, 

434 *, 

435 out, 

436): 

437 logger.debug("GEMS SCALED_MM_OUT") 

438 return _scaled_mm_impl(self, mat2, scale_a, scale_b, bias, out_dtype, out)