Coverage for src/flag_gems/runtime/backend/_arm/fused/patch_qwen3_rmsnorm.py: 0%
85 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1"""Patch Qwen3RMSNorm.forward to use a single Triton kernel call,
2replacing the 5-6 ATen dispatches in the eager implementation.
4Each Qwen3 decode token does ~113 RMSNorm calls:
5 - 28 input_layernorm (hidden_size, e.g. 2048)
6 - 1 final norm (hidden_size)
7 - 28 post_attention_layernorm (already fused via patch_qwen3_layer_norm)
8 - 28 q_norm (head_dim=128, batched across heads)
9 - 28 k_norm (head_dim=128, batched across heads)
11This patch targets the input/q/k/final norms (~85 calls/token).
12"""
13import logging
15import torch
16import triton
17import triton.language as tl
19from flag_gems.utils import triton_lang_extension as tle
21logger = logging.getLogger(__name__)
23_TILE = 128
24_PATCHED: dict = {}
27@triton.jit(do_not_specialize=["eps"])
28def _rms_norm_kernel(
29 x_ptr,
30 w_ptr,
31 out_ptr,
32 stride_r,
33 N,
34 eps,
35 BLOCK_SIZE: tl.constexpr,
36):
37 pid = tle.program_id(0)
38 x_row = x_ptr + pid * stride_r
39 out_row = out_ptr + pid * stride_r
41 sum_sq = tl.zeros([1], dtype=tl.float32)
42 for off in range(0, N, BLOCK_SIZE):
43 cols = off + tl.arange(0, BLOCK_SIZE)
44 mask = cols < N
45 x = tl.load(x_row + cols, mask=mask, other=0.0).to(tl.float32)
46 sum_sq += tl.sum(x * x, axis=0)
47 rrms = 1.0 / tl.sqrt(sum_sq / N + eps)
49 for off in range(0, N, BLOCK_SIZE):
50 cols = off + tl.arange(0, BLOCK_SIZE)
51 mask = cols < N
52 x = tl.load(x_row + cols, mask=mask, other=0.0).to(tl.float32)
53 w = tl.load(w_ptr + cols, mask=mask, other=0.0)
54 y = (x * rrms).to(out_ptr.dtype.element_ty) * w
55 tl.store(out_row + cols, y, mask=mask)
58_PREWARM_DONE = False
61def _prewarm():
62 global _PREWARM_DONE
63 if _PREWARM_DONE:
64 return
65 try:
66 for N in (128, 2048, 2560):
67 x = torch.ones((1, N), dtype=torch.bfloat16)
68 w = torch.ones(N, dtype=torch.bfloat16)
69 o = torch.empty_like(x)
70 _rms_norm_kernel[(1,)](
71 x, w, o, N, N, 1e-6, BLOCK_SIZE=_TILE, num_warps=1, num_stages=1
72 )
73 except Exception:
74 logger.debug("rmsnorm prewarm failed", exc_info=True)
75 _PREWARM_DONE = True
78def _rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
79 """Standalone RMSNorm. Operates on last dim. Supports any leading dim shape."""
80 _prewarm()
81 N = weight.shape[0]
82 assert x.shape[-1] == N
83 M = x.numel() // N
84 x_2d = x.reshape(M, N).contiguous()
85 out = torch.empty_like(x_2d)
86 _rms_norm_kernel[(M,)](
87 x_2d,
88 weight,
89 out,
90 N, # stride_r
91 N,
92 eps,
93 BLOCK_SIZE=_TILE,
94 num_warps=1,
95 num_stages=1,
96 )
97 return out.reshape(x.shape)
100def _make_patched_forward(orig):
101 def patched(self, hidden_states):
102 # Fast path: BF16 only
103 if hidden_states.dtype != torch.bfloat16:
104 return orig(self, hidden_states)
105 return _rms_norm(hidden_states, self.weight, self.variance_epsilon)
107 return patched
110def patch_qwen3_rmsnorm() -> int:
111 # Targets regular Qwen3 only. Qwen3.5 has different math
112 # (output * (1.0 + weight) and attr is `eps` not `variance_epsilon`);
113 # use patch_qwen3_5_rmsnorm.py for that.
114 targets = [
115 "transformers.models.qwen3.modeling_qwen3",
116 ]
117 n = 0
118 for modname in targets:
119 try:
120 mod = __import__(modname, fromlist=["Qwen3RMSNorm"])
121 except (ImportError, AttributeError):
122 continue
123 cls_name = "Qwen3RMSNorm" if "qwen3_5" not in modname else "Qwen3_5RMSNorm"
124 if not hasattr(mod, cls_name):
125 cls_name = "Qwen3RMSNorm"
126 if not hasattr(mod, cls_name):
127 continue
128 cls = getattr(mod, cls_name)
129 key = (modname, cls_name)
130 if key in _PATCHED:
131 continue
132 orig = cls.forward
133 _PATCHED[key] = (cls, orig)
134 cls.forward = _make_patched_forward(orig)
135 n += 1
136 logger.info(f"Patched {modname}.{cls_name}.forward")
137 return n
140def unpatch_qwen3_rmsnorm() -> int:
141 n = 0
142 for key, (cls, orig) in list(_PATCHED.items()):
143 cls.forward = orig
144 del _PATCHED[key]
145 n += 1
146 return n