Coverage for src/flag_gems/ops/baddbmm.py: 31%
160 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8from .. import runtime
9from ..runtime import torch_device_fn
10from ..utils import libentry, libtuner
11from ..utils import triton_lang_extension as ext
12from .bmm import bmm
13from .mul import mul
15logger = logging.getLogger(__name__)
18@libentry()
19@libtuner(
20 configs=runtime.ops_get_configs("baddbmm", pre_hook=None)
21 if os.environ.get("USE_FLAGTUNE") == "1"
22 else runtime.get_tuned_config("baddbmm"),
23 key=["M", "N", "K"],
24 strategy=runtime.get_expand_config("baddbmm")["strategy"]
25 if os.environ.get("USE_FLAGTUNE") == "1"
26 else ["align32", "align32", "align32"],
27 warmup=5,
28 rep=10,
29)
30@triton.heuristics(runtime.get_heuristic_config("baddbmm"))
31@triton.jit(do_not_specialize=["alpha", "beta"])
32def baddbmm_kernel(
33 A,
34 B,
35 O,
36 bias,
37 alpha,
38 beta,
39 M,
40 N,
41 K,
42 TILE_M: tl.constexpr,
43 TILE_N: tl.constexpr,
44 TILE_K: tl.constexpr,
45 GROUP_M: tl.constexpr,
46 DIVISIBLE_M: tl.constexpr,
47 DIVISIBLE_N: tl.constexpr,
48 DIVISIBLE_K: tl.constexpr,
49 bias_batch_stride: tl.constexpr,
50 bias_M_stride: tl.constexpr,
51 bias_N_stride: tl.constexpr,
52 IS_FP64: tl.constexpr = False,
53):
54 # batch offsets
55 pid_b = ext.program_id(2)
56 A += pid_b * M * K
57 B += pid_b * K * N
58 O += pid_b * M * N
59 bias += pid_b * bias_batch_stride
61 pidx = ext.program_id(0)
62 pidy = ext.program_id(1)
64 if GROUP_M == 1:
65 pid_m, pid_n = pidx, pidy
66 else:
67 gridx = ext.num_programs(0)
68 gridy = ext.num_programs(1)
69 pid = pidx + pidy * gridx
70 num_CTA_per_group = gridy * GROUP_M
71 group_id = pid // num_CTA_per_group
72 inner_group_id = pid % num_CTA_per_group
73 GROUP_SIZE = tl.where(
74 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M
75 )
76 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE
77 pid_n = inner_group_id // GROUP_SIZE
79 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M)
80 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
81 offs_k = tl.arange(0, TILE_K)
83 if not DIVISIBLE_M:
84 mask_m = offs_m < M
85 if not DIVISIBLE_N:
86 mask_n = offs_n < N
88 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :]
89 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :]
90 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :]
92 num_iters = tl.cdiv(K, TILE_K)
93 if IS_FP64:
94 accumulator = tl.zeros((TILE_M, TILE_N), dtype=tl.float64)
95 else:
96 accumulator = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
97 for _ in range(num_iters):
98 if DIVISIBLE_K:
99 if DIVISIBLE_M:
100 mask_a = None
101 else:
102 mask_a = mask_m[:, None]
103 if DIVISIBLE_N:
104 mask_b = None
105 else:
106 mask_b = mask_n[None, :]
107 else:
108 mask_k = offs_k < K
109 if DIVISIBLE_M:
110 mask_a = mask_k[None, :]
111 else:
112 mask_a = mask_m[:, None] & mask_k[None, :]
113 if DIVISIBLE_N:
114 mask_b = mask_k[:, None]
115 else:
116 mask_b = mask_k[:, None] & mask_n[None, :]
117 a = tl.load(a_ptrs, mask=mask_a)
118 b = tl.load(b_ptrs, mask=mask_b)
119 accumulator += tl.dot(a, b, allow_tf32=False)
120 offs_k += TILE_K
121 a_ptrs += TILE_K
122 b_ptrs += TILE_K * N
124 bias_ptrs = bias + offs_m[:, None] * bias_M_stride + offs_n[None, :] * bias_N_stride
126 if DIVISIBLE_M and DIVISIBLE_N:
127 mask_c = None
128 else:
129 mask_c = True
130 if not DIVISIBLE_M:
131 mask_c &= offs_m[:, None] < M
132 if not DIVISIBLE_N:
133 mask_c &= offs_n[None, :] < N
135 bi = tl.load(bias_ptrs, mask=mask_c)
136 out = accumulator * alpha + bi * beta
137 o = out.to(bi.dtype)
138 tl.store(o_ptrs, o, mask=mask_c)
141def _baddbmm_launch(bias, A, B, beta, alpha, out):
142 batch, M, K = A.shape
143 _, _, N = B.shape
144 A = A.contiguous()
145 B = B.contiguous()
146 bbias = torch.broadcast_to(bias, (batch, M, N)).contiguous()
147 bias_batch_stride = bbias.stride(0)
148 bias_M_stride = bbias.stride(1)
149 bias_N_stride = bbias.stride(-1)
151 grid = lambda meta: (
152 triton.cdiv(meta["M"], meta["TILE_M"]),
153 triton.cdiv(meta["N"], meta["TILE_N"]),
154 batch,
155 )
156 with torch_device_fn.device(A.device):
157 baddbmm_kernel[grid](
158 A,
159 B,
160 out,
161 bbias,
162 alpha,
163 beta,
164 M,
165 N,
166 K,
167 bias_batch_stride=bias_batch_stride,
168 bias_M_stride=bias_M_stride,
169 bias_N_stride=bias_N_stride,
170 )
173class BaddbmmFunction(torch.autograd.Function):
174 @staticmethod
175 def forward(ctx, bias, A, B, beta, alpha):
176 logger.debug("GEMS BADDBMM FORWARD")
178 ctx.save_for_backward(A, B, bias)
179 ctx.alpha = alpha
180 ctx.beta = beta
182 batch, M, K = A.shape
183 _, _, N = B.shape
184 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device)
185 _baddbmm_launch(bias, A, B, beta, alpha, out)
186 return out
188 @staticmethod
189 def backward(ctx, grad_output):
190 logger.debug("GEMS BADDBMM BACKWARD")
191 A, B, bias = ctx.saved_tensors
193 grad_A = None
194 grad_B = None
195 grad_bias = None
196 if ctx.needs_input_grad[0]:
197 grad_bias = compute_bias_grad(grad_output, ctx.beta, bias)
198 if ctx.needs_input_grad[1]:
199 grad_A = compute_A_grad(grad_output, B, ctx.alpha)
200 if ctx.needs_input_grad[2]:
201 grad_B = compute_B_grad(A, grad_output, ctx.alpha)
203 return grad_bias, grad_A, grad_B, None, None
206def compute_bias_grad(d_output, beta, bias):
207 grad_bias = mul(d_output, beta)
208 if grad_bias.shape != bias.shape:
209 # Sum over broadcasted dimensions
210 while grad_bias.dim() > bias.dim():
211 grad_bias = grad_bias.sum(dim=0)
212 for i in range(bias.dim()):
213 if bias.shape[i] == 1 and grad_bias.shape[i] > 1:
214 grad_bias = grad_bias.sum(dim=i, keepdim=True)
215 return grad_bias.view(bias.shape)
218def compute_A_grad(d_output, B, alpha):
219 B_T = B.transpose(1, 2)
220 if B.dtype == torch.float16:
221 Bcopy = B_T.to(torch.float32)
222 dcopye = d_output.to(torch.float32)
223 mul1 = bmm(dcopye, Bcopy)
224 grad_A = mul(mul1, alpha)
225 grad_A = grad_A.to(torch.float16)
226 else:
227 mul1 = bmm(d_output, B_T)
228 grad_A = mul(mul1, alpha)
229 return grad_A
232def compute_B_grad(A, d_output, alpha):
233 A_T = A.transpose(1, 2)
234 if A.dtype == torch.float16:
235 Acopy = A_T.to(torch.float32)
236 dcopye = d_output.to(torch.float32)
237 mul2 = bmm(Acopy, dcopye)
238 grad_B = mul(mul2, alpha)
239 grad_B = grad_B.to(torch.float16)
240 else:
241 mul2 = bmm(A_T, d_output)
242 grad_B = mul(mul2, alpha)
243 return grad_B
246def baddbmm_out(bias, A, B, *, beta=1.0, alpha=1.0, out):
247 logger.debug("GEMS BADDBMM_OUT")
248 batch, M, K = A.shape
249 _, _, N = B.shape
250 assert (
251 out.shape == (batch, M, N) and out.dtype == A.dtype
252 ), "Incompatible output shape or dtype for baddbmm.out"
253 _baddbmm_launch(
254 bias.contiguous(),
255 A.contiguous(),
256 B.contiguous(),
257 beta,
258 alpha,
259 out,
260 )
261 return out
264def baddbmm(bias, A, B, beta=1.0, alpha=1.0):
265 return BaddbmmFunction.apply(
266 bias.contiguous(),
267 A.contiguous(),
268 B.contiguous(),
269 beta,
270 alpha,
271 )