Coverage for src/flag_gems/runtime/backend/_arm/fused/patch_qwen3_5_attention.py: 0%
46 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1"""Monkey-patch F.scaled_dot_product_attention to route the M=1 BF16
2decode path through the existing flash_attn_decode_bf16 TLE C kernel,
3replacing the bmm + softmax + bmm sequence (9% of decode time per
4profiler).
6This is a minimal patch that only swaps SDPA — does NOT pull in the
7full FlagGems _arm.ops package (which would also register Triton mm /
8addmm kernels that are slower than ATen for our small decode shapes).
10Other shapes (prefill, M>1, non-BF16, with attn_mask) fall through to
11the original ATen SDPA without recursion (we capture the original
12function pointer at patch time).
13"""
14import logging
16import torch
17import torch.nn.functional as F
19logger = logging.getLogger(__name__)
21# Imported at module load. If triton-cpu lacks the runtime module
22# (older build), keep flash_attn_decode unavailable and fall through
23# to ATen.
24try:
25 import triton
26 import triton.language as tl
27 from triton.language.extra.cpu.tle_ops import (
28 flash_attn_decode as _tle_flash_attn_decode,
29 )
31 @triton.jit
32 def _flash_attn_kernel(
33 q_ptr,
34 k_ptr,
35 v_ptr,
36 out_ptr,
37 seq_len,
38 head_dim: tl.constexpr,
39 sm_scale: tl.constexpr,
40 num_heads: tl.constexpr,
41 num_kv_heads: tl.constexpr,
42 stride_kn: tl.constexpr,
43 stride_vn: tl.constexpr,
44 ):
45 # Coarse TLE op: the whole M=1 flash-attention decode in one launch.
46 # seq_len is runtime (grows per token); the rest are constexpr.
47 _tle_flash_attn_decode(
48 q_ptr,
49 k_ptr,
50 v_ptr,
51 out_ptr,
52 seq_len,
53 head_dim,
54 sm_scale,
55 num_heads,
56 num_kv_heads,
57 stride_kn,
58 stride_vn,
59 )
61except ImportError:
62 _flash_attn_kernel = None
64# Capture the *original* SDPA before any patching so our fallback
65# call doesn't recurse.
66_orig_sdpa = F.scaled_dot_product_attention
68_PATCHED = False
71def _patched_sdpa(
72 query,
73 key,
74 value,
75 attn_mask=None,
76 dropout_p=0.0,
77 is_causal=False,
78 scale=None,
79 enable_gqa=False,
80):
81 """SDPA with M=1 BF16 fast path using flash_attn_decode_bf16."""
82 if _flash_attn_kernel is None:
83 return _orig_sdpa(
84 query,
85 key,
86 value,
87 attn_mask=attn_mask,
88 dropout_p=dropout_p,
89 is_causal=is_causal,
90 scale=scale,
91 enable_gqa=enable_gqa,
92 )
94 B, Hq, M, D = query.shape
96 if (
97 M == 1
98 and B == 1
99 and query.dtype == torch.bfloat16
100 and attn_mask is None
101 and dropout_p == 0.0
102 and query.is_contiguous()
103 and key.is_contiguous()
104 and value.is_contiguous()
105 ):
106 Hkv = key.shape[1]
107 seq_len = key.shape[2]
108 sm_scale = scale if scale is not None else D**-0.5
109 q_flat = query.squeeze(0).squeeze(1).contiguous()
110 k_flat = key.squeeze(0).contiguous()
111 v_flat = value.squeeze(0).contiguous()
112 out_flat = torch.empty(Hq, D, dtype=torch.bfloat16)
113 _flash_attn_kernel[(1,)](
114 q_flat,
115 k_flat,
116 v_flat,
117 out_flat,
118 seq_len,
119 head_dim=D,
120 sm_scale=sm_scale,
121 num_heads=Hq,
122 num_kv_heads=Hkv,
123 stride_kn=k_flat.stride(1),
124 stride_vn=v_flat.stride(1),
125 )
126 return out_flat.unsqueeze(0).unsqueeze(2)
128 # Non-decode shapes: fall back to original ATen SDPA.
129 return _orig_sdpa(
130 query,
131 key,
132 value,
133 attn_mask=attn_mask,
134 dropout_p=dropout_p,
135 is_causal=is_causal,
136 scale=scale,
137 enable_gqa=enable_gqa,
138 )
141def patch_qwen3_5_attention(model=None) -> int:
142 """Install the patched SDPA. The `model` parameter is ignored and only
143 accepted for API consistency with other patches.
145 Returns 1 if installed (or already installed), 0 if flash_attn_decode
146 is unavailable.
147 """
148 global _PATCHED
149 if _flash_attn_kernel is None:
150 logger.warning("flash_attn_decode_bf16 not available; SDPA patch skipped")
151 return 0
152 if _PATCHED:
153 return 1
154 F.scaled_dot_product_attention = _patched_sdpa
155 _PATCHED = True
156 logger.info(
157 "Patched F.scaled_dot_product_attention with TLE flash_attn_decode_bf16"
158 )
159 return 1
162def unpatch_qwen3_5_attention(model=None) -> int:
163 global _PATCHED
164 if not _PATCHED:
165 return 0
166 F.scaled_dot_product_attention = _orig_sdpa
167 _PATCHED = False
168 return 1