Coverage for src/flag_gems/runtime/backend/_ascend/ops/baddbmm.py: 0%
147 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.ops.mul import mul
9from flag_gems.runtime import torch_device_fn
10from flag_gems.runtime.backend._ascend import heuristics_config_utils as _hcu
11from flag_gems.utils import libentry, libtuner
12from flag_gems.utils import triton_lang_extension as tle
14from .bmm import bmm
16logger = logging.getLogger(__name__)
19@libentry()
20@libtuner(
21 configs=runtime.get_tuned_config("baddbmm"),
22 key=["M", "N", "K"],
23 strategy=["align32", "align32", "align32"],
24 warmup=5,
25 rep=10,
26)
27@triton.heuristics(_hcu.HEURISTICS_CONFIGS["baddbmm"])
28@triton.jit(do_not_specialize=["alpha", "beta"])
29def baddbmm_kernel(
30 A,
31 B,
32 O,
33 bias,
34 alpha,
35 beta,
36 M,
37 N,
38 K,
39 TILE_M: tl.constexpr,
40 TILE_N: tl.constexpr,
41 TILE_K: tl.constexpr,
42 GROUP_M: tl.constexpr,
43 DIVISIBLE_M: tl.constexpr,
44 DIVISIBLE_N: tl.constexpr,
45 DIVISIBLE_K: tl.constexpr,
46 bias_batch_stride: tl.constexpr,
47 bias_M_stride: tl.constexpr,
48 bias_N_stride: tl.constexpr,
49):
50 # batch offsets
51 pid_b = tle.program_id(2)
52 A += pid_b * M * K
53 B += pid_b * K * N
54 O += pid_b * M * N
55 bias += pid_b * bias_batch_stride
57 pidx = tle.program_id(0)
58 pidy = tle.program_id(1)
60 if GROUP_M == 1:
61 pid_m, pid_n = pidx, pidy
62 else:
63 gridx = tle.num_programs(0)
64 gridy = tle.num_programs(1)
65 pid = pidx + pidy * gridx
66 num_CTA_per_group = gridy * GROUP_M
67 group_id = pid // num_CTA_per_group
68 inner_group_id = pid % num_CTA_per_group
69 GROUP_SIZE = tl.where(
70 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M
71 )
72 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE
73 pid_n = inner_group_id // GROUP_SIZE
75 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M)
76 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
77 offs_k = tl.arange(0, TILE_K)
79 if not DIVISIBLE_M:
80 mask_m = offs_m < M
81 if not DIVISIBLE_N:
82 mask_n = offs_n < N
84 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :]
85 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :]
86 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :]
88 num_iters = tl.cdiv(K, TILE_K)
89 accumulator = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
90 for _ in range(num_iters):
91 if DIVISIBLE_K:
92 if DIVISIBLE_M:
93 mask_a = None
94 else:
95 mask_a = mask_m[:, None]
96 if DIVISIBLE_N:
97 mask_b = None
98 else:
99 mask_b = mask_n[None, :]
100 else:
101 mask_k = offs_k < K
102 if DIVISIBLE_M:
103 mask_a = mask_k[None, :]
104 else:
105 mask_a = mask_m[:, None] & mask_k[None, :]
106 if DIVISIBLE_N:
107 mask_b = mask_k[:, None]
108 else:
109 mask_b = mask_k[:, None] & mask_n[None, :]
110 a = tl.load(a_ptrs, mask=mask_a)
111 b = tl.load(b_ptrs, mask=mask_b)
112 accumulator += tl.dot(a, b, allow_tf32=False)
113 offs_k += TILE_K
114 a_ptrs += TILE_K
115 b_ptrs += TILE_K * N
117 bias_ptrs = bias + offs_m[:, None] * bias_M_stride + offs_n[None, :] * bias_N_stride
119 if DIVISIBLE_M and DIVISIBLE_N:
120 mask_c = None
121 else:
122 mask_c = True
123 if not DIVISIBLE_M:
124 mask_c &= offs_m[:, None] < M
125 if not DIVISIBLE_N:
126 mask_c &= offs_n[None, :] < N
128 bi = tl.load(bias_ptrs, mask=mask_c)
129 out = accumulator * alpha + bi * beta
130 o = out.to(bi.dtype)
131 tl.store(o_ptrs, o, mask=mask_c)
134class BaddbmmFunction(torch.autograd.Function):
135 @staticmethod
136 def forward(ctx, bias, A, B, beta, alpha):
137 logger.debug("GEMS_ASCEND BADDBMM FORWARD")
139 ctx.save_for_backward(A, B, bias)
140 ctx.alpha = alpha
141 ctx.beta = beta
143 batch, M, K = A.shape
144 _, _, N = B.shape
145 A = A.contiguous()
146 B = B.contiguous()
147 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device)
149 bbias = torch.broadcast_to(bias, (batch, M, N)).contiguous()
150 bias_batch_stride = bbias.stride(0)
151 bias_M_stride = bbias.stride(1)
152 bias_N_stride = bbias.stride(-1)
154 grid = lambda meta: (
155 triton.cdiv(meta["M"], meta["TILE_M"]),
156 triton.cdiv(meta["N"], meta["TILE_N"]),
157 batch,
158 )
159 with torch_device_fn.device(A.device):
160 baddbmm_kernel[grid](
161 A,
162 B,
163 out,
164 bbias,
165 alpha,
166 beta,
167 M,
168 N,
169 K,
170 bias_batch_stride=bias_batch_stride,
171 bias_M_stride=bias_M_stride,
172 bias_N_stride=bias_N_stride,
173 )
174 return out
176 @staticmethod
177 def backward(ctx, grad_output):
178 logger.debug("GEMS_ASCEND BADDBMM BACKWARD")
179 A, B, bias = ctx.saved_tensors
181 grad_A = None
182 grad_B = None
183 grad_bias = None
184 if ctx.needs_input_grad[0]:
185 grad_bias = compute_bias_grad(grad_output, ctx.beta, bias)
186 if ctx.needs_input_grad[1]:
187 grad_A = compute_A_grad(grad_output, B, ctx.alpha)
188 if ctx.needs_input_grad[2]:
189 grad_B = compute_B_grad(A, grad_output, ctx.alpha)
191 return grad_bias, grad_A, grad_B, None, None
194def compute_bias_grad(d_output, beta, bias):
195 grad_bias = mul(d_output, beta)
196 if grad_bias.shape != bias.shape:
197 # Sum over broadcasted dimensions
198 while grad_bias.dim() > bias.dim():
199 grad_bias = grad_bias.sum(dim=0)
200 for i in range(bias.dim()):
201 if bias.shape[i] == 1 and grad_bias.shape[i] > 1:
202 grad_bias = grad_bias.sum(dim=i, keepdim=True)
203 return grad_bias.view(bias.shape)
206def compute_A_grad(d_output, B, alpha):
207 B_T = B.transpose(1, 2).contiguous()
208 if B.dtype == torch.float16:
209 Bcopy = B_T.to(torch.float32)
210 dcopye = d_output.to(torch.float32)
211 mul1 = bmm(dcopye, Bcopy)
212 grad_A = mul(mul1, alpha)
213 grad_A = grad_A.to(torch.float16)
214 else:
215 mul1 = bmm(d_output, B_T)
216 grad_A = mul(mul1, alpha)
217 return grad_A
220def compute_B_grad(A, d_output, alpha):
221 A_T = A.transpose(1, 2).contiguous()
222 if A.dtype == torch.float16:
223 Acopy = A_T.to(torch.float32)
224 dcopye = d_output.to(torch.float32)
225 mul2 = bmm(Acopy, dcopye)
226 grad_B = mul(mul2, alpha)
227 grad_B = grad_B.to(torch.float16)
228 else:
229 mul2 = bmm(A_T, d_output)
230 grad_B = mul(mul2, alpha)
231 return grad_B
234def baddbmm(bias, A, B, beta=1.0, alpha=1.0):
235 return BaddbmmFunction.apply(
236 bias.contiguous(),
237 A.contiguous(),
238 B.contiguous(),
239 beta,
240 alpha,
241 )