Coverage for src/flag_gems/runtime/backend/_arm/fused/patch_qwen3_rope.py: 0%

87 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1"""Monkey-patch transformers.models.qwen3.modeling_qwen3.apply_rotary_pos_emb 

2to use a TLE @triton.jit RoPE kernel for the decode (T=1) BF16 hot path. 

3 

4Replaces 4 ATen muls + 2 adds + slice/cat of rotate_half with two single 

5Triton kernel launches (one for q, one for k). Goes through the @triton.jit 

6TLE path (NOT ctypes) per project requirement. 

7 

8Decode (B=1, T=1, BF16 q/k contiguous) hits the fast path. Other shapes fall 

9back to the original PyTorch implementation. 

10""" 

11import logging 

12 

13import torch 

14import triton 

15import triton.language as tl 

16 

17from flag_gems.utils import triton_lang_extension as tle 

18 

19logger = logging.getLogger(__name__) 

20 

21_PATCHED: dict = {} 

22 

23 

24@triton.jit 

25def _rope_qk_bf16_kernel( 

26 q_ptr, 

27 k_ptr, 

28 cos_ptr, 

29 sin_ptr, 

30 n_heads_q, 

31 head_dim, 

32 half, 

33 BLOCK_HALF: tl.constexpr, 

34): 

35 """Apply RoPE in-place to q (heads 0..n_heads_q) and k (heads n_heads_q..total) 

36 in a single kernel launch. Grid is (n_heads_q + n_heads_kv,). 

37 

38 Layout: q_ptr / k_ptr point to flat memory [n_heads * head_dim] bf16 each. 

39 cos_ptr / sin_ptr point to [half] bf16 (interleaved RoPE convention: 

40 cos/sin first half repeated for the second half). 

41 """ 

42 pid = tle.program_id(0) 

43 is_q = pid < n_heads_q 

44 # Branch on q vs k by selecting the right base + index 

45 row = tl.where(is_q, q_ptr + pid * head_dim, k_ptr + (pid - n_heads_q) * head_dim) 

46 for off in range(0, half, BLOCK_HALF): 

47 d = off + tl.arange(0, BLOCK_HALF) 

48 mask = d < half 

49 q0 = tl.load(row + d, mask=mask, other=0.0).to(tl.float32) 

50 q1 = tl.load(row + half + d, mask=mask, other=0.0).to(tl.float32) 

51 c = tl.load(cos_ptr + d, mask=mask, other=0.0).to(tl.float32) 

52 s = tl.load(sin_ptr + d, mask=mask, other=0.0).to(tl.float32) 

53 r0 = q0 * c - q1 * s 

54 r1 = q0 * s + q1 * c 

55 tl.store(row + d, r0.to(q_ptr.dtype.element_ty), mask=mask) 

56 tl.store(row + half + d, r1.to(q_ptr.dtype.element_ty), mask=mask) 

57 

58 

59_PREWARM_DONE = False 

60 

61 

62def _prewarm(): 

63 global _PREWARM_DONE 

64 if _PREWARM_DONE: 

65 return 

66 try: 

67 for hd in (128,): 

68 q = torch.zeros((2, hd), dtype=torch.bfloat16) 

69 k = torch.zeros((2, hd), dtype=torch.bfloat16) 

70 c = torch.zeros(hd // 2, dtype=torch.bfloat16) 

71 s = torch.zeros(hd // 2, dtype=torch.bfloat16) 

72 _rope_qk_bf16_kernel[(4,)]( 

73 q, 

74 k, 

75 c, 

76 s, 

77 2, 

78 hd, 

79 hd // 2, 

80 BLOCK_HALF=64, 

81 num_warps=1, 

82 num_stages=1, 

83 ) 

84 except Exception: 

85 logger.debug("rope prewarm failed", exc_info=True) 

86 _PREWARM_DONE = True 

87 

88 

89def _rope_bf16_jit(q, k, cos_half, sin_half, n_heads_q, n_heads_kv, head_dim): 

90 """Apply RoPE in-place via single @triton.jit kernel launch (q+k fused).""" 

91 _prewarm() 

92 half = head_dim // 2 

93 total = n_heads_q + n_heads_kv 

94 _rope_qk_bf16_kernel[(total,)]( 

95 q, 

96 k, 

97 cos_half, 

98 sin_half, 

99 n_heads_q, 

100 head_dim, 

101 half, 

102 BLOCK_HALF=64, 

103 num_warps=1, 

104 num_stages=1, 

105 ) 

106 

107 

108def _patched_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): 

109 """Fast path uses @triton.jit kernel for decode B=1 T=1 BF16; else fall back.""" 

110 # Fast path conditions 

111 if ( 

112 q.dim() == 4 

113 and k.dim() == 4 

114 and q.shape[0] == 1 

115 and k.shape[0] == 1 

116 and q.shape[2] == 1 

117 and k.shape[2] == 1 

118 and q.dtype == torch.bfloat16 

119 and k.dtype == torch.bfloat16 

120 and cos.dim() in (2, 3, 4) 

121 and sin.dim() in (2, 3, 4) 

122 and q.is_contiguous() 

123 and k.is_contiguous() 

124 ): 

125 n_heads = q.shape[1] 

126 n_kv_heads = k.shape[1] 

127 head_dim = q.shape[3] 

128 # Kernel uses interleaved convention: only reads cos/sin first half (assumes 

129 # cos[d/2:]==cos[:d/2] which is HF's standard repeat-half RoPE pattern). 

130 cos_half = cos.reshape(-1, head_dim)[0, : head_dim // 2].contiguous() 

131 sin_half = sin.reshape(-1, head_dim)[0, : head_dim // 2].contiguous() 

132 # Kernel writes in-place; clone q/k to preserve original (HF semantics). 

133 q_buf = q.clone() 

134 k_buf = k.clone() 

135 _rope_bf16_jit(q_buf, k_buf, cos_half, sin_half, n_heads, n_kv_heads, head_dim) 

136 return q_buf, k_buf 

137 

138 # Fallback: original PyTorch implementation 

139 return _PATCHED["original"](q, k, cos, sin, unsqueeze_dim) 

140 

141 

142def patch_qwen3_rope() -> int: 

143 """Monkey-patch apply_rotary_pos_emb in transformers.models.qwen3. 

144 

145 Returns count of patched modules. 

146 """ 

147 # Targets regular Qwen3 only. Qwen3.5 supports partial rotary 

148 # (q_rot vs q_pass split); needs separate handling. 

149 targets = [ 

150 "transformers.models.qwen3.modeling_qwen3", 

151 ] 

152 n = 0 

153 for modname in targets: 

154 try: 

155 mod = __import__(modname, fromlist=["apply_rotary_pos_emb"]) 

156 except (ImportError, AttributeError): 

157 continue 

158 if not hasattr(mod, "apply_rotary_pos_emb"): 

159 continue 

160 if modname in _PATCHED: 

161 continue 

162 original = getattr(mod, "apply_rotary_pos_emb") 

163 _PATCHED["original"] = original 

164 _PATCHED[modname] = True 

165 setattr(mod, "apply_rotary_pos_emb", _patched_apply_rotary_pos_emb) 

166 n += 1 

167 logger.info(f"Patched {modname}.apply_rotary_pos_emb") 

168 return n 

169 

170 

171def unpatch_qwen3_rope() -> int: 

172 n = 0 

173 for modname in list(_PATCHED.keys()): 

174 if modname == "original": 

175 continue 

176 try: 

177 mod = __import__(modname, fromlist=["apply_rotary_pos_emb"]) 

178 except (ImportError, AttributeError): 

179 continue 

180 if "original" in _PATCHED: 

181 setattr(mod, "apply_rotary_pos_emb", _PATCHED["original"]) 

182 del _PATCHED[modname] 

183 n += 1 

184 return n