Coverage for src/flag_gems/runtime/backend/_ascend/ops/mm.py: 0%
92 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.runtime.backend._ascend import heuristics_config_utils as _hcu
10from flag_gems.utils import libentry, libtuner
12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
15@libentry()
16@libtuner(
17 configs=runtime.get_tuned_config("mm"),
18 key=["M", "N", "K"],
19)
20@triton.heuristics(_hcu.HEURISTICS_CONFIGS["mm"])
21@triton.jit
22def mm_kernel(
23 A,
24 B,
25 C,
26 M: tl.constexpr,
27 N: tl.constexpr,
28 K: tl.constexpr,
29 stride_am: tl.constexpr,
30 stride_ak: tl.constexpr,
31 stride_bk: tl.constexpr,
32 stride_bn: tl.constexpr,
33 stride_cm: tl.constexpr,
34 stride_cn: tl.constexpr,
35 dot_out_dtype: tl.constexpr,
36 BLOCK_M: tl.constexpr,
37 BLOCK_N: tl.constexpr,
38 BLOCK_K: tl.constexpr,
39 GROUP_M: tl.constexpr,
40 SPLIT_K: tl.constexpr,
41 EVEN_K: tl.constexpr,
42):
43 pid = tl.program_id(0)
44 pid_z = tl.program_id(1)
45 grid_m = tl.cdiv(M, BLOCK_M)
46 grid_n = tl.cdiv(N, BLOCK_N)
47 width = GROUP_M * grid_n
48 group_id = pid // width
49 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
50 pid_m = group_id * GROUP_M + (pid % group_size)
51 pid_n = (pid % width) // (group_size)
52 ram = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
53 rbn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
54 rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
55 A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
56 B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
57 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
58 for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
59 if EVEN_K:
60 a = tl.load(A, mask=(ram < M)[:, None], other=0.0)
61 b = tl.load(B, mask=(rbn < N)[None, :], other=0.0)
62 else:
63 k_remaining = K - k * (BLOCK_K * SPLIT_K)
64 a = tl.load(
65 A,
66 mask=(rk[None, :] < k_remaining) & (ram < M)[:, None],
67 other=0.0,
68 )
69 b = tl.load(
70 B,
71 mask=(rk[:, None] < k_remaining) & (rbn < N)[None, :],
72 other=0.0,
73 )
74 if a.dtype != b.dtype:
75 a = a.to(C.dtype.element_ty)
76 b = b.to(C.dtype.element_ty)
77 acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=False)
78 A += BLOCK_K * SPLIT_K * stride_ak
79 B += BLOCK_K * SPLIT_K * stride_bk
80 acc = acc.to(C.dtype.element_ty)
81 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
82 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
83 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
84 mask = (rm < M)[:, None] & (rn < N)[None, :]
85 if SPLIT_K == 1:
86 tl.store(C, acc, mask=mask)
87 else:
88 tl.atomic_add(C, acc, mask=mask)
91_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32]
94def get_higher_dtype(a, b):
95 if a is b:
96 return a
98 assert a in _ordered_datatypes
99 assert b in _ordered_datatypes
101 for d in _ordered_datatypes:
102 if a is d:
103 return b
104 if b is d:
105 return a
108def mm(a, b):
109 logger.debug("GEMS_ASCEND MM")
110 device = a.device
111 # handle non-contiguous inputs if necessary
112 if a.stride(0) > 1 and a.stride(1) > 1:
113 a = a.contiguous()
114 if b.stride(0) > 1 and b.stride(1) > 1:
115 b = b.contiguous()
116 # checks constraints
117 assert a.shape[1] == b.shape[0], "incompatible dimensions"
118 M, K = a.shape
119 _, N = b.shape
120 # allocates output
121 c_dtype = get_higher_dtype(a.dtype, b.dtype)
122 c = torch.empty((M, N), device=device, dtype=c_dtype)
123 dot_out_dtype = tl.float32
124 # launch kernel
125 grid = lambda META: (
126 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
127 META.get("SPLIT_K", 1),
128 )
129 with torch_device_fn.device(a.device):
130 mm_kernel[grid](
131 a,
132 b,
133 c,
134 M,
135 N,
136 K,
137 a.stride(0),
138 a.stride(1),
139 b.stride(0),
140 b.stride(1),
141 c.stride(0),
142 c.stride(1),
143 dot_out_dtype=dot_out_dtype,
144 GROUP_M=8,
145 )
146 return c
149def mm_out(a, b, *, out):
150 logger.debug("GEMS_ASCEND MM_OUT")
151 if a.stride(0) > 1 and a.stride(1) > 1:
152 a = a.contiguous()
153 if b.stride(0) > 1 and b.stride(1) > 1:
154 b = b.contiguous()
155 assert a.shape[1] == b.shape[0], "incompatible dimensions"
156 M, K = a.shape
157 _, N = b.shape
158 dot_out_dtype = tl.float32
159 grid = lambda META: (
160 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
161 META.get("SPLIT_K", 1),
162 )
163 with torch_device_fn.device(a.device):
164 mm_kernel[grid](
165 a,
166 b,
167 out,
168 M,
169 N,
170 K,
171 a.stride(0),
172 a.stride(1),
173 b.stride(0),
174 b.stride(1),
175 out.stride(0),
176 out.stride(1),
177 dot_out_dtype=dot_out_dtype,
178 GROUP_M=8,
179 )
180 return out