Coverage for src/flag_gems/ops/bmm.py: 37%
95 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry, libtuner
10from flag_gems.utils import triton_lang_extension as ext
12logger = logging.getLogger(__name__)
15@libentry()
16@libtuner(
17 configs=runtime.get_tuned_config("bmm"),
18 key=["M", "N", "K", "stride_am", "stride_bk"],
19 strategy=[
20 "log",
21 "log",
22 "log",
23 "align32",
24 "align32",
25 ],
26 flagtune_op_name="bmm",
27 flagtune_expand_op_name="bmm",
28 flagtune_pre_hook=None,
29)
30@triton.heuristics(runtime.get_heuristic_config("bmm"))
31@triton.jit
32def bmm_kernel(
33 A,
34 B,
35 O,
36 M,
37 N,
38 K,
39 stride_ab,
40 stride_am,
41 stride_ak,
42 stride_bb,
43 stride_bk,
44 stride_bn,
45 stride_ob,
46 stride_om,
47 stride_on,
48 TILE_M: tl.constexpr,
49 TILE_N: tl.constexpr,
50 TILE_K: tl.constexpr,
51 GROUP_M: tl.constexpr,
52 DIVISIBLE_M: tl.constexpr,
53 DIVISIBLE_N: tl.constexpr,
54 DIVISIBLE_K: tl.constexpr,
55 IS_FP64: tl.constexpr = False,
56):
57 # batch offsets
58 pid_b = ext.program_id(2)
59 A += pid_b * stride_ab
60 B += pid_b * stride_bb
61 O += pid_b * stride_ob
63 pidx = ext.program_id(0)
64 pidy = ext.program_id(1)
66 if GROUP_M == 1:
67 pid_m, pid_n = pidx, pidy
68 else:
69 # reorder CTAs
70 gridx = ext.num_programs(0)
71 gridy = ext.num_programs(1)
72 pid = pidx + pidy * gridx
74 num_CTA_per_group = gridy * GROUP_M
76 group_id = pid // num_CTA_per_group
77 inner_group_id = pid % num_CTA_per_group
78 GROUP_SIZE = tl.where(
79 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M
80 )
81 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE
82 pid_n = inner_group_id // GROUP_SIZE
84 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M)
85 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
86 offs_k = tl.arange(0, TILE_K)
88 if not DIVISIBLE_M:
89 mask_m = offs_m < M
90 if not DIVISIBLE_N:
91 mask_n = offs_n < N
93 a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
94 b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
95 o_ptrs = O + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
97 num_iters = tl.cdiv(K, TILE_K)
98 if IS_FP64:
99 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float64)
100 else:
101 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
102 for _ in range(num_iters):
103 if DIVISIBLE_K:
104 if DIVISIBLE_M:
105 mask_a = None
106 else:
107 mask_a = mask_m[:, None]
108 if DIVISIBLE_N:
109 mask_b = None
110 else:
111 mask_b = mask_n[None, :]
112 else:
113 mask_k = offs_k < K
114 if DIVISIBLE_M:
115 mask_a = mask_k[None, :]
116 else:
117 mask_a = mask_m[:, None] & mask_k[None, :]
118 if DIVISIBLE_N:
119 mask_b = mask_k[:, None]
120 else:
121 mask_b = mask_k[:, None] & mask_n[None, :]
123 a = tl.load(a_ptrs, mask_a)
124 b = tl.load(b_ptrs, mask_b)
126 offs_k += TILE_K
127 a_ptrs += TILE_K * stride_ak
128 b_ptrs += TILE_K * stride_bk
130 o += tl.dot(a, b, allow_tf32=False)
132 if DIVISIBLE_M and DIVISIBLE_N:
133 mask_c = None
134 elif DIVISIBLE_M and not DIVISIBLE_N:
135 mask_c = mask_n[None, :]
136 elif not DIVISIBLE_M and DIVISIBLE_N:
137 mask_c = mask_m[:, None]
138 else:
139 mask_c = mask_m[:, None] & mask_n[None, :]
140 tl.store(o_ptrs, o, mask_c)
143def bmm(A, B):
144 logger.debug("GEMS BMM")
145 assert A.shape[0] == B.shape[0], "Batch dim mismatch"
146 assert A.shape[2] == B.shape[1], "K dim mismatch"
147 batch, M, K = A.shape
148 _, _, N = B.shape
149 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device)
151 grid_fn = lambda meta: (
152 triton.cdiv(meta["M"], meta["TILE_M"]),
153 triton.cdiv(meta["N"], meta["TILE_N"]),
154 batch,
155 )
156 with torch_device_fn.device(A.device):
157 bmm_kernel[grid_fn](
158 A,
159 B,
160 out,
161 M,
162 N,
163 K,
164 A.stride(0),
165 A.stride(1),
166 A.stride(2),
167 B.stride(0),
168 B.stride(1),
169 B.stride(2),
170 out.stride(0),
171 out.stride(1),
172 out.stride(2),
173 IS_FP64=A.dtype == torch.float64,
174 )
175 return out
178def bmm_out(A, B, out):
179 logger.debug("GEMS BMM_OUT")
180 assert A.shape[0] == B.shape[0] == out.shape[0], "Batch dim mismatch"
181 assert A.shape[2] == B.shape[1], "K dim mismatch"
182 batch, M, K = A.shape
183 _, _, N = B.shape
185 grid_fn = lambda meta: (
186 triton.cdiv(meta["M"], meta["TILE_M"]),
187 triton.cdiv(meta["N"], meta["TILE_N"]),
188 batch,
189 )
190 with torch_device_fn.device(A.device):
191 bmm_kernel[grid_fn](
192 A,
193 B,
194 out,
195 M,
196 N,
197 K,
198 A.stride(0),
199 A.stride(1),
200 A.stride(2),
201 B.stride(0),
202 B.stride(1),
203 B.stride(2),
204 out.stride(0),
205 out.stride(1),
206 out.stride(2),
207 IS_FP64=A.dtype == torch.float64,
208 )
209 return out