Coverage for src/flag_gems/fused/DSA/sparse_mla.py: 24%
173 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils.triton_version_utils import HAS_TLE
9if HAS_TLE:
10 import triton.experimental.tle.language as tle
11else:
12 tle = None
14logger = logging.getLogger(__name__)
16spar_mla_fwd_configs = [
17 triton.Config({"num_stages": 4}, num_warps=8),
18 triton.Config({"num_stages": 2}, num_warps=4),
19]
22@triton.autotune( # Decorate the kernel
23 configs=spar_mla_fwd_configs,
24 key=["K", "is_causal"],
25)
26@triton.jit
27def triton_sparse_mla_fwd(
28 q,
29 kv,
30 indices,
31 sm_scale: tl.constexpr,
32 output,
33 lse,
34 stride_qb,
35 stride_qh,
36 stride_qm,
37 stride_qd,
38 stride_kvb,
39 stride_kvg,
40 stride_kvn,
41 stride_kvd,
42 stride_tb,
43 stride_tg,
44 stride_tm,
45 stride_tt, # indices dim
46 stride_ob,
47 stride_oh,
48 stride_om,
49 stride_od,
50 stride_lb,
51 stride_lh,
52 stride_lm,
53 SQ: tl.constexpr, # seqlen
54 K: tl.constexpr, # topk
55 D: tl.constexpr, # QKV dim
56 TD: tl.constexpr, # tail dim
57 DP: tl.constexpr,
58 TDP: tl.constexpr,
59 G: tl.constexpr, # group_size
60 BK: tl.constexpr,
61 BH: tl.constexpr,
62 is_causal: tl.constexpr,
63):
64 i_b, i_sq, i_gbh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
65 NH = tl.cdiv(G, BH)
66 i_g, i_bh = i_gbh // NH, i_gbh % NH
67 q_base = q + i_b * stride_qb + i_sq * stride_qm + i_gbh * (BH * stride_qh)
68 tq_base = q_base + D * stride_qd
69 kv_base = kv + i_b * stride_kvb + i_g * stride_kvg
70 tkv_base = kv_base + D * stride_kvd
71 t_base = indices + i_b * stride_tb + i_sq * stride_tm + i_g * stride_tg
72 o_base = output + i_b * stride_ob + i_sq * stride_om + i_gbh * (BH * stride_oh)
73 l_base = lse + i_b * stride_lb + i_sq * stride_lm + i_gbh * (BH * stride_lh)
75 offs_h = tl.arange(0, BH)
76 offs_d = tl.arange(0, DP)
77 offs_td = tl.arange(0, TDP)
78 offs_od = tl.arange(0, DP)
79 offs_t = tl.arange(0, BK)
80 mask_h = i_bh * BH + offs_h < G
81 mask_d = offs_d < D
82 mask_td = offs_td < TD
83 mask_od = mask_d
85 q_ptr = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd
86 q_msk = mask_h[:, None] & mask_d[None, :]
87 q_blk = tl.load(q_ptr, q_msk, other=0.0).to(tl.float16)
89 tq_ptr = tq_base + offs_h[:, None] * stride_qh + offs_td[None, :] * stride_qd
90 tq_msk = mask_h[:, None] & mask_td[None, :]
91 tq_blk = tl.load(tq_ptr, tq_msk, other=0.0).to(tl.float16)
93 max_log = tl.full([BH], float("-inf"), dtype=tl.float16)
94 sum_exp = tl.full([BH], 1.0, dtype=tl.float16)
95 acc = tl.zeros([BH, DP], dtype=tl.float16)
96 qk = tl.zeros([BH, BK], dtype=tl.float16)
98 log_scale: tl.constexpr = sm_scale * 1.44269504
100 # max_col = max(0, i_sq + SKV - SQ) if is_causal else SKV-1
101 max_col = i_sq if is_causal else SQ - 1
103 NK = tl.cdiv(K, BK)
104 for ck in range(NK):
105 t_ptr = (BK * ck + offs_t) * stride_tt
106 t_msk = t_ptr < K
107 t_ptr += t_base
108 kv_ids = tl.load(t_ptr, t_msk, other=-1)
109 mask_ids = (kv_ids <= max_col) & (kv_ids >= 0)
111 if tl.max(mask_ids, axis=0) > 0:
112 kv_ptr = (
113 kv_base + offs_d[:, None] * stride_kvd + kv_ids[None, :] * stride_kvn
114 )
115 kv_msk = mask_d[:, None] & mask_ids[None, :]
116 kv_blk = tl.load(kv_ptr, kv_msk, other=0.0).to(tl.float16) # [DP, BK]
118 tkv_ptr = (
119 tkv_base + offs_td[:, None] * stride_kvd + kv_ids[None, :] * stride_kvn
120 )
121 tkv_msk = mask_td[:, None] & mask_ids[None, :]
122 tkv_blk = tl.load(tkv_ptr, tkv_msk, other=0.0).to(tl.float16) # [TDP, BK]
124 qk = tl.dot(q_blk, kv_blk, out_dtype=tl.float16)
125 qk = tl.dot(tq_blk, tkv_blk, qk, out_dtype=tl.float16) * log_scale
126 # qk = tl.dot(tq_blk, tkv_blk, qk, out_dtype=tl.float16) * sm_scale
128 qk = tl.where(mask_ids[None, :], qk, float("-inf")) # [BH, BK]
130 new_max = tl.maximum(max_log, tl.max(qk, axis=1))
131 exp_qk = tl.math.exp2(qk - new_max[:, None]).to(tl.float16)
132 # exp_qk = tl.math.exp(qk - new_max[:, None]).to(tl.float16)
133 sum_qk = tl.sum(exp_qk, axis=1)
134 alpha = tl.math.exp2(max_log - new_max).to(tl.float16)
135 # alpha = tl.math.exp(max_log - new_max).to(tl.float16)
136 sum_exp = sum_exp * alpha + sum_qk
137 acc = acc * alpha[:, None]
138 acc = tl.dot(
139 exp_qk, kv_blk.trans(), acc, out_dtype=tl.float16
140 ) # [BH, BK] @ [BK, DP] = [BH, DP]
142 max_log = new_max.to(tl.float16)
144 out_vals = acc / sum_exp[:, None]
145 o_ptr = o_base + offs_h[:, None] * stride_oh + offs_od[None, :] * stride_od
146 o_msk = mask_h[:, None] & mask_od[None, :]
147 # o_msk &= tl.zeros_like(o_msk)
148 tl.store(o_ptr, out_vals.to(q_blk.dtype), o_msk)
150 fin_log = max_log + tl.math.log2(sum_exp.to(tl.float32)) # return lse / ln2
151 # fin_log *= 0.69314718
152 # fin_log = max_log + tl.math.log(sum_exp.to(tl.float32))
153 # fin_log *= 1.44269504 # return lse / ln2
154 l_ptr = l_base + offs_h * stride_lh
155 l_msk = mask_h
156 tl.store(l_ptr, fin_log.to(q_blk.dtype), l_msk)
159if HAS_TLE:
161 @triton.autotune(
162 configs=spar_mla_fwd_configs,
163 key=["K", "is_causal"],
164 )
165 @triton.jit
166 def triton_sparse_mla_fwd_tle(
167 q,
168 kv,
169 indices,
170 sm_scale: tl.constexpr,
171 output,
172 lse,
173 stride_qb,
174 stride_qh,
175 stride_qm,
176 stride_qd,
177 stride_kvb,
178 stride_kvg,
179 stride_kvn,
180 stride_kvd,
181 stride_tb,
182 stride_tg,
183 stride_tm,
184 stride_tt,
185 stride_ob,
186 stride_oh,
187 stride_om,
188 stride_od,
189 stride_lb,
190 stride_lh,
191 stride_lm,
192 SQ: tl.constexpr,
193 K: tl.constexpr,
194 D: tl.constexpr,
195 TD: tl.constexpr,
196 DP: tl.constexpr,
197 TDP: tl.constexpr,
198 G: tl.constexpr,
199 BK: tl.constexpr,
200 BH: tl.constexpr,
201 is_causal: tl.constexpr,
202 ):
203 i_b, i_sq, i_gbh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
204 i_g, i_bh = i_gbh // G, i_gbh % G
205 q_base = q + i_b * stride_qb + i_sq * stride_qm + i_gbh * (BH * stride_qh)
206 tq_base = q_base + D * stride_qd
207 kv_base = kv + i_b * stride_kvb + i_g * stride_kvg
208 tkv_base = kv_base + D * stride_kvd
209 t_base = indices + i_b * stride_tb + i_sq * stride_tm + i_g * stride_tg
210 o_base = output + i_b * stride_ob + i_sq * stride_om + i_gbh * (BH * stride_oh)
211 l_base = lse + i_b * stride_lb + i_sq * stride_lm + i_gbh * (BH * stride_lh)
213 offs_h = tl.arange(0, BH)
214 offs_d = tl.arange(0, DP)
215 offs_td = tl.arange(0, TDP)
216 offs_od = tl.arange(0, DP)
217 offs_t = tl.arange(0, BK)
218 mask_h = i_bh * BH + offs_h < G
219 mask_d = offs_d < D
220 mask_td = offs_td < TD
221 mask_od = mask_d
223 q_ptr = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd
224 q_msk = mask_h[:, None] & mask_d[None, :]
225 q_blk = tl.load(q_ptr, q_msk, other=0.0)
227 tq_ptr = tq_base + offs_h[:, None] * stride_qh + offs_td[None, :] * stride_qd
228 tq_msk = mask_h[:, None] & mask_td[None, :]
229 tq_blk = tl.load(tq_ptr, tq_msk, other=0.0)
231 max_prev = tl.full([BH], float("-inf"), dtype=tl.float32)
232 sum_exp = tl.full([BH], 1.0, dtype=tl.float32)
233 acc = tl.zeros([BH, DP], dtype=tl.float32)
235 log_scale: tl.constexpr = sm_scale * 1.44269504
237 max_col = i_sq if is_causal else SQ - 1
239 NK = tl.cdiv(K, BK)
240 for ck in tl.range(NK, num_stages=0):
241 if ck * BK <= max_col:
242 t_ptr = (BK * ck + offs_t) * stride_tt
243 t_msk = t_ptr < K
244 t_ptr += t_base
245 kv_ids = tl.load(t_ptr, t_msk, other=-1)
246 mask_ids = (kv_ids <= max_col) & (kv_ids >= 0)
248 kv_ptr = (
249 kv_base
250 + offs_d[:, None] * stride_kvd
251 + kv_ids[None, :] * stride_kvn
252 )
253 kv_msk = mask_d[:, None] & mask_ids[None, :]
254 kv_blk = tle.load(kv_ptr, kv_msk, other=0.0, is_async=True)
256 tkv_ptr = (
257 tkv_base
258 + offs_td[:, None] * stride_kvd
259 + kv_ids[None, :] * stride_kvn
260 )
261 tkv_msk = mask_td[:, None] & mask_ids[None, :]
262 tkv_blk = tle.load(tkv_ptr, tkv_msk, other=0.0, is_async=False)
264 qk = tl.dot(tq_blk, tkv_blk, out_dtype=tl.float32)
265 qk = tl.dot(q_blk, kv_blk, qk, out_dtype=tl.float32)
267 qk = tl.where(mask_ids[None, :], qk, float("-inf"))
269 new_max = tl.maximum(max_prev, tl.max(qk, axis=1))
270 alpha = tl.math.exp2((max_prev - new_max) * log_scale)
271 exp_qk = tl.math.exp2(qk * log_scale - new_max[:, None] * log_scale)
272 sum_qk = tl.sum(exp_qk, axis=1)
273 sum_exp = sum_exp * alpha + sum_qk
274 acc = acc * alpha[:, None]
275 exp_qk = exp_qk.to(tl.bfloat16)
276 acc = tl.dot(exp_qk, tl.trans(kv_blk), acc, out_dtype=tl.float32)
278 max_prev = new_max
280 out_vals = acc / sum_exp[:, None]
281 o_ptr = o_base + offs_h[:, None] * stride_oh + offs_od[None, :] * stride_od
282 o_msk = mask_h[:, None] & mask_od
283 tl.store(o_ptr, out_vals.to(q_blk.dtype), o_msk)
285 fin_log = max_prev * log_scale + tl.math.log2(sum_exp.to(tl.float32))
286 l_ptr = l_base + offs_h * stride_lh
287 l_msk = mask_h
288 tl.store(l_ptr, fin_log.to(q_blk.dtype), l_msk)
291def triton_sparse_mla_fwd_interface(
292 q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512
293):
294 logger.debug("GEMS SPARSE_MLA_FWD_INTERFACE")
295 is_causal = True
296 assert return_p_sum is False, "This kernel file is for fwd only"
297 assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
298 B, SQ, H, DT = q.shape
299 _, _, VG, _ = kv.shape
301 # assert DT == 576, "you should assign dim otherwise"
302 D = d_v
304 assert kv.shape[-1] == DT
305 TD = DT - D
306 DP = triton.next_power_of_2(D)
307 TDP = triton.next_power_of_2(TD)
308 assert kv.shape[0] == B
309 _, _, _, K = indices.shape
310 assert indices.shape == (B, SQ, VG, K)
311 G = H // VG
312 if sm_scale is None:
313 sm_scale = DT**-0.5
314 BH = max(16, min(64, triton.next_power_of_2(G)))
315 NH = triton.cdiv(G, BH)
316 BK = 32
317 output = torch.zeros((B, SQ, H, D), device=q.device, dtype=q.dtype)
318 lse = torch.full((B, SQ, H), float("-inf"), device=q.device, dtype=q.dtype)
319 grid = (B, SQ, VG * NH) # (SQ//BQ, B*H)
320 kernel_args = (
321 q,
322 kv,
323 indices,
324 sm_scale,
325 output,
326 lse,
327 q.stride(0),
328 q.stride(2),
329 q.stride(1),
330 q.stride(3), # [B, H, SQ, DT]
331 kv.stride(0),
332 kv.stride(2),
333 kv.stride(1),
334 kv.stride(3), # [B, VG, SKV, DT]
335 indices.stride(0),
336 indices.stride(2),
337 indices.stride(1),
338 indices.stride(3), # [B, VG, SQ, K]
339 output.stride(0),
340 output.stride(2),
341 output.stride(1),
342 output.stride(3), # [B, H, SQ, D]
343 lse.stride(0),
344 lse.stride(2),
345 lse.stride(1), # [B, H, SQ]
346 SQ,
347 K,
348 D,
349 TD,
350 DP,
351 TDP,
352 G,
353 BK,
354 BH,
355 is_causal,
356 )
357 if HAS_TLE:
358 triton_sparse_mla_fwd_tle[grid](*kernel_args)
359 else:
360 triton_sparse_mla_fwd[grid](*kernel_args)
361 return output, lse