Coverage for src/flag_gems/runtime/backend/_arm/ops/attention.py: 0%

100 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1""" 

2attention.py — ARM CPU Flash Attention (Triton-CPU) 

3 

4Flash Attention v2 with online softmax — no O(M*N) intermediate matrix. 

5Supports GQA (grouped-query attention), is_causal, BF16 inputs. 

6 

7Performance (M=512, D=128, H=16, OMP=6, CIX P1 CD8180): 

8 ATen: ~179ms -> Triton: ~40ms (4.5x speedup) 

9 BLOCK_M=32, BLOCK_N=16 chosen via sweep. 

10 

11Decode path (M < BLOCK_M=32) falls back to ATen: tl.dot requires M>=4. 

12Non-BF16 inputs or attn_mask also fall back to ATen. 

13""" 

14import ctypes 

15import logging 

16import os 

17 

18import torch 

19import torch.nn.functional as F 

20import triton 

21import triton.language as tl 

22 

23log = logging.getLogger(__name__) 

24 

25# Preload libsleef (tl.math.exp2 in Triton-CPU .so depends on SLEEF symbols). 

26 

27 

28def _ensure_sleef(): 

29 try: 

30 import triton as _t 

31 

32 sleef_dir = os.path.join(os.path.dirname(_t.__file__), "_C") 

33 sleef_so = os.path.join(sleef_dir, "libsleef.so.3") 

34 if not os.path.exists(sleef_so): 

35 return 

36 ld = os.environ.get("LD_LIBRARY_PATH", "") 

37 if sleef_dir not in ld: 

38 os.environ["LD_LIBRARY_PATH"] = f"{sleef_dir}:{ld}" 

39 ctypes.CDLL(sleef_so) # preload so later dlopen can resolve symbols 

40 except Exception: 

41 pass 

42 

43 

44_ensure_sleef() 

45 

46# Keep the original ATen SDPA for internal fallback (avoids infinite recursion after monkey-patch). 

47_aten_sdpa = F.scaled_dot_product_attention 

48 

49# Import once at module load. If triton-cpu lacks the runtime module (older 

50# build), fall through to ATen for M=1 decode. 

51try: 

52 from triton.language.extra.cpu.runtime import ( 

53 flash_attn_decode_bf16 as _flash_attn_decode_bf16, 

54 ) 

55except ImportError: 

56 _flash_attn_decode_bf16 = None 

57 

58# log2(e) = 1/ln(2) — used so we can substitute exp2 for exp (avoids SLEEF precision loss). 

59_LOG2E: float = 1.44269504089 

60 

61# Block sizes (BLOCK_N=16 chosen via sweep). 

62_BLOCK_M: int = 32 

63_BLOCK_N: int = 16 

64 

65# ── Flash Attention Triton Kernel ─────────────────────────────────────────── 

66 

67 

68@triton.jit 

69def _flash_attn_fwd_kernel( 

70 Q, 

71 K, 

72 V, 

73 sm_scale, 

74 Out, 

75 # [B*Hq, M, D] 

76 stride_qh, 

77 stride_qm, 

78 stride_qk, 

79 # [B*Hkv, N, D] 

80 stride_kh, 

81 stride_kn, 

82 stride_kk, 

83 # [B*Hkv, N, D] 

84 stride_vh, 

85 stride_vn, 

86 stride_vk, 

87 # [B*Hq, M, D] 

88 stride_oh, 

89 stride_om, 

90 stride_ok, 

91 seqlen_q, 

92 seqlen_k, 

93 q_numhead, 

94 kv_numhead, # GQA support 

95 LOG2E: tl.constexpr, # 1.44269504 

96 BLOCK_M: tl.constexpr, 

97 BLOCK_N: tl.constexpr, 

98 HEAD_DIM: tl.constexpr, 

99 IS_CAUSAL: tl.constexpr, # compile-time constant: generates two code paths 

100): 

101 pid_bh = tl.program_id(0) # batch * Q-head (flattened) 

102 pid_m = tl.program_id(1) # M-tile index 

103 

104 # GQA mapping: every (Hq // Hkv) Q-heads share one KV-head. 

105 head_id = pid_bh % q_numhead 

106 batch_id = pid_bh // q_numhead 

107 kv_head_id = head_id * kv_numhead // q_numhead 

108 

109 Q_bh = Q + (batch_id * q_numhead + head_id) * stride_qh 

110 K_bh = K + (batch_id * kv_numhead + kv_head_id) * stride_kh 

111 V_bh = V + (batch_id * kv_numhead + kv_head_id) * stride_vh 

112 O_bh = Out + (batch_id * q_numhead + head_id) * stride_oh 

113 

114 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

115 offs_k = tl.arange(0, HEAD_DIM) 

116 mask_m = offs_m < seqlen_q 

117 

118 # Q: [BLOCK_M, HEAD_DIM], pre-multiplied by sm_scale*LOG2E (shifts into log2 domain). 

119 q = tl.load( 

120 Q_bh + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk, 

121 mask=mask_m[:, None], 

122 other=0.0, 

123 ).to(tl.float32) * (sm_scale * LOG2E) 

124 

125 # Online softmax state (per-row, log2 domain). 

126 m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) 

127 lse = tl.zeros([BLOCK_M], dtype=tl.float32) 

128 acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) 

129 

130 # Causal: only iterate up to the current Q-tile. 

131 if IS_CAUSAL: 

132 kv_end = tl.minimum(seqlen_k, (pid_m + 1) * BLOCK_M) 

133 else: 

134 kv_end = seqlen_k 

135 

136 for start_n in range(0, kv_end, BLOCK_N): 

137 offs_n = start_n + tl.arange(0, BLOCK_N) 

138 mask_n = offs_n < seqlen_k 

139 

140 # K^T: [HEAD_DIM, BLOCK_N] — swapping k/n offsets gives the transposed load. 

141 k = tl.load( 

142 K_bh + offs_k[:, None] * stride_kk + offs_n[None, :] * stride_kn, 

143 mask=mask_n[None, :], 

144 other=0.0, 

145 ).to(tl.float32) 

146 

147 # QK^T: [BLOCK_M, HEAD_DIM] x [HEAD_DIM, BLOCK_N] -> [BLOCK_M, BLOCK_N]. 

148 # q is already in log2 domain (includes sm_scale*LOG2E), so exp2 can be applied directly. 

149 qk = tl.dot(q.to(tl.bfloat16), k.to(tl.bfloat16)).to(tl.float32) 

150 

151 if IS_CAUSAL: 

152 causal_ok = offs_m[:, None] >= offs_n[None, :] 

153 qk = tl.where(causal_ok & mask_n[None, :], qk, float("-inf")) 

154 else: 

155 qk = tl.where(mask_n[None, :], qk, float("-inf")) 

156 

157 # Online softmax (log2 domain). 

158 m_new = tl.maximum(m_i, tl.max(qk, axis=1)) # [BLOCK_M] 

159 alpha = tl.math.exp2(m_i - m_new) # rescale previous rows 

160 p = tl.math.exp2(qk - m_new[:, None]) # [BLOCK_M, BLOCK_N] 

161 

162 lse = lse * alpha + tl.sum(p, axis=1) 

163 acc = acc * alpha[:, None] 

164 

165 # V: [BLOCK_N, HEAD_DIM] 

166 v = tl.load( 

167 V_bh + offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk, 

168 mask=mask_n[:, None], 

169 other=0.0, 

170 ).to(tl.bfloat16) 

171 

172 # P @ V: [BLOCK_M, BLOCK_N] × [BLOCK_N, HEAD_DIM] → [BLOCK_M, HEAD_DIM] 

173 acc = tl.dot(p.to(tl.bfloat16), v, acc=acc) 

174 m_i = m_new 

175 

176 # Normalize and write back. 

177 acc = acc / lse[:, None] 

178 tl.store( 

179 O_bh + offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok, 

180 acc.to(tl.bfloat16), 

181 mask=mask_m[:, None], 

182 ) 

183 

184 

185# Python wrappers. 

186 

187 

188def _triton_flash_attn( 

189 query: torch.Tensor, 

190 key: torch.Tensor, 

191 value: torch.Tensor, 

192 sm_scale: float, 

193 is_causal: bool, 

194) -> torch.Tensor: 

195 """Core Triton kernel invocation. Caller must have already verified the Triton path applies.""" 

196 B, Hq, M, D = query.shape 

197 Hkv = key.shape[1] 

198 

199 # Flatten batch+head -> [B*H, seq, D] 

200 q = query.reshape(B * Hq, M, D) 

201 k = key.reshape(B * Hkv, -1, D) 

202 v = value.reshape(B * Hkv, -1, D) 

203 N = k.shape[1] 

204 out = torch.empty_like(q) 

205 

206 grid = (B * Hq, triton.cdiv(M, _BLOCK_M)) 

207 

208 _flash_attn_fwd_kernel[grid]( 

209 q, 

210 k, 

211 v, 

212 sm_scale, 

213 out, 

214 q.stride(0), 

215 q.stride(1), 

216 q.stride(2), 

217 k.stride(0), 

218 k.stride(1), 

219 k.stride(2), 

220 v.stride(0), 

221 v.stride(1), 

222 v.stride(2), 

223 out.stride(0), 

224 out.stride(1), 

225 out.stride(2), 

226 M, 

227 N, 

228 Hq, 

229 Hkv, 

230 _LOG2E, 

231 BLOCK_M=_BLOCK_M, 

232 BLOCK_N=_BLOCK_N, 

233 HEAD_DIM=D, 

234 IS_CAUSAL=is_causal, 

235 ) 

236 return out.reshape(B, Hq, M, D) 

237 

238 

239def scaled_dot_product_attention( 

240 query, 

241 key, 

242 value, 

243 attn_mask=None, 

244 dropout_p=0.0, 

245 is_causal=False, 

246 scale=None, 

247 enable_gqa=False, 

248): 

249 """ 

250 aten::scaled_dot_product_attention — ARM CPU Flash Attention. 

251 

252 Triton path conditions (otherwise fall back to ATen): 

253 - dtype = bfloat16 

254 - attn_mask = None 

255 - dropout_p = 0.0 

256 - seqlen_q >= BLOCK_M (=32) 

257 - head_dim in {16,32,64,128,256} 

258 """ 

259 B, Hq, M, D = query.shape 

260 

261 # M=1 decode fast path: C runtime flash_attn_decode_bf16 via triton-cpu. 

262 # Measured +1.2% E2E on Qwen3-1.7B INT8 vs ATen fallback (3 rounds A/B). 

263 # Requires BF16, no mask, no dropout, contiguous Q/K/V. 

264 if ( 

265 _flash_attn_decode_bf16 is not None 

266 and M == 1 

267 and B == 1 

268 and query.dtype == torch.bfloat16 

269 and attn_mask is None 

270 and dropout_p == 0.0 

271 and query.is_contiguous() 

272 and key.is_contiguous() 

273 and value.is_contiguous() 

274 ): 

275 Hkv = key.shape[1] 

276 seq_len = key.shape[2] 

277 sm_scale = scale if scale is not None else D**-0.5 

278 q_flat = query.squeeze(0).squeeze(1).contiguous() 

279 k_flat = key.squeeze(0).contiguous() 

280 v_flat = value.squeeze(0).contiguous() 

281 out_flat = torch.empty(Hq, D, dtype=torch.bfloat16) 

282 _flash_attn_decode_bf16( 

283 q_flat, 

284 k_flat, 

285 v_flat, 

286 out_flat, 

287 seq_len, 

288 D, 

289 sm_scale, 

290 Hq, 

291 Hkv, 

292 k_flat.stride(1), 

293 v_flat.stride(1), 

294 ) 

295 return out_flat.unsqueeze(0).unsqueeze(2) 

296 

297 # Prefill fast path: Triton Flash Attention kernel (requires M >= BLOCK_M). 

298 use_triton = ( 

299 query.dtype == torch.bfloat16 

300 and attn_mask is None 

301 and dropout_p == 0.0 

302 and M >= _BLOCK_M 

303 and D in {16, 32, 64, 128, 256} 

304 ) 

305 

306 if not use_triton: 

307 log.debug( 

308 "GEMS SDPA: ATen fallback (M=%d, dtype=%s, mask=%s)", 

309 M, 

310 query.dtype, 

311 attn_mask is not None, 

312 ) 

313 return _aten_sdpa( 

314 query, 

315 key, 

316 value, 

317 attn_mask=attn_mask, 

318 dropout_p=dropout_p, 

319 is_causal=is_causal, 

320 scale=scale, 

321 enable_gqa=enable_gqa, 

322 ) 

323 

324 sm_scale = scale if scale is not None else D**-0.5 

325 log.debug( 

326 "GEMS SDPA: Triton Flash Attention (M=%d, N=%d, D=%d, causal=%s, Hq=%d, Hkv=%d)", 

327 M, 

328 key.shape[2], 

329 D, 

330 is_causal, 

331 Hq, 

332 key.shape[1], 

333 ) 

334 return _triton_flash_attn(query, key, value, sm_scale, is_causal)