Coverage for src/flag_gems/runtime/backend/_ascend/fused/flash_mla.py: 0%
95 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import device, torch_device_fn
9from flag_gems.utils import triton_lang_extension as tle
11vendor_name = device.vendor_name
12device = device.name
13logger = logging.getLogger(__name__)
16# @triton.autotune(
17# configs=[
18# triton.Config({"BLOCK_H": h, "BLOCK_N": n}, num_warps=w, num_stages=s)
19# for h in [32, 64, 128]
20# for n in [32, 64, 128]
21# for w in [4, 8]
22# for s in [1, 2]
23# ],
24# key=["head_num"]
25# )
26@triton.heuristics(
27 values={
28 "EVEN_H": lambda META: META["head_num"] % META["BLOCK_H"] == 0,
29 }
30)
31@triton.jit
32def flash_mla_attn_kernel(
33 Q_ptr,
34 Kv_cache,
35 Req_to_tokens,
36 B_seq_len,
37 O,
38 sm_scale,
39 head_num,
40 stride_q_bs,
41 stride_q_h,
42 stride_kv_bs,
43 stride_req_to_tokens_bs,
44 stride_o_b,
45 stride_o_h,
46 stride_o_s,
47 BLOCK_H: tl.constexpr,
48 BLOCK_N: tl.constexpr,
49 EVEN_H: tl.constexpr,
50 PAGE_SIZE: tl.constexpr,
51 HEAD_DIM_V: tl.constexpr,
52 HEAD_DIM: tl.constexpr,
53):
54 cur_head_id = tle.program_id(0)
55 cur_batch_id = tle.program_id(1)
56 Req_to_tokens += stride_req_to_tokens_bs * cur_batch_id
58 cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)
60 offs_d_ckv = tl.arange(0, HEAD_DIM_V)
61 offs_q_nope = (
62 cur_batch_id * stride_q_bs
63 + cur_head[:, None] * stride_q_h
64 + offs_d_ckv[None, :]
65 )
67 offs_d_kpe = tl.arange(HEAD_DIM_V, HEAD_DIM)
68 offs_q_pe = (
69 cur_batch_id * stride_q_bs
70 + cur_head[:, None] * stride_q_h
71 + offs_d_kpe[None, :]
72 )
74 if EVEN_H:
75 q_nope = tl.load(Q_ptr + offs_q_nope)
76 q_pe = tl.load(Q_ptr + offs_q_pe)
77 else:
78 mask_head = cur_head < head_num
79 q_nope = tl.load(Q_ptr + offs_q_nope, mask=mask_head[:, None])
80 q_pe = tl.load(Q_ptr + offs_q_pe, mask=mask_head[:, None])
82 e_max = tl.full([BLOCK_H], value=float("-inf"), dtype=tl.float32)
83 e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
84 acc = tl.zeros([BLOCK_H, HEAD_DIM_V], dtype=tl.float32)
86 cur_batch_seq_len = tl.load(B_seq_len + cur_batch_id)
87 loop_time = cur_batch_seq_len // BLOCK_N
88 remainder = cur_batch_seq_len % BLOCK_N
89 offs_n = tl.arange(0, BLOCK_N)
90 for i in range(0, loop_time):
91 kv_page_number = tl.load(Req_to_tokens + offs_n // PAGE_SIZE)
92 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
93 offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :]
94 v_c = tl.load(Kv_cache + offs_v_c)
95 k_c = tl.trans(v_c)
97 qk = tl.dot(q_nope, k_c) # qk_nope
99 offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None]
100 k_pe = tl.load(Kv_cache + offs_k_pe)
102 qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope
103 qk *= sm_scale
105 n_e_max = tl.maximum(tl.max(qk, 1), e_max)
106 re_scale = tl.exp(e_max - n_e_max)
107 p = tl.exp(qk - n_e_max[:, None])
108 acc *= re_scale[:, None]
109 acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc)
111 e_sum = e_sum * re_scale + tl.sum(p, 1)
112 e_max = n_e_max
113 offs_n += BLOCK_N
115 if remainder:
116 mask_kvsplit = offs_n < cur_batch_seq_len
117 kv_page_number = tl.load(
118 Req_to_tokens + offs_n // PAGE_SIZE,
119 mask=mask_kvsplit,
120 other=0,
121 )
122 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
123 offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :]
124 v_c = tl.load(Kv_cache + offs_v_c, mask=mask_kvsplit[:, None], other=0.0)
125 k_c = tl.trans(v_c)
127 qk = tl.dot(q_nope, k_c) # qk_nope
129 offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None]
130 k_pe = tl.load(Kv_cache + offs_k_pe, mask=mask_kvsplit[None, :], other=0.0)
132 qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope
133 qk *= sm_scale
135 qk = tl.where(mask_kvsplit[None, :], qk, float("-inf"))
137 n_e_max = tl.maximum(tl.max(qk, 1), e_max)
138 re_scale = tl.exp(e_max - n_e_max)
139 p = tl.exp(qk - n_e_max[:, None])
140 acc *= re_scale[:, None]
141 acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc)
143 e_sum = e_sum * re_scale + tl.sum(p, 1)
145 offs_o = (
146 cur_batch_id * stride_o_b + cur_head[:, None] * stride_o_h + offs_d_ckv[None, :]
147 )
148 if EVEN_H:
149 tl.store(
150 O + offs_o,
151 acc / e_sum[:, None],
152 )
153 else:
154 tl.store(O + offs_o, acc / e_sum[:, None], mask=mask_head[:, None])
157def flash_mla(
158 q,
159 block_table,
160 blocked_k,
161 max_seqlen_pad,
162 block_size,
163 b,
164 s_q,
165 cache_seqlens,
166 h_q,
167 h_kv,
168 d,
169 dv,
170 causal,
171):
172 logger.debug("GEMS_ASCEND FLASH MLA")
173 print("GEMS FLASH MLA")
174 assert causal, "causal False not supported"
175 assert d > dv, "mla with rope dim should be larger than no rope dim"
177 batch_size, s_q, head_num, d = list(q.shape)
178 q = q.view([-1, head_num, d]).contiguous()
179 blocked_k = blocked_k.view([-1, d]).contiguous()
180 block_table = block_table.contiguous()
181 cache_seqlens = cache_seqlens.contiguous()
183 sm_scale = 1 / math.sqrt(d)
185 o = torch.empty([b * s_q, h_q, dv], dtype=q.dtype, device=device)
187 # major, _ = torch_device_fn.get_device_capability(device)
188 # if major == 9:
189 # BLOCK_H = 64
190 # num_stages = 3
191 # elif major == 8:
192 # BLOCK_H = 32
193 # num_stages = 2
194 # elif major == 7 and vendor_name == "iluvatar":
195 # BLOCK_H = 32
196 # num_stages = 1
197 # elif major == 3 and vendor_name == "mthreads":
198 # BLOCK_H = 32
199 # num_stages = 1
200 # else:
201 # error.backend_not_support(device)
202 BLOCK_H = 16
203 BLOCK_N = 16
204 num_stages = 1
206 grid = (
207 triton.cdiv(head_num, BLOCK_H),
208 batch_size,
209 )
210 with torch_device_fn.device(device):
211 flash_mla_attn_kernel[grid](
212 q,
213 blocked_k,
214 block_table,
215 cache_seqlens,
216 o,
217 sm_scale,
218 head_num,
219 # stride
220 q.stride(0),
221 q.stride(1),
222 blocked_k.stride(-2),
223 block_table.stride(0),
224 o.stride(0),
225 o.stride(1),
226 o.stride(2),
227 BLOCK_H=BLOCK_H,
228 BLOCK_N=BLOCK_N,
229 PAGE_SIZE=block_size,
230 HEAD_DIM_V=dv,
231 HEAD_DIM=d,
232 num_warps=8,
233 num_stages=num_stages,
234 )
236 return o.view([b, s_q, h_q, dv])