Coverage for src/flag_gems/runtime/backend/_arm/ops/attention.py: 0%
100 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1"""
2attention.py — ARM CPU Flash Attention (Triton-CPU)
4Flash Attention v2 with online softmax — no O(M*N) intermediate matrix.
5Supports GQA (grouped-query attention), is_causal, BF16 inputs.
7Performance (M=512, D=128, H=16, OMP=6, CIX P1 CD8180):
8 ATen: ~179ms -> Triton: ~40ms (4.5x speedup)
9 BLOCK_M=32, BLOCK_N=16 chosen via sweep.
11Decode path (M < BLOCK_M=32) falls back to ATen: tl.dot requires M>=4.
12Non-BF16 inputs or attn_mask also fall back to ATen.
13"""
14import ctypes
15import logging
16import os
18import torch
19import torch.nn.functional as F
20import triton
21import triton.language as tl
23log = logging.getLogger(__name__)
25# Preload libsleef (tl.math.exp2 in Triton-CPU .so depends on SLEEF symbols).
28def _ensure_sleef():
29 try:
30 import triton as _t
32 sleef_dir = os.path.join(os.path.dirname(_t.__file__), "_C")
33 sleef_so = os.path.join(sleef_dir, "libsleef.so.3")
34 if not os.path.exists(sleef_so):
35 return
36 ld = os.environ.get("LD_LIBRARY_PATH", "")
37 if sleef_dir not in ld:
38 os.environ["LD_LIBRARY_PATH"] = f"{sleef_dir}:{ld}"
39 ctypes.CDLL(sleef_so) # preload so later dlopen can resolve symbols
40 except Exception:
41 pass
44_ensure_sleef()
46# Keep the original ATen SDPA for internal fallback (avoids infinite recursion after monkey-patch).
47_aten_sdpa = F.scaled_dot_product_attention
49# Import once at module load. If triton-cpu lacks the runtime module (older
50# build), fall through to ATen for M=1 decode.
51try:
52 from triton.language.extra.cpu.runtime import (
53 flash_attn_decode_bf16 as _flash_attn_decode_bf16,
54 )
55except ImportError:
56 _flash_attn_decode_bf16 = None
58# log2(e) = 1/ln(2) — used so we can substitute exp2 for exp (avoids SLEEF precision loss).
59_LOG2E: float = 1.44269504089
61# Block sizes (BLOCK_N=16 chosen via sweep).
62_BLOCK_M: int = 32
63_BLOCK_N: int = 16
65# ── Flash Attention Triton Kernel ───────────────────────────────────────────
68@triton.jit
69def _flash_attn_fwd_kernel(
70 Q,
71 K,
72 V,
73 sm_scale,
74 Out,
75 # [B*Hq, M, D]
76 stride_qh,
77 stride_qm,
78 stride_qk,
79 # [B*Hkv, N, D]
80 stride_kh,
81 stride_kn,
82 stride_kk,
83 # [B*Hkv, N, D]
84 stride_vh,
85 stride_vn,
86 stride_vk,
87 # [B*Hq, M, D]
88 stride_oh,
89 stride_om,
90 stride_ok,
91 seqlen_q,
92 seqlen_k,
93 q_numhead,
94 kv_numhead, # GQA support
95 LOG2E: tl.constexpr, # 1.44269504
96 BLOCK_M: tl.constexpr,
97 BLOCK_N: tl.constexpr,
98 HEAD_DIM: tl.constexpr,
99 IS_CAUSAL: tl.constexpr, # compile-time constant: generates two code paths
100):
101 pid_bh = tl.program_id(0) # batch * Q-head (flattened)
102 pid_m = tl.program_id(1) # M-tile index
104 # GQA mapping: every (Hq // Hkv) Q-heads share one KV-head.
105 head_id = pid_bh % q_numhead
106 batch_id = pid_bh // q_numhead
107 kv_head_id = head_id * kv_numhead // q_numhead
109 Q_bh = Q + (batch_id * q_numhead + head_id) * stride_qh
110 K_bh = K + (batch_id * kv_numhead + kv_head_id) * stride_kh
111 V_bh = V + (batch_id * kv_numhead + kv_head_id) * stride_vh
112 O_bh = Out + (batch_id * q_numhead + head_id) * stride_oh
114 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
115 offs_k = tl.arange(0, HEAD_DIM)
116 mask_m = offs_m < seqlen_q
118 # Q: [BLOCK_M, HEAD_DIM], pre-multiplied by sm_scale*LOG2E (shifts into log2 domain).
119 q = tl.load(
120 Q_bh + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk,
121 mask=mask_m[:, None],
122 other=0.0,
123 ).to(tl.float32) * (sm_scale * LOG2E)
125 # Online softmax state (per-row, log2 domain).
126 m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
127 lse = tl.zeros([BLOCK_M], dtype=tl.float32)
128 acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
130 # Causal: only iterate up to the current Q-tile.
131 if IS_CAUSAL:
132 kv_end = tl.minimum(seqlen_k, (pid_m + 1) * BLOCK_M)
133 else:
134 kv_end = seqlen_k
136 for start_n in range(0, kv_end, BLOCK_N):
137 offs_n = start_n + tl.arange(0, BLOCK_N)
138 mask_n = offs_n < seqlen_k
140 # K^T: [HEAD_DIM, BLOCK_N] — swapping k/n offsets gives the transposed load.
141 k = tl.load(
142 K_bh + offs_k[:, None] * stride_kk + offs_n[None, :] * stride_kn,
143 mask=mask_n[None, :],
144 other=0.0,
145 ).to(tl.float32)
147 # QK^T: [BLOCK_M, HEAD_DIM] x [HEAD_DIM, BLOCK_N] -> [BLOCK_M, BLOCK_N].
148 # q is already in log2 domain (includes sm_scale*LOG2E), so exp2 can be applied directly.
149 qk = tl.dot(q.to(tl.bfloat16), k.to(tl.bfloat16)).to(tl.float32)
151 if IS_CAUSAL:
152 causal_ok = offs_m[:, None] >= offs_n[None, :]
153 qk = tl.where(causal_ok & mask_n[None, :], qk, float("-inf"))
154 else:
155 qk = tl.where(mask_n[None, :], qk, float("-inf"))
157 # Online softmax (log2 domain).
158 m_new = tl.maximum(m_i, tl.max(qk, axis=1)) # [BLOCK_M]
159 alpha = tl.math.exp2(m_i - m_new) # rescale previous rows
160 p = tl.math.exp2(qk - m_new[:, None]) # [BLOCK_M, BLOCK_N]
162 lse = lse * alpha + tl.sum(p, axis=1)
163 acc = acc * alpha[:, None]
165 # V: [BLOCK_N, HEAD_DIM]
166 v = tl.load(
167 V_bh + offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk,
168 mask=mask_n[:, None],
169 other=0.0,
170 ).to(tl.bfloat16)
172 # P @ V: [BLOCK_M, BLOCK_N] × [BLOCK_N, HEAD_DIM] → [BLOCK_M, HEAD_DIM]
173 acc = tl.dot(p.to(tl.bfloat16), v, acc=acc)
174 m_i = m_new
176 # Normalize and write back.
177 acc = acc / lse[:, None]
178 tl.store(
179 O_bh + offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok,
180 acc.to(tl.bfloat16),
181 mask=mask_m[:, None],
182 )
185# Python wrappers.
188def _triton_flash_attn(
189 query: torch.Tensor,
190 key: torch.Tensor,
191 value: torch.Tensor,
192 sm_scale: float,
193 is_causal: bool,
194) -> torch.Tensor:
195 """Core Triton kernel invocation. Caller must have already verified the Triton path applies."""
196 B, Hq, M, D = query.shape
197 Hkv = key.shape[1]
199 # Flatten batch+head -> [B*H, seq, D]
200 q = query.reshape(B * Hq, M, D)
201 k = key.reshape(B * Hkv, -1, D)
202 v = value.reshape(B * Hkv, -1, D)
203 N = k.shape[1]
204 out = torch.empty_like(q)
206 grid = (B * Hq, triton.cdiv(M, _BLOCK_M))
208 _flash_attn_fwd_kernel[grid](
209 q,
210 k,
211 v,
212 sm_scale,
213 out,
214 q.stride(0),
215 q.stride(1),
216 q.stride(2),
217 k.stride(0),
218 k.stride(1),
219 k.stride(2),
220 v.stride(0),
221 v.stride(1),
222 v.stride(2),
223 out.stride(0),
224 out.stride(1),
225 out.stride(2),
226 M,
227 N,
228 Hq,
229 Hkv,
230 _LOG2E,
231 BLOCK_M=_BLOCK_M,
232 BLOCK_N=_BLOCK_N,
233 HEAD_DIM=D,
234 IS_CAUSAL=is_causal,
235 )
236 return out.reshape(B, Hq, M, D)
239def scaled_dot_product_attention(
240 query,
241 key,
242 value,
243 attn_mask=None,
244 dropout_p=0.0,
245 is_causal=False,
246 scale=None,
247 enable_gqa=False,
248):
249 """
250 aten::scaled_dot_product_attention — ARM CPU Flash Attention.
252 Triton path conditions (otherwise fall back to ATen):
253 - dtype = bfloat16
254 - attn_mask = None
255 - dropout_p = 0.0
256 - seqlen_q >= BLOCK_M (=32)
257 - head_dim in {16,32,64,128,256}
258 """
259 B, Hq, M, D = query.shape
261 # M=1 decode fast path: C runtime flash_attn_decode_bf16 via triton-cpu.
262 # Measured +1.2% E2E on Qwen3-1.7B INT8 vs ATen fallback (3 rounds A/B).
263 # Requires BF16, no mask, no dropout, contiguous Q/K/V.
264 if (
265 _flash_attn_decode_bf16 is not None
266 and M == 1
267 and B == 1
268 and query.dtype == torch.bfloat16
269 and attn_mask is None
270 and dropout_p == 0.0
271 and query.is_contiguous()
272 and key.is_contiguous()
273 and value.is_contiguous()
274 ):
275 Hkv = key.shape[1]
276 seq_len = key.shape[2]
277 sm_scale = scale if scale is not None else D**-0.5
278 q_flat = query.squeeze(0).squeeze(1).contiguous()
279 k_flat = key.squeeze(0).contiguous()
280 v_flat = value.squeeze(0).contiguous()
281 out_flat = torch.empty(Hq, D, dtype=torch.bfloat16)
282 _flash_attn_decode_bf16(
283 q_flat,
284 k_flat,
285 v_flat,
286 out_flat,
287 seq_len,
288 D,
289 sm_scale,
290 Hq,
291 Hkv,
292 k_flat.stride(1),
293 v_flat.stride(1),
294 )
295 return out_flat.unsqueeze(0).unsqueeze(2)
297 # Prefill fast path: Triton Flash Attention kernel (requires M >= BLOCK_M).
298 use_triton = (
299 query.dtype == torch.bfloat16
300 and attn_mask is None
301 and dropout_p == 0.0
302 and M >= _BLOCK_M
303 and D in {16, 32, 64, 128, 256}
304 )
306 if not use_triton:
307 log.debug(
308 "GEMS SDPA: ATen fallback (M=%d, dtype=%s, mask=%s)",
309 M,
310 query.dtype,
311 attn_mask is not None,
312 )
313 return _aten_sdpa(
314 query,
315 key,
316 value,
317 attn_mask=attn_mask,
318 dropout_p=dropout_p,
319 is_causal=is_causal,
320 scale=scale,
321 enable_gqa=enable_gqa,
322 )
324 sm_scale = scale if scale is not None else D**-0.5
325 log.debug(
326 "GEMS SDPA: Triton Flash Attention (M=%d, N=%d, D=%d, causal=%s, Hq=%d, Hkv=%d)",
327 M,
328 key.shape[2],
329 D,
330 is_causal,
331 Hq,
332 key.shape[1],
333 )
334 return _triton_flash_attn(query, key, value, sm_scale, is_causal)