Coverage for src/flag_gems/fused/add_rms_norm.py: 38%

93 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3import math 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry 

12from flag_gems.utils import triton_lang_extension as tle 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17@triton.jit 

18def prev_multiple_of(a, b): 

19 return tl.cdiv(a, b) * b - b 

20 

21 

22@libentry() 

23@triton.jit(do_not_specialize=["eps"]) 

24def add_rms_norm_kernel( 

25 out_ptr, 

26 in_ptr1, 

27 in_ptr2, 

28 w_ptr, 

29 y_stride_r, 

30 y_stride_c, 

31 x1_stride_r, 

32 x1_stride_c, 

33 x2_stride_r, 

34 x2_stride_c, 

35 N, 

36 eps, 

37 BLOCK_SIZE: tl.constexpr, 

38): 

39 if tl.constexpr(in_ptr1.dtype.element_ty == tl.float16) or tl.constexpr( 

40 in_ptr1.dtype.element_ty == tl.bfloat16 

41 ): 

42 cdtype = tl.float32 

43 else: 

44 cdtype = in_ptr1.dtype.element_ty 

45 

46 pid = tl.program_id(0) 

47 out_ptr += pid * y_stride_r 

48 in_ptr1 += pid * x1_stride_r 

49 in_ptr2 += pid * x2_stride_r 

50 

51 mask = tl.arange(0, BLOCK_SIZE) < N 

52 cols = tl.arange(0, BLOCK_SIZE) 

53 x1 = tl.load(in_ptr1 + cols * x1_stride_c, mask, other=0.0).to(cdtype) 

54 x2 = tl.load(in_ptr2 + cols * x2_stride_c, mask, other=0.0).to(cdtype) 

55 

56 # Add the two inputs 

57 x = x1 + x2 

58 

59 var = tl.sum(x * x, axis=0) / N 

60 rrms = 1 / tl.sqrt(var + eps) 

61 

62 w = tl.load(w_ptr + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0) 

63 y = (x * rrms * w).to(cdtype) 

64 tl.store(out_ptr + cols * y_stride_c, y, mask=mask) 

65 

66 

67@libentry() 

68@triton.autotune( 

69 configs=runtime.get_tuned_config("add_rms_norm_loop"), 

70 key=["N"], 

71) 

72@triton.jit(do_not_specialize=["eps"]) 

73def add_rms_norm_loop_kernel( 

74 out_ptr, 

75 in_ptr1, 

76 in_ptr2, 

77 w_ptr, 

78 N, 

79 eps, 

80 TILE_N: tl.constexpr, 

81): 

82 if tl.constexpr(in_ptr1.dtype.element_ty == tl.float16) or tl.constexpr( 

83 in_ptr1.dtype.element_ty == tl.bfloat16 

84 ): 

85 cdtype = tl.float32 

86 else: 

87 cdtype = in_ptr1.dtype.element_ty 

88 

89 pid = tle.program_id(0) 

90 

91 # Pass 1: compute sum(x^2) in chunks 

92 acc = tl.zeros((TILE_N,), dtype=tl.float32) 

93 num_steps = tl.cdiv(N, TILE_N) 

94 

95 for step in range(0, num_steps - 1): 

96 start_n = step * TILE_N 

97 n_offsets = start_n + tl.arange(0, TILE_N) 

98 x1 = tl.load(in_ptr1 + pid * N + n_offsets).to(tl.float32) 

99 x2 = tl.load(in_ptr2 + pid * N + n_offsets).to(tl.float32) 

100 x = x1 + x2 

101 acc += x * x 

102 

103 # last step with mask 

104 start_n = (num_steps - 1) * TILE_N 

105 n_offsets = start_n + tl.arange(0, TILE_N) 

106 mask = n_offsets < N 

107 x1 = tl.load(in_ptr1 + pid * N + n_offsets, mask=mask, other=0.0).to(tl.float32) 

108 x2 = tl.load(in_ptr2 + pid * N + n_offsets, mask=mask, other=0.0).to(tl.float32) 

109 x = x1 + x2 

110 acc += x * x 

111 

112 var = tl.sum(acc) / N 

113 rrms = 1 / tl.sqrt(var + eps) 

114 

115 # Pass 2: normalize in reverse order (better L2 cache reuse) 

116 prev_multiple = prev_multiple_of(N, TILE_N) 

117 

118 # first reverse step with mask 

119 for start_n in range(0, TILE_N, TILE_N): 

120 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N) 

121 mask = n_offsets < N 

122 x1 = tl.load( 

123 in_ptr1 + pid * N + n_offsets, 

124 mask=mask, 

125 other=0.0, 

126 eviction_policy="evict_first", 

127 ).to(cdtype) 

128 x2 = tl.load( 

129 in_ptr2 + pid * N + n_offsets, 

130 mask=mask, 

131 other=0.0, 

132 eviction_policy="evict_first", 

133 ).to(cdtype) 

134 x = x1 + x2 

135 w = tl.load(w_ptr + n_offsets, mask=mask, other=0.0) 

136 y = (x * rrms * w).to(cdtype) 

137 tl.store(out_ptr + pid * N + n_offsets, y, mask=mask) 

138 

139 for start_n in range(TILE_N, N, TILE_N): 

140 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N) 

141 x1 = tl.load( 

142 in_ptr1 + pid * N + n_offsets, 

143 eviction_policy="evict_first", 

144 ).to(cdtype) 

145 x2 = tl.load( 

146 in_ptr2 + pid * N + n_offsets, 

147 eviction_policy="evict_first", 

148 ).to(cdtype) 

149 x = x1 + x2 

150 w = tl.load(w_ptr + n_offsets) 

151 y = (x * rrms * w).to(cdtype) 

152 tl.store(out_ptr + pid * N + n_offsets, y) 

153 

154 

155def add_rms_norm(x1, x2, normalized_shape, weight, eps=1e-5): 

156 """ 

157 Add_RMSNorm: Add two inputs element-wise and apply RMS normalization. 

158 

159 Args: 

160 x1: First input tensor 

161 x2: Second input tensor 

162 normalized_shape: Shape to normalize over (typically the last dimensions) 

163 weight: Optional weight tensor for the normalization 

164 eps: Epsilon value for numerical stability 

165 

166 Returns: 

167 Normalized output tensor 

168 """ 

169 logger.debug( 

170 "GEMS ADD_RMS_NORM FORWARD, [input1 shape]: %s, [input2 shape]: %s, [weight shape]: %s", 

171 x1.size(), 

172 x2.size(), 

173 weight.size() if weight is not None else None, 

174 ) 

175 dim = x1.ndim - len(normalized_shape) 

176 M = math.prod(x1.shape[:dim]) 

177 N = math.prod(normalized_shape) 

178 

179 # Verify shapes match 

180 assert x1.shape == x2.shape, f"Input shapes must match: {x1.shape} vs {x2.shape}" 

181 

182 x1 = x1.contiguous() 

183 x2 = x2.contiguous() 

184 weight = weight.contiguous() 

185 y = torch.empty_like(x1) 

186 

187 with torch_device_fn.device(x1.device): 

188 if N <= 4096: 

189 BLOCK_SIZE = triton.next_power_of_2(N) 

190 add_rms_norm_kernel[M,]( 

191 y, x1, x2, weight, N, 1, N, 1, N, 1, N, eps, BLOCK_SIZE 

192 ) 

193 else: 

194 add_rms_norm_loop_kernel[M,](y, x1, x2, weight, N, eps) 

195 

196 return y