Coverage for src/flag_gems/runtime/backend/_arm/fused/patch_qwen3_5_attention.py: 0%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1"""Monkey-patch F.scaled_dot_product_attention to route the M=1 BF16 

2decode path through the existing flash_attn_decode_bf16 TLE C kernel, 

3replacing the bmm + softmax + bmm sequence (9% of decode time per 

4profiler). 

5 

6This is a minimal patch that only swaps SDPA — does NOT pull in the 

7full FlagGems _arm.ops package (which would also register Triton mm / 

8addmm kernels that are slower than ATen for our small decode shapes). 

9 

10Other shapes (prefill, M>1, non-BF16, with attn_mask) fall through to 

11the original ATen SDPA without recursion (we capture the original 

12function pointer at patch time). 

13""" 

14import logging 

15 

16import torch 

17import torch.nn.functional as F 

18 

19logger = logging.getLogger(__name__) 

20 

21# Imported at module load. If triton-cpu lacks the runtime module 

22# (older build), keep flash_attn_decode unavailable and fall through 

23# to ATen. 

24try: 

25 import triton 

26 import triton.language as tl 

27 from triton.language.extra.cpu.tle_ops import ( 

28 flash_attn_decode as _tle_flash_attn_decode, 

29 ) 

30 

31 @triton.jit 

32 def _flash_attn_kernel( 

33 q_ptr, 

34 k_ptr, 

35 v_ptr, 

36 out_ptr, 

37 seq_len, 

38 head_dim: tl.constexpr, 

39 sm_scale: tl.constexpr, 

40 num_heads: tl.constexpr, 

41 num_kv_heads: tl.constexpr, 

42 stride_kn: tl.constexpr, 

43 stride_vn: tl.constexpr, 

44 ): 

45 # Coarse TLE op: the whole M=1 flash-attention decode in one launch. 

46 # seq_len is runtime (grows per token); the rest are constexpr. 

47 _tle_flash_attn_decode( 

48 q_ptr, 

49 k_ptr, 

50 v_ptr, 

51 out_ptr, 

52 seq_len, 

53 head_dim, 

54 sm_scale, 

55 num_heads, 

56 num_kv_heads, 

57 stride_kn, 

58 stride_vn, 

59 ) 

60 

61except ImportError: 

62 _flash_attn_kernel = None 

63 

64# Capture the *original* SDPA before any patching so our fallback 

65# call doesn't recurse. 

66_orig_sdpa = F.scaled_dot_product_attention 

67 

68_PATCHED = False 

69 

70 

71def _patched_sdpa( 

72 query, 

73 key, 

74 value, 

75 attn_mask=None, 

76 dropout_p=0.0, 

77 is_causal=False, 

78 scale=None, 

79 enable_gqa=False, 

80): 

81 """SDPA with M=1 BF16 fast path using flash_attn_decode_bf16.""" 

82 if _flash_attn_kernel is None: 

83 return _orig_sdpa( 

84 query, 

85 key, 

86 value, 

87 attn_mask=attn_mask, 

88 dropout_p=dropout_p, 

89 is_causal=is_causal, 

90 scale=scale, 

91 enable_gqa=enable_gqa, 

92 ) 

93 

94 B, Hq, M, D = query.shape 

95 

96 if ( 

97 M == 1 

98 and B == 1 

99 and query.dtype == torch.bfloat16 

100 and attn_mask is None 

101 and dropout_p == 0.0 

102 and query.is_contiguous() 

103 and key.is_contiguous() 

104 and value.is_contiguous() 

105 ): 

106 Hkv = key.shape[1] 

107 seq_len = key.shape[2] 

108 sm_scale = scale if scale is not None else D**-0.5 

109 q_flat = query.squeeze(0).squeeze(1).contiguous() 

110 k_flat = key.squeeze(0).contiguous() 

111 v_flat = value.squeeze(0).contiguous() 

112 out_flat = torch.empty(Hq, D, dtype=torch.bfloat16) 

113 _flash_attn_kernel[(1,)]( 

114 q_flat, 

115 k_flat, 

116 v_flat, 

117 out_flat, 

118 seq_len, 

119 head_dim=D, 

120 sm_scale=sm_scale, 

121 num_heads=Hq, 

122 num_kv_heads=Hkv, 

123 stride_kn=k_flat.stride(1), 

124 stride_vn=v_flat.stride(1), 

125 ) 

126 return out_flat.unsqueeze(0).unsqueeze(2) 

127 

128 # Non-decode shapes: fall back to original ATen SDPA. 

129 return _orig_sdpa( 

130 query, 

131 key, 

132 value, 

133 attn_mask=attn_mask, 

134 dropout_p=dropout_p, 

135 is_causal=is_causal, 

136 scale=scale, 

137 enable_gqa=enable_gqa, 

138 ) 

139 

140 

141def patch_qwen3_5_attention(model=None) -> int: 

142 """Install the patched SDPA. The `model` parameter is ignored and only 

143 accepted for API consistency with other patches. 

144 

145 Returns 1 if installed (or already installed), 0 if flash_attn_decode 

146 is unavailable. 

147 """ 

148 global _PATCHED 

149 if _flash_attn_kernel is None: 

150 logger.warning("flash_attn_decode_bf16 not available; SDPA patch skipped") 

151 return 0 

152 if _PATCHED: 

153 return 1 

154 F.scaled_dot_product_attention = _patched_sdpa 

155 _PATCHED = True 

156 logger.info( 

157 "Patched F.scaled_dot_product_attention with TLE flash_attn_decode_bf16" 

158 ) 

159 return 1 

160 

161 

162def unpatch_qwen3_5_attention(model=None) -> int: 

163 global _PATCHED 

164 if not _PATCHED: 

165 return 0 

166 F.scaled_dot_product_attention = _orig_sdpa 

167 _PATCHED = False 

168 return 1