Coverage for src/flag_gems/ops/bmm.py: 38%
96 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, libtuner
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger(__name__)
16@libentry()
17@libtuner(
18 configs=runtime.ops_get_configs("bmm", pre_hook=None)
19 if os.environ.get("USE_FLAGTUNE") == "1"
20 else runtime.get_tuned_config("bmm"),
21 key=["M", "N", "K", "stride_am", "stride_bk"],
22 strategy=runtime.get_expand_config("bmm")["strategy"]
23 if os.environ.get("USE_FLAGTUNE") == "1"
24 else [
25 "log",
26 "log",
27 "log",
28 "align32",
29 "align32",
30 ],
31)
32@triton.heuristics(runtime.get_heuristic_config("bmm"))
33@triton.jit
34def bmm_kernel(
35 A,
36 B,
37 O,
38 M,
39 N,
40 K,
41 stride_ab,
42 stride_am,
43 stride_ak,
44 stride_bb,
45 stride_bk,
46 stride_bn,
47 stride_ob,
48 stride_om,
49 stride_on,
50 TILE_M: tl.constexpr,
51 TILE_N: tl.constexpr,
52 TILE_K: tl.constexpr,
53 GROUP_M: tl.constexpr,
54 DIVISIBLE_M: tl.constexpr,
55 DIVISIBLE_N: tl.constexpr,
56 DIVISIBLE_K: tl.constexpr,
57 IS_FP64: tl.constexpr = False,
58):
59 # batch offsets
60 pid_b = tle.program_id(2)
61 A += pid_b * stride_ab
62 B += pid_b * stride_bb
63 O += pid_b * stride_ob
65 pidx = tle.program_id(0)
66 pidy = tle.program_id(1)
68 if GROUP_M == 1:
69 pid_m, pid_n = pidx, pidy
70 else:
71 # reorder CTAs
72 gridx = tle.num_programs(0)
73 gridy = tle.num_programs(1)
74 pid = pidx + pidy * gridx
76 num_CTA_per_group = gridy * GROUP_M
78 group_id = pid // num_CTA_per_group
79 inner_group_id = pid % num_CTA_per_group
80 GROUP_SIZE = tl.where(
81 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M
82 )
83 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE
84 pid_n = inner_group_id // GROUP_SIZE
86 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M)
87 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
88 offs_k = tl.arange(0, TILE_K)
90 if not DIVISIBLE_M:
91 mask_m = offs_m < M
92 if not DIVISIBLE_N:
93 mask_n = offs_n < N
95 a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
96 b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
97 o_ptrs = O + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
99 num_iters = tl.cdiv(K, TILE_K)
100 if IS_FP64:
101 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float64)
102 else:
103 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
104 for _ in range(num_iters):
105 if DIVISIBLE_K:
106 if DIVISIBLE_M:
107 mask_a = None
108 else:
109 mask_a = mask_m[:, None]
110 if DIVISIBLE_N:
111 mask_b = None
112 else:
113 mask_b = mask_n[None, :]
114 else:
115 mask_k = offs_k < K
116 if DIVISIBLE_M:
117 mask_a = mask_k[None, :]
118 else:
119 mask_a = mask_m[:, None] & mask_k[None, :]
120 if DIVISIBLE_N:
121 mask_b = mask_k[:, None]
122 else:
123 mask_b = mask_k[:, None] & mask_n[None, :]
125 a = tl.load(a_ptrs, mask_a)
126 b = tl.load(b_ptrs, mask_b)
128 offs_k += TILE_K
129 a_ptrs += TILE_K * stride_ak
130 b_ptrs += TILE_K * stride_bk
132 o += tl.dot(a, b, allow_tf32=False)
134 if DIVISIBLE_M and DIVISIBLE_N:
135 mask_c = None
136 elif DIVISIBLE_M and not DIVISIBLE_N:
137 mask_c = mask_n[None, :]
138 elif not DIVISIBLE_M and DIVISIBLE_N:
139 mask_c = mask_m[:, None]
140 else:
141 mask_c = mask_m[:, None] & mask_n[None, :]
142 tl.store(o_ptrs, o, mask_c)
145def bmm(A, B):
146 logger.debug("GEMS BMM")
147 assert A.shape[0] == B.shape[0], "Batch dim mismatch"
148 assert A.shape[2] == B.shape[1], "K dim mismatch"
149 batch, M, K = A.shape
150 _, _, N = B.shape
151 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device)
153 grid_fn = lambda meta: (
154 triton.cdiv(meta["M"], meta["TILE_M"]),
155 triton.cdiv(meta["N"], meta["TILE_N"]),
156 batch,
157 )
158 with torch_device_fn.device(A.device):
159 bmm_kernel[grid_fn](
160 A,
161 B,
162 out,
163 M,
164 N,
165 K,
166 A.stride(0),
167 A.stride(1),
168 A.stride(2),
169 B.stride(0),
170 B.stride(1),
171 B.stride(2),
172 out.stride(0),
173 out.stride(1),
174 out.stride(2),
175 IS_FP64=A.dtype == torch.float64,
176 )
177 return out
180def bmm_out(A, B, out):
181 logger.debug("GEMS BMM_OUT")
182 assert A.shape[0] == B.shape[0] == out.shape[0], "Batch dim mismatch"
183 assert A.shape[2] == B.shape[1], "K dim mismatch"
184 batch, M, K = A.shape
185 _, _, N = B.shape
187 grid_fn = lambda meta: (
188 triton.cdiv(meta["M"], meta["TILE_M"]),
189 triton.cdiv(meta["N"], meta["TILE_N"]),
190 batch,
191 )
192 with torch_device_fn.device(A.device):
193 bmm_kernel[grid_fn](
194 A,
195 B,
196 out,
197 M,
198 N,
199 K,
200 A.stride(0),
201 A.stride(1),
202 A.stride(2),
203 B.stride(0),
204 B.stride(1),
205 B.stride(2),
206 out.stride(0),
207 out.stride(1),
208 out.stride(2),
209 IS_FP64=A.dtype == torch.float64,
210 )
211 return out