Coverage for src/flag_gems/runtime/backend/_arm/fused/patch_qwen3_rmsnorm.py: 0%

85 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1"""Patch Qwen3RMSNorm.forward to use a single Triton kernel call, 

2replacing the 5-6 ATen dispatches in the eager implementation. 

3 

4Each Qwen3 decode token does ~113 RMSNorm calls: 

5 - 28 input_layernorm (hidden_size, e.g. 2048) 

6 - 1 final norm (hidden_size) 

7 - 28 post_attention_layernorm (already fused via patch_qwen3_layer_norm) 

8 - 28 q_norm (head_dim=128, batched across heads) 

9 - 28 k_norm (head_dim=128, batched across heads) 

10 

11This patch targets the input/q/k/final norms (~85 calls/token). 

12""" 

13import logging 

14 

15import torch 

16import triton 

17import triton.language as tl 

18 

19from flag_gems.utils import triton_lang_extension as tle 

20 

21logger = logging.getLogger(__name__) 

22 

23_TILE = 128 

24_PATCHED: dict = {} 

25 

26 

27@triton.jit(do_not_specialize=["eps"]) 

28def _rms_norm_kernel( 

29 x_ptr, 

30 w_ptr, 

31 out_ptr, 

32 stride_r, 

33 N, 

34 eps, 

35 BLOCK_SIZE: tl.constexpr, 

36): 

37 pid = tle.program_id(0) 

38 x_row = x_ptr + pid * stride_r 

39 out_row = out_ptr + pid * stride_r 

40 

41 sum_sq = tl.zeros([1], dtype=tl.float32) 

42 for off in range(0, N, BLOCK_SIZE): 

43 cols = off + tl.arange(0, BLOCK_SIZE) 

44 mask = cols < N 

45 x = tl.load(x_row + cols, mask=mask, other=0.0).to(tl.float32) 

46 sum_sq += tl.sum(x * x, axis=0) 

47 rrms = 1.0 / tl.sqrt(sum_sq / N + eps) 

48 

49 for off in range(0, N, BLOCK_SIZE): 

50 cols = off + tl.arange(0, BLOCK_SIZE) 

51 mask = cols < N 

52 x = tl.load(x_row + cols, mask=mask, other=0.0).to(tl.float32) 

53 w = tl.load(w_ptr + cols, mask=mask, other=0.0) 

54 y = (x * rrms).to(out_ptr.dtype.element_ty) * w 

55 tl.store(out_row + cols, y, mask=mask) 

56 

57 

58_PREWARM_DONE = False 

59 

60 

61def _prewarm(): 

62 global _PREWARM_DONE 

63 if _PREWARM_DONE: 

64 return 

65 try: 

66 for N in (128, 2048, 2560): 

67 x = torch.ones((1, N), dtype=torch.bfloat16) 

68 w = torch.ones(N, dtype=torch.bfloat16) 

69 o = torch.empty_like(x) 

70 _rms_norm_kernel[(1,)]( 

71 x, w, o, N, N, 1e-6, BLOCK_SIZE=_TILE, num_warps=1, num_stages=1 

72 ) 

73 except Exception: 

74 logger.debug("rmsnorm prewarm failed", exc_info=True) 

75 _PREWARM_DONE = True 

76 

77 

78def _rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: 

79 """Standalone RMSNorm. Operates on last dim. Supports any leading dim shape.""" 

80 _prewarm() 

81 N = weight.shape[0] 

82 assert x.shape[-1] == N 

83 M = x.numel() // N 

84 x_2d = x.reshape(M, N).contiguous() 

85 out = torch.empty_like(x_2d) 

86 _rms_norm_kernel[(M,)]( 

87 x_2d, 

88 weight, 

89 out, 

90 N, # stride_r 

91 N, 

92 eps, 

93 BLOCK_SIZE=_TILE, 

94 num_warps=1, 

95 num_stages=1, 

96 ) 

97 return out.reshape(x.shape) 

98 

99 

100def _make_patched_forward(orig): 

101 def patched(self, hidden_states): 

102 # Fast path: BF16 only 

103 if hidden_states.dtype != torch.bfloat16: 

104 return orig(self, hidden_states) 

105 return _rms_norm(hidden_states, self.weight, self.variance_epsilon) 

106 

107 return patched 

108 

109 

110def patch_qwen3_rmsnorm() -> int: 

111 # Targets regular Qwen3 only. Qwen3.5 has different math 

112 # (output * (1.0 + weight) and attr is `eps` not `variance_epsilon`); 

113 # use patch_qwen3_5_rmsnorm.py for that. 

114 targets = [ 

115 "transformers.models.qwen3.modeling_qwen3", 

116 ] 

117 n = 0 

118 for modname in targets: 

119 try: 

120 mod = __import__(modname, fromlist=["Qwen3RMSNorm"]) 

121 except (ImportError, AttributeError): 

122 continue 

123 cls_name = "Qwen3RMSNorm" if "qwen3_5" not in modname else "Qwen3_5RMSNorm" 

124 if not hasattr(mod, cls_name): 

125 cls_name = "Qwen3RMSNorm" 

126 if not hasattr(mod, cls_name): 

127 continue 

128 cls = getattr(mod, cls_name) 

129 key = (modname, cls_name) 

130 if key in _PATCHED: 

131 continue 

132 orig = cls.forward 

133 _PATCHED[key] = (cls, orig) 

134 cls.forward = _make_patched_forward(orig) 

135 n += 1 

136 logger.info(f"Patched {modname}.{cls_name}.forward") 

137 return n 

138 

139 

140def unpatch_qwen3_rmsnorm() -> int: 

141 n = 0 

142 for key, (cls, orig) in list(_PATCHED.items()): 

143 cls.forward = orig 

144 del _PATCHED[key] 

145 n += 1 

146 return n