Coverage for src/flag_gems/ops/scaled_mm.py: 42%
213 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.fused.cutlass_scaled_mm import cutlass_scaled_mm as _csmm
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, libtuner
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger(__name__)
15GROUP_M = 8
16SCALAR_SCALE = 0
17VECTOR_SCALE = 1
18ASCEND_ALIGNED_BLOCK = 128
19ASCEND_ALIGNED_KERNEL_BLOCK = 64
20ASCEND_ALIGNED_MIN_VOLUME = 512 * 512 * 512
23def _heur_even_k(args):
24 return args["K"] % args["BLOCK_K"] == 0
27@libentry()
28@libtuner(
29 configs=runtime.get_tuned_config("scaled_mm"),
30 key=["M", "N", "K", "stride_am", "stride_bk"],
31 strategy=["align32", "align32", "align32", "align32", "align32"],
32 warmup=2,
33 rep=4,
34)
35@triton.heuristics({"EVEN_K": _heur_even_k})
36@triton.jit
37def scaled_mm_kernel(
38 A,
39 B,
40 ScaleA,
41 ScaleB,
42 Bias,
43 C,
44 M: tl.constexpr,
45 N: tl.constexpr,
46 K: tl.constexpr,
47 stride_am: tl.constexpr,
48 stride_ak: tl.constexpr,
49 stride_bk: tl.constexpr,
50 stride_bn: tl.constexpr,
51 stride_cm: tl.constexpr,
52 stride_cn: tl.constexpr,
53 ACC_DTYPE: tl.constexpr,
54 SCALE_A_MODE: tl.constexpr,
55 SCALE_B_MODE: tl.constexpr,
56 HAS_BIAS: tl.constexpr,
57 BLOCK_M: tl.constexpr,
58 BLOCK_N: tl.constexpr,
59 BLOCK_K: tl.constexpr,
60 GROUP_M: tl.constexpr,
61 EVEN_K: tl.constexpr,
62):
63 pid = tle.program_id(0)
64 grid_m = tl.cdiv(M, BLOCK_M)
65 grid_n = tl.cdiv(N, BLOCK_N)
66 width = GROUP_M * grid_n
67 group_id = pid // width
68 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
69 pid_m = group_id * GROUP_M + (pid % group_size)
70 pid_n = (pid % width) // group_size
72 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
73 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
74 offs_m = offs_m.to(tl.int64)
75 offs_n = offs_n.to(tl.int64)
76 offs_k = tl.arange(0, BLOCK_K)
78 a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
79 b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
81 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_DTYPE)
82 for k in range(0, tl.cdiv(K, BLOCK_K)):
83 if EVEN_K:
84 a = tl.load(a_ptrs, mask=offs_m[:, None] < M, other=0.0)
85 b = tl.load(b_ptrs, mask=offs_n[None, :] < N, other=0.0)
86 else:
87 k_remaining = K - k * BLOCK_K
88 a = tl.load(
89 a_ptrs,
90 mask=(offs_m[:, None] < M) & (offs_k[None, :] < k_remaining),
91 other=0.0,
92 )
93 b = tl.load(
94 b_ptrs,
95 mask=(offs_k[:, None] < k_remaining) & (offs_n[None, :] < N),
96 other=0.0,
97 )
98 acc += tl.dot(a, b, out_dtype=ACC_DTYPE, allow_tf32=False)
99 a_ptrs += BLOCK_K * stride_ak
100 b_ptrs += BLOCK_K * stride_bk
102 acc = acc.to(tl.float32)
104 if SCALE_A_MODE == 0:
105 scale_a = tl.full((BLOCK_M,), tl.load(ScaleA), dtype=tl.float32)
106 else:
107 scale_a = tl.load(ScaleA + offs_m, mask=offs_m < M, other=0.0)
109 if SCALE_B_MODE == 0:
110 scale_b = tl.full((BLOCK_N,), tl.load(ScaleB), dtype=tl.float32)
111 else:
112 scale_b = tl.load(ScaleB + offs_n, mask=offs_n < N, other=0.0)
114 acc = acc * scale_a[:, None] * scale_b[None, :]
116 if HAS_BIAS:
117 bias = tl.load(Bias + offs_n, mask=offs_n < N, other=0.0)
118 acc += bias[None, :]
120 c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
121 c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
122 tl.store(c_ptrs, acc, mask=c_mask)
125@libentry()
126@triton.jit
127def scaled_mm_aligned_kernel(
128 A,
129 B,
130 ScaleA,
131 ScaleB,
132 Bias,
133 C,
134 M: tl.constexpr,
135 N: tl.constexpr,
136 K: tl.constexpr,
137 stride_am: tl.constexpr,
138 stride_ak: tl.constexpr,
139 stride_bk: tl.constexpr,
140 stride_bn: tl.constexpr,
141 stride_cm: tl.constexpr,
142 stride_cn: tl.constexpr,
143 ACC_DTYPE: tl.constexpr,
144 SCALE_A_MODE: tl.constexpr,
145 SCALE_B_MODE: tl.constexpr,
146 HAS_BIAS: tl.constexpr,
147 BLOCK_M: tl.constexpr,
148 BLOCK_N: tl.constexpr,
149 BLOCK_K: tl.constexpr,
150 GROUP_M: tl.constexpr,
151):
152 pid = tle.program_id(0)
153 grid_m = tl.cdiv(M, BLOCK_M)
154 grid_n = tl.cdiv(N, BLOCK_N)
155 width = GROUP_M * grid_n
156 group_id = pid // width
157 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
158 pid_m = group_id * GROUP_M + (pid % group_size)
159 pid_n = (pid % width) // group_size
161 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
162 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
163 offs_m = offs_m.to(tl.int64)
164 offs_n = offs_n.to(tl.int64)
165 offs_k = tl.arange(0, BLOCK_K)
167 a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
168 b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
170 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_DTYPE)
171 for _ in range(0, tl.cdiv(K, BLOCK_K)):
172 a = tl.load(a_ptrs)
173 b = tl.load(b_ptrs)
174 acc += tl.dot(a, b, out_dtype=ACC_DTYPE, allow_tf32=False)
175 a_ptrs += BLOCK_K * stride_ak
176 b_ptrs += BLOCK_K * stride_bk
178 acc = acc.to(tl.float32)
180 if SCALE_A_MODE == 0:
181 scale_a = tl.full((BLOCK_M,), tl.load(ScaleA), dtype=tl.float32)
182 else:
183 scale_a = tl.load(ScaleA + offs_m)
185 if SCALE_B_MODE == 0:
186 scale_b = tl.full((BLOCK_N,), tl.load(ScaleB), dtype=tl.float32)
187 else:
188 scale_b = tl.load(ScaleB + offs_n)
190 acc = acc * scale_a[:, None] * scale_b[None, :]
192 if HAS_BIAS:
193 bias = tl.load(Bias + offs_n)
194 acc += bias[None, :]
196 c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
197 tl.store(c_ptrs, acc)
200def _resolve_out_dtype(self, out_dtype, out=None):
201 if out_dtype is not None:
202 if out is not None and out.dtype != out_dtype:
203 raise RuntimeError(
204 "out_dtype must be the same as the dtype of the provided out tensor"
205 )
206 return out_dtype
207 if out is not None:
208 return out.dtype
209 return self.dtype
212def _normalize_scale(scale, expected_size, *, is_left_scale):
213 if scale.numel() == 1:
214 return scale.reshape(1).contiguous(), SCALAR_SCALE
216 valid_vector = scale.ndim == 1 and scale.shape[0] == expected_size
217 if is_left_scale:
218 valid_vector = valid_vector or (
219 scale.ndim == 2 and scale.shape == (expected_size, 1)
220 )
221 else:
222 valid_vector = valid_vector or (
223 scale.ndim == 2 and scale.shape == (1, expected_size)
224 )
226 if valid_vector:
227 return scale.reshape(expected_size).contiguous(), VECTOR_SCALE
229 scale_name = "scale_a" if is_left_scale else "scale_b"
230 expected_shape = (
231 f"({expected_size}, 1)" if is_left_scale else f"(1, {expected_size})"
232 )
233 raise RuntimeError(
234 f"{scale_name} must be a scalar tensor or have shape {expected_shape}"
235 )
238def _normalize_bias(bias, cols):
239 if bias is None:
240 return None
241 if bias.numel() != cols:
242 raise RuntimeError(f"Bias must be size {cols} but got {bias.numel()}")
243 return bias.reshape(cols).contiguous()
246def _check_inputs(self, mat2):
247 if self.ndim != 2:
248 raise RuntimeError("self must be a matrix")
249 if mat2.ndim != 2:
250 raise RuntimeError("mat2 must be a matrix")
251 if self.shape[1] != mat2.shape[0]:
252 raise RuntimeError(
253 f"mat1 and mat2 shapes cannot be multiplied ({self.shape[0]}x{self.shape[1]} "
254 f"and {mat2.shape[0]}x{mat2.shape[1]})"
255 )
256 if self.dtype != mat2.dtype:
257 raise RuntimeError(
258 f"self and mat2 must have the same dtype, but got {self.dtype} and {mat2.dtype}"
259 )
262def _maybe_make_contiguous_for_kernel(self, mat2):
263 if self.stride(0) > 1 and self.stride(1) > 1:
264 self = self.contiguous()
265 if mat2.stride(0) > 1 and mat2.stride(1) > 1:
266 mat2 = mat2.contiguous()
267 return self, mat2
270def _can_use_cutlass_scaled_mm(self, mat2, scale_a, scale_b, bias, out):
271 if self.device.type != "cuda":
272 return False
273 is_fp8 = hasattr(torch, "float8_e4m3fn") and self.dtype == torch.float8_e4m3fn
274 if not (is_fp8 or self.dtype == torch.int8):
275 return False
276 if self.dtype != mat2.dtype:
277 return False
278 major, minor = torch.cuda.get_device_capability(self.device)
279 sm_version_num = major * 10 + minor
280 if not (90 <= sm_version_num < 100):
281 return False
282 if scale_a.dtype != torch.float32 or scale_b.dtype != torch.float32:
283 return False
284 if scale_a.numel() not in (1, self.shape[0]):
285 return False
286 if scale_b.numel() not in (1, mat2.shape[1]):
287 return False
288 if not scale_a.is_contiguous() or not scale_b.is_contiguous():
289 return False
290 if self.stride(1) != 1 or out.stride(1) != 1:
291 return False
292 if mat2.stride(0) != 1:
293 return False
294 if out.stride(0) % 16 != 0 or mat2.stride(1) % 16 != 0:
295 return False
296 if bias is not None and (bias.ndim != 1 or not bias.is_contiguous()):
297 return False
298 return True
301def _can_use_ascend_aligned_scaled_mm(self, mat2, out):
302 if self.device.type != "npu" or runtime.device.vendor_name != "ascend":
303 return False
304 if not self.is_floating_point():
305 return False
306 M, K = self.shape
307 _, N = mat2.shape
308 return (
309 M * N * K >= ASCEND_ALIGNED_MIN_VOLUME
310 and M % ASCEND_ALIGNED_BLOCK == 0
311 and N % ASCEND_ALIGNED_BLOCK == 0
312 and K % ASCEND_ALIGNED_BLOCK == 0
313 and self.stride(1) == 1
314 and mat2.stride(1) == 1
315 and out.stride(1) == 1
316 )
319def _scaled_mm_impl(
320 self,
321 mat2,
322 scale_a,
323 scale_b,
324 bias,
325 out_dtype,
326 out,
327):
328 _check_inputs(self, mat2)
329 M, K = self.shape
330 _, N = mat2.shape
332 output_dtype = _resolve_out_dtype(self, out_dtype, out)
333 if out is None:
334 out = torch.empty((M, N), dtype=output_dtype, device=self.device)
335 else:
336 if out.shape != (M, N):
337 raise RuntimeError("Incompatible output shape")
339 scale_a, scale_a_mode = _normalize_scale(scale_a, M, is_left_scale=True)
340 scale_b, scale_b_mode = _normalize_scale(scale_b, N, is_left_scale=False)
341 bias = _normalize_bias(bias, N)
343 if M == 0 or N == 0:
344 return out
346 if _can_use_cutlass_scaled_mm(self, mat2, scale_a, scale_b, bias, out):
347 with torch_device_fn.device(self.device):
348 _csmm(out, self, mat2, scale_a, scale_b, bias)
349 return out
351 self, mat2 = _maybe_make_contiguous_for_kernel(self, mat2)
352 acc_dtype = tl.float32 if self.is_floating_point() else tl.int32
353 grid = lambda META: (
354 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
355 )
356 with torch_device_fn.device(self.device):
357 if _can_use_ascend_aligned_scaled_mm(self, mat2, out):
358 block = ASCEND_ALIGNED_KERNEL_BLOCK
359 aligned_grid = (triton.cdiv(M, block) * triton.cdiv(N, block),)
360 scaled_mm_aligned_kernel[aligned_grid](
361 self,
362 mat2,
363 scale_a,
364 scale_b,
365 bias,
366 out,
367 M,
368 N,
369 K,
370 self.stride(0),
371 self.stride(1),
372 mat2.stride(0),
373 mat2.stride(1),
374 out.stride(0),
375 out.stride(1),
376 ACC_DTYPE=acc_dtype,
377 SCALE_A_MODE=scale_a_mode,
378 SCALE_B_MODE=scale_b_mode,
379 HAS_BIAS=bias is not None,
380 BLOCK_M=block,
381 BLOCK_N=block,
382 BLOCK_K=block,
383 GROUP_M=GROUP_M,
384 )
385 else:
386 scaled_mm_kernel[grid](
387 self,
388 mat2,
389 scale_a,
390 scale_b,
391 bias,
392 out,
393 M,
394 N,
395 K,
396 self.stride(0),
397 self.stride(1),
398 mat2.stride(0),
399 mat2.stride(1),
400 out.stride(0),
401 out.stride(1),
402 ACC_DTYPE=acc_dtype,
403 SCALE_A_MODE=scale_a_mode,
404 SCALE_B_MODE=scale_b_mode,
405 HAS_BIAS=bias is not None,
406 GROUP_M=GROUP_M,
407 )
408 return out
411def scaled_mm(
412 self,
413 mat2,
414 scale_a,
415 scale_b,
416 bias=None,
417 scale_result=None,
418 out_dtype=None,
419 use_fast_accum=False,
420):
421 logger.debug("GEMS SCALED_MM")
422 return _scaled_mm_impl(self, mat2, scale_a, scale_b, bias, out_dtype, None)
425def scaled_mm_out(
426 self,
427 mat2,
428 scale_a,
429 scale_b,
430 bias=None,
431 scale_result=None,
432 out_dtype=None,
433 use_fast_accum=False,
434 *,
435 out,
436):
437 logger.debug("GEMS SCALED_MM_OUT")
438 return _scaled_mm_impl(self, mat2, scale_a, scale_b, bias, out_dtype, out)