Coverage for src/flag_gems/runtime/backend/_arm/fused/patch_qwen3_rope.py: 0%
87 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1"""Monkey-patch transformers.models.qwen3.modeling_qwen3.apply_rotary_pos_emb
2to use a TLE @triton.jit RoPE kernel for the decode (T=1) BF16 hot path.
4Replaces 4 ATen muls + 2 adds + slice/cat of rotate_half with two single
5Triton kernel launches (one for q, one for k). Goes through the @triton.jit
6TLE path (NOT ctypes) per project requirement.
8Decode (B=1, T=1, BF16 q/k contiguous) hits the fast path. Other shapes fall
9back to the original PyTorch implementation.
10"""
11import logging
13import torch
14import triton
15import triton.language as tl
17from flag_gems.utils import triton_lang_extension as tle
19logger = logging.getLogger(__name__)
21_PATCHED: dict = {}
24@triton.jit
25def _rope_qk_bf16_kernel(
26 q_ptr,
27 k_ptr,
28 cos_ptr,
29 sin_ptr,
30 n_heads_q,
31 head_dim,
32 half,
33 BLOCK_HALF: tl.constexpr,
34):
35 """Apply RoPE in-place to q (heads 0..n_heads_q) and k (heads n_heads_q..total)
36 in a single kernel launch. Grid is (n_heads_q + n_heads_kv,).
38 Layout: q_ptr / k_ptr point to flat memory [n_heads * head_dim] bf16 each.
39 cos_ptr / sin_ptr point to [half] bf16 (interleaved RoPE convention:
40 cos/sin first half repeated for the second half).
41 """
42 pid = tle.program_id(0)
43 is_q = pid < n_heads_q
44 # Branch on q vs k by selecting the right base + index
45 row = tl.where(is_q, q_ptr + pid * head_dim, k_ptr + (pid - n_heads_q) * head_dim)
46 for off in range(0, half, BLOCK_HALF):
47 d = off + tl.arange(0, BLOCK_HALF)
48 mask = d < half
49 q0 = tl.load(row + d, mask=mask, other=0.0).to(tl.float32)
50 q1 = tl.load(row + half + d, mask=mask, other=0.0).to(tl.float32)
51 c = tl.load(cos_ptr + d, mask=mask, other=0.0).to(tl.float32)
52 s = tl.load(sin_ptr + d, mask=mask, other=0.0).to(tl.float32)
53 r0 = q0 * c - q1 * s
54 r1 = q0 * s + q1 * c
55 tl.store(row + d, r0.to(q_ptr.dtype.element_ty), mask=mask)
56 tl.store(row + half + d, r1.to(q_ptr.dtype.element_ty), mask=mask)
59_PREWARM_DONE = False
62def _prewarm():
63 global _PREWARM_DONE
64 if _PREWARM_DONE:
65 return
66 try:
67 for hd in (128,):
68 q = torch.zeros((2, hd), dtype=torch.bfloat16)
69 k = torch.zeros((2, hd), dtype=torch.bfloat16)
70 c = torch.zeros(hd // 2, dtype=torch.bfloat16)
71 s = torch.zeros(hd // 2, dtype=torch.bfloat16)
72 _rope_qk_bf16_kernel[(4,)](
73 q,
74 k,
75 c,
76 s,
77 2,
78 hd,
79 hd // 2,
80 BLOCK_HALF=64,
81 num_warps=1,
82 num_stages=1,
83 )
84 except Exception:
85 logger.debug("rope prewarm failed", exc_info=True)
86 _PREWARM_DONE = True
89def _rope_bf16_jit(q, k, cos_half, sin_half, n_heads_q, n_heads_kv, head_dim):
90 """Apply RoPE in-place via single @triton.jit kernel launch (q+k fused)."""
91 _prewarm()
92 half = head_dim // 2
93 total = n_heads_q + n_heads_kv
94 _rope_qk_bf16_kernel[(total,)](
95 q,
96 k,
97 cos_half,
98 sin_half,
99 n_heads_q,
100 head_dim,
101 half,
102 BLOCK_HALF=64,
103 num_warps=1,
104 num_stages=1,
105 )
108def _patched_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
109 """Fast path uses @triton.jit kernel for decode B=1 T=1 BF16; else fall back."""
110 # Fast path conditions
111 if (
112 q.dim() == 4
113 and k.dim() == 4
114 and q.shape[0] == 1
115 and k.shape[0] == 1
116 and q.shape[2] == 1
117 and k.shape[2] == 1
118 and q.dtype == torch.bfloat16
119 and k.dtype == torch.bfloat16
120 and cos.dim() in (2, 3, 4)
121 and sin.dim() in (2, 3, 4)
122 and q.is_contiguous()
123 and k.is_contiguous()
124 ):
125 n_heads = q.shape[1]
126 n_kv_heads = k.shape[1]
127 head_dim = q.shape[3]
128 # Kernel uses interleaved convention: only reads cos/sin first half (assumes
129 # cos[d/2:]==cos[:d/2] which is HF's standard repeat-half RoPE pattern).
130 cos_half = cos.reshape(-1, head_dim)[0, : head_dim // 2].contiguous()
131 sin_half = sin.reshape(-1, head_dim)[0, : head_dim // 2].contiguous()
132 # Kernel writes in-place; clone q/k to preserve original (HF semantics).
133 q_buf = q.clone()
134 k_buf = k.clone()
135 _rope_bf16_jit(q_buf, k_buf, cos_half, sin_half, n_heads, n_kv_heads, head_dim)
136 return q_buf, k_buf
138 # Fallback: original PyTorch implementation
139 return _PATCHED["original"](q, k, cos, sin, unsqueeze_dim)
142def patch_qwen3_rope() -> int:
143 """Monkey-patch apply_rotary_pos_emb in transformers.models.qwen3.
145 Returns count of patched modules.
146 """
147 # Targets regular Qwen3 only. Qwen3.5 supports partial rotary
148 # (q_rot vs q_pass split); needs separate handling.
149 targets = [
150 "transformers.models.qwen3.modeling_qwen3",
151 ]
152 n = 0
153 for modname in targets:
154 try:
155 mod = __import__(modname, fromlist=["apply_rotary_pos_emb"])
156 except (ImportError, AttributeError):
157 continue
158 if not hasattr(mod, "apply_rotary_pos_emb"):
159 continue
160 if modname in _PATCHED:
161 continue
162 original = getattr(mod, "apply_rotary_pos_emb")
163 _PATCHED["original"] = original
164 _PATCHED[modname] = True
165 setattr(mod, "apply_rotary_pos_emb", _patched_apply_rotary_pos_emb)
166 n += 1
167 logger.info(f"Patched {modname}.apply_rotary_pos_emb")
168 return n
171def unpatch_qwen3_rope() -> int:
172 n = 0
173 for modname in list(_PATCHED.keys()):
174 if modname == "original":
175 continue
176 try:
177 mod = __import__(modname, fromlist=["apply_rotary_pos_emb"])
178 except (ImportError, AttributeError):
179 continue
180 if "original" in _PATCHED:
181 setattr(mod, "apply_rotary_pos_emb", _PATCHED["original"])
182 del _PATCHED[modname]
183 n += 1
184 return n