Coverage for src/flag_gems/ops/baddbmm.py: 30%
159 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from .. import runtime
8from ..runtime import torch_device_fn
9from ..utils import libentry, libtuner
10from ..utils import triton_lang_extension as ext
11from .bmm import bmm
12from .mul import mul
14logger = logging.getLogger(__name__)
17@libentry()
18@libtuner(
19 configs=runtime.get_tuned_config("baddbmm"),
20 key=["M", "N", "K"],
21 strategy=["align32", "align32", "align32"],
22 warmup=5,
23 rep=10,
24 flagtune_op_name="baddbmm",
25)
26@triton.heuristics(runtime.get_heuristic_config("baddbmm"))
27@triton.jit(do_not_specialize=["alpha", "beta"])
28def baddbmm_kernel(
29 A,
30 B,
31 O,
32 bias,
33 alpha,
34 beta,
35 M,
36 N,
37 K,
38 TILE_M: tl.constexpr,
39 TILE_N: tl.constexpr,
40 TILE_K: tl.constexpr,
41 GROUP_M: tl.constexpr,
42 DIVISIBLE_M: tl.constexpr,
43 DIVISIBLE_N: tl.constexpr,
44 DIVISIBLE_K: tl.constexpr,
45 bias_batch_stride: tl.constexpr,
46 bias_M_stride: tl.constexpr,
47 bias_N_stride: tl.constexpr,
48 IS_FP64: tl.constexpr = False,
49):
50 # batch offsets
51 pid_b = ext.program_id(2)
52 A += pid_b * M * K
53 B += pid_b * K * N
54 O += pid_b * M * N
55 bias += pid_b * bias_batch_stride
57 pidx = ext.program_id(0)
58 pidy = ext.program_id(1)
60 if GROUP_M == 1:
61 pid_m, pid_n = pidx, pidy
62 else:
63 gridx = ext.num_programs(0)
64 gridy = ext.num_programs(1)
65 pid = pidx + pidy * gridx
66 num_CTA_per_group = gridy * GROUP_M
67 group_id = pid // num_CTA_per_group
68 inner_group_id = pid % num_CTA_per_group
69 GROUP_SIZE = tl.where(
70 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M
71 )
72 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE
73 pid_n = inner_group_id // GROUP_SIZE
75 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M)
76 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
77 offs_k = tl.arange(0, TILE_K)
79 if not DIVISIBLE_M:
80 mask_m = offs_m < M
81 if not DIVISIBLE_N:
82 mask_n = offs_n < N
84 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :]
85 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :]
86 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :]
88 num_iters = tl.cdiv(K, TILE_K)
89 if IS_FP64:
90 accumulator = tl.zeros((TILE_M, TILE_N), dtype=tl.float64)
91 else:
92 accumulator = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
93 for _ in range(num_iters):
94 if DIVISIBLE_K:
95 if DIVISIBLE_M:
96 mask_a = None
97 else:
98 mask_a = mask_m[:, None]
99 if DIVISIBLE_N:
100 mask_b = None
101 else:
102 mask_b = mask_n[None, :]
103 else:
104 mask_k = offs_k < K
105 if DIVISIBLE_M:
106 mask_a = mask_k[None, :]
107 else:
108 mask_a = mask_m[:, None] & mask_k[None, :]
109 if DIVISIBLE_N:
110 mask_b = mask_k[:, None]
111 else:
112 mask_b = mask_k[:, None] & mask_n[None, :]
113 a = tl.load(a_ptrs, mask=mask_a)
114 b = tl.load(b_ptrs, mask=mask_b)
115 accumulator += tl.dot(a, b, allow_tf32=False)
116 offs_k += TILE_K
117 a_ptrs += TILE_K
118 b_ptrs += TILE_K * N
120 bias_ptrs = bias + offs_m[:, None] * bias_M_stride + offs_n[None, :] * bias_N_stride
122 if DIVISIBLE_M and DIVISIBLE_N:
123 mask_c = None
124 else:
125 mask_c = True
126 if not DIVISIBLE_M:
127 mask_c &= offs_m[:, None] < M
128 if not DIVISIBLE_N:
129 mask_c &= offs_n[None, :] < N
131 bi = tl.load(bias_ptrs, mask=mask_c)
132 out = accumulator * alpha + bi * beta
133 o = out.to(bi.dtype)
134 tl.store(o_ptrs, o, mask=mask_c)
137def _baddbmm_launch(bias, A, B, beta, alpha, out):
138 batch, M, K = A.shape
139 _, _, N = B.shape
140 A = A.contiguous()
141 B = B.contiguous()
142 bbias = torch.broadcast_to(bias, (batch, M, N)).contiguous()
143 bias_batch_stride = bbias.stride(0)
144 bias_M_stride = bbias.stride(1)
145 bias_N_stride = bbias.stride(-1)
147 grid = lambda meta: (
148 triton.cdiv(meta["M"], meta["TILE_M"]),
149 triton.cdiv(meta["N"], meta["TILE_N"]),
150 batch,
151 )
152 with torch_device_fn.device(A.device):
153 baddbmm_kernel[grid](
154 A,
155 B,
156 out,
157 bbias,
158 alpha,
159 beta,
160 M,
161 N,
162 K,
163 bias_batch_stride=bias_batch_stride,
164 bias_M_stride=bias_M_stride,
165 bias_N_stride=bias_N_stride,
166 )
169class BaddbmmFunction(torch.autograd.Function):
170 @staticmethod
171 def forward(ctx, bias, A, B, beta, alpha):
172 logger.debug("GEMS BADDBMM FORWARD")
174 ctx.save_for_backward(A, B, bias)
175 ctx.alpha = alpha
176 ctx.beta = beta
178 batch, M, K = A.shape
179 _, _, N = B.shape
180 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device)
181 _baddbmm_launch(bias, A, B, beta, alpha, out)
182 return out
184 @staticmethod
185 def backward(ctx, grad_output):
186 logger.debug("GEMS BADDBMM BACKWARD")
187 A, B, bias = ctx.saved_tensors
189 grad_A = None
190 grad_B = None
191 grad_bias = None
192 if ctx.needs_input_grad[0]:
193 grad_bias = compute_bias_grad(grad_output, ctx.beta, bias)
194 if ctx.needs_input_grad[1]:
195 grad_A = compute_A_grad(grad_output, B, ctx.alpha)
196 if ctx.needs_input_grad[2]:
197 grad_B = compute_B_grad(A, grad_output, ctx.alpha)
199 return grad_bias, grad_A, grad_B, None, None
202def compute_bias_grad(d_output, beta, bias):
203 grad_bias = mul(d_output, beta)
204 if grad_bias.shape != bias.shape:
205 # Sum over broadcasted dimensions
206 while grad_bias.dim() > bias.dim():
207 grad_bias = grad_bias.sum(dim=0)
208 for i in range(bias.dim()):
209 if bias.shape[i] == 1 and grad_bias.shape[i] > 1:
210 grad_bias = grad_bias.sum(dim=i, keepdim=True)
211 return grad_bias.view(bias.shape)
214def compute_A_grad(d_output, B, alpha):
215 B_T = B.transpose(1, 2)
216 if B.dtype == torch.float16:
217 Bcopy = B_T.to(torch.float32)
218 dcopye = d_output.to(torch.float32)
219 mul1 = bmm(dcopye, Bcopy)
220 grad_A = mul(mul1, alpha)
221 grad_A = grad_A.to(torch.float16)
222 else:
223 mul1 = bmm(d_output, B_T)
224 grad_A = mul(mul1, alpha)
225 return grad_A
228def compute_B_grad(A, d_output, alpha):
229 A_T = A.transpose(1, 2)
230 if A.dtype == torch.float16:
231 Acopy = A_T.to(torch.float32)
232 dcopye = d_output.to(torch.float32)
233 mul2 = bmm(Acopy, dcopye)
234 grad_B = mul(mul2, alpha)
235 grad_B = grad_B.to(torch.float16)
236 else:
237 mul2 = bmm(A_T, d_output)
238 grad_B = mul(mul2, alpha)
239 return grad_B
242def baddbmm_out(bias, A, B, *, beta=1.0, alpha=1.0, out):
243 logger.debug("GEMS BADDBMM_OUT")
244 batch, M, K = A.shape
245 _, _, N = B.shape
246 assert (
247 out.shape == (batch, M, N) and out.dtype == A.dtype
248 ), "Incompatible output shape or dtype for baddbmm.out"
249 _baddbmm_launch(
250 bias.contiguous(),
251 A.contiguous(),
252 B.contiguous(),
253 beta,
254 alpha,
255 out,
256 )
257 return out
260def baddbmm(bias, A, B, beta=1.0, alpha=1.0):
261 return BaddbmmFunction.apply(
262 bias.contiguous(),
263 A.contiguous(),
264 B.contiguous(),
265 beta,
266 alpha,
267 )