Coverage for src/flag_gems/ops/rms_norm.py: 33%

166 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as ext 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16@triton.jit 

17def prev_multiple_of(a, b): 

18 return tl.cdiv(a, b) * b - b 

19 

20 

21@libentry() 

22@triton.jit(do_not_specialize=["eps"]) 

23def rms_norm_kernel( 

24 out_ptr, # pointer to the output 

25 INV_RMS, # pointer to inverse rms 

26 in_ptr, # pointer to the input 

27 w_ptr, # pointer to the weights 

28 y_stride_r, 

29 y_stride_c, 

30 x_stride_r, # how much to increase the pointer when moving by 1 row 

31 x_stride_c, # how much to increase the pointer when moving by 1 col 

32 N, # number of columns in X 

33 eps, # epsilon to avoid division by zero 

34 BLOCK_SIZE: tl.constexpr, 

35): 

36 if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

37 in_ptr.dtype.element_ty == tl.bfloat16 

38 ): 

39 cdtype = tl.float32 

40 else: 

41 cdtype = in_ptr.dtype.element_ty 

42 

43 pid = tl.program_id(0) 

44 out_ptr += pid * y_stride_r 

45 in_ptr += pid * x_stride_r 

46 

47 mask = tl.arange(0, BLOCK_SIZE) < N 

48 cols = tl.arange(0, BLOCK_SIZE) 

49 x = tl.load(in_ptr + cols * x_stride_c, mask, other=0.0).to(cdtype) 

50 

51 var = tl.sum(x * x, axis=0) / N 

52 rrms = 1 / tl.sqrt(var + eps) 

53 

54 w = tl.load(w_ptr + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0) 

55 y = (x * rrms * w).to(cdtype) 

56 tl.store(out_ptr + cols * y_stride_c, y, mask=mask) 

57 tl.store(INV_RMS + pid, rrms) 

58 

59 

60@libentry() 

61@triton.autotune( 

62 configs=runtime.get_tuned_config("rms_norm_loop"), 

63 key=["N"], 

64) 

65@triton.jit(do_not_specialize=["eps"]) 

66def rms_norm_loop_kernel( 

67 out_ptr, 

68 INV_RMS, 

69 in_ptr, 

70 w_ptr, 

71 N, 

72 eps, 

73 TILE_N: tl.constexpr, 

74): 

75 if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

76 in_ptr.dtype.element_ty == tl.bfloat16 

77 ): 

78 cdtype = tl.float32 

79 else: 

80 cdtype = in_ptr.dtype.element_ty 

81 

82 pid = ext.program_id(0) 

83 

84 # Pass 1: compute sum(x^2) in chunks 

85 acc = tl.zeros((TILE_N,), dtype=tl.float32) 

86 num_steps = tl.cdiv(N, TILE_N) 

87 

88 for step in range(0, num_steps - 1): 

89 start_n = step * TILE_N 

90 n_offsets = start_n + tl.arange(0, TILE_N) 

91 x = tl.load(in_ptr + pid * N + n_offsets).to(tl.float32) 

92 acc += x * x 

93 

94 # last step with mask 

95 start_n = (num_steps - 1) * TILE_N 

96 n_offsets = start_n + tl.arange(0, TILE_N) 

97 mask = n_offsets < N 

98 x = tl.load(in_ptr + pid * N + n_offsets, mask=mask, other=0.0).to(tl.float32) 

99 acc += x * x 

100 

101 var = tl.sum(acc) / N 

102 rrms = 1 / tl.sqrt(var + eps) 

103 tl.store(INV_RMS + pid, rrms) 

104 

105 # Pass 2: normalize in reverse order (better L2 cache reuse) 

106 prev_multiple = prev_multiple_of(N, TILE_N) 

107 

108 # first reverse step with mask 

109 for start_n in range(0, TILE_N, TILE_N): 

110 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N) 

111 mask = n_offsets < N 

112 x = tl.load( 

113 in_ptr + pid * N + n_offsets, 

114 mask=mask, 

115 other=0.0, 

116 eviction_policy="evict_first", 

117 ).to(cdtype) 

118 w = tl.load(w_ptr + n_offsets, mask=mask, other=0.0) 

119 y = (x * rrms * w).to(cdtype) 

120 tl.store(out_ptr + pid * N + n_offsets, y, mask=mask) 

121 

122 for start_n in range(TILE_N, N, TILE_N): 

123 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N) 

124 x = tl.load( 

125 in_ptr + pid * N + n_offsets, 

126 eviction_policy="evict_first", 

127 ).to(cdtype) 

128 w = tl.load(w_ptr + n_offsets) 

129 y = (x * rrms * w).to(cdtype) 

130 tl.store(out_ptr + pid * N + n_offsets, y) 

131 

132 

133@libentry() 

134@triton.jit(do_not_specialize=["eps"]) 

135def rms_norm_grad_dx_kernel( 

136 X, # pointer to the input 

137 DY, 

138 INV_RMS, # pointer to inverse rms 

139 DX, # pointer to the output 

140 W, # pointer to the weights 

141 dx_stride_r, 

142 dx_stride_c, 

143 x_stride_r, # how much to increase the pointer when moving by 1 row 

144 x_stride_c, # how much to increase the pointer when moving by 1 col 

145 N, # number of columns in X 

146 eps, # epsilon to avoid division by zero 

147 BLOCK_SIZE: tl.constexpr, 

148): 

149 pid = ext.program_id(0) 

150 DX += pid * dx_stride_r 

151 X += pid * x_stride_r 

152 DY += pid * x_stride_r 

153 INV_RMS += pid 

154 

155 mask = tl.arange(0, BLOCK_SIZE) < N 

156 cols = tl.arange(0, BLOCK_SIZE) 

157 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

158 inv_rms = tl.load(INV_RMS).to(tl.float32) 

159 dy = tl.load(DY + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

160 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0) 

161 

162 dy = dy * w 

163 

164 normalized_buf = x * inv_rms 

165 row_sum_stats = tl.sum(normalized_buf * dy, axis=0) 

166 

167 norm_val = normalized_buf / N 

168 dx = (dy - norm_val * row_sum_stats) * inv_rms 

169 

170 tl.store(DX + cols * dx_stride_c, dx, mask=mask) 

171 

172 

173@libentry() 

174@triton.jit 

175def rms_norm_grad_dw_kernel( 

176 X, # pointer to the input 

177 DY, 

178 INV_RMS, # pointer to inverse rms 

179 DW, # pointer to the output 

180 dx_stride_r, 

181 dx_stride_c, 

182 x_stride_r, # how much to increase the pointer when moving by 1 row 

183 x_stride_c, # how much to increase the pointer when moving by 1 col 

184 M, # number of rows in X 

185 N, # number of columns in X 

186 ROW_BLOCK_SIZE: tl.constexpr, 

187 COL_BLOCK_SIZE: tl.constexpr, 

188): 

189 row_pid = tl.program_id(0) 

190 col_pid = tl.program_id(1) 

191 

192 row_start = row_pid * ROW_BLOCK_SIZE 

193 col_start = col_pid * COL_BLOCK_SIZE 

194 

195 offset = row_start * x_stride_r + col_start * x_stride_c 

196 X += offset 

197 DY += offset 

198 INV_RMS += row_start 

199 

200 rows = tl.arange(0, ROW_BLOCK_SIZE) 

201 cols = tl.arange(0, COL_BLOCK_SIZE) 

202 

203 row_mask = (row_start + rows) < M 

204 col_mask = (col_start + cols) < N 

205 

206 x = tl.load( 

207 X + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c, 

208 row_mask[:, None] & col_mask[None, :], 

209 other=0.0, 

210 ).to(tl.float32) 

211 inv_rms = tl.load(INV_RMS + rows, row_mask, other=0.0).to(tl.float32) 

212 dy = tl.load( 

213 DY + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c, 

214 row_mask[:, None] & col_mask[None, :], 

215 other=0.0, 

216 ).to(tl.float32) 

217 

218 d_weight = x * dy * inv_rms[:, None] 

219 # Sum over rows (axis=0) - masked rows are 0 (from other=0.0 in load), so sum is correct 

220 # The mask ensures invalid rows contribute 0 to the sum 

221 partial_dweight_sum = tl.sum(d_weight, axis=0) 

222 

223 tl.store( 

224 DW + row_pid * N + col_start + cols, 

225 partial_dweight_sum, 

226 mask=col_mask, 

227 ) 

228 

229 

230def rms_norm_out(result, x, normalized_shape, weight, eps=1e-5): 

231 y, _ = rms_norm_forward(x, normalized_shape, weight, eps=eps) 

232 result.copy_(y) 

233 return result 

234 

235 

236def rms_norm_forward(x, normalized_shape, weight, eps=1e-5): 

237 logger.debug("GEMS RMS_NORM FORWARD") 

238 dim = x.ndim - len(normalized_shape) 

239 M = math.prod(x.shape[:dim]) 

240 N = math.prod(normalized_shape) 

241 

242 x = x.contiguous() 

243 weight = weight.contiguous() 

244 y = torch.empty_like(x) 

245 inv_rms = torch.empty((M,), device=x.device, dtype=torch.float32) 

246 

247 with torch_device_fn.device(x.device): 

248 if N <= 4096: 

249 BLOCK_SIZE = triton.next_power_of_2(N) 

250 rms_norm_kernel[M,](y, inv_rms, x, weight, N, 1, N, 1, N, eps, BLOCK_SIZE) 

251 else: 

252 rms_norm_loop_kernel[M,](y, inv_rms, x, weight, N, eps) 

253 

254 return y, inv_rms 

255 

256 

257def rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps=1e-5): 

258 logger.debug("GEMS RMS_NORM BACKWARD") 

259 dim = x.ndim - len(normalized_shape) 

260 M = math.prod(x.shape[:dim]) 

261 N = math.prod(normalized_shape) 

262 

263 BLOCK_SIZE = triton.next_power_of_2(N) 

264 x = x.contiguous() 

265 dy = dy.contiguous() 

266 weight = weight.contiguous() 

267 dx = torch.empty_like(x) 

268 

269 with torch_device_fn.device(x.device): 

270 rms_norm_grad_dx_kernel[M,]( 

271 x, dy, inv_rms, dx, weight, N, 1, N, 1, N, eps, BLOCK_SIZE 

272 ) 

273 

274 ROW_BLOCK_SIZE = 16 

275 COL_BLOCK_SIZE = 256 

276 row_block_num = triton.cdiv(M, ROW_BLOCK_SIZE) 

277 col_block_num = triton.cdiv(N, COL_BLOCK_SIZE) 

278 

279 partial_buffer = torch.empty( 

280 (row_block_num, N), dtype=torch.float32, device=x.device 

281 ) 

282 

283 with torch_device_fn.device(x.device): 

284 rms_norm_grad_dw_kernel[row_block_num, col_block_num]( 

285 x, 

286 dy, 

287 inv_rms, 

288 partial_buffer, 

289 N, 

290 1, 

291 N, 

292 1, 

293 M, 

294 N, 

295 ROW_BLOCK_SIZE, 

296 COL_BLOCK_SIZE, 

297 ) 

298 dw = ( 

299 torch.sum(partial_buffer, dim=0, dtype=torch.float32) 

300 .to(x.dtype) 

301 .reshape(-1) 

302 ) 

303 

304 return dx, dw 

305 

306 

307class RmsNorm(torch.autograd.Function): 

308 @staticmethod 

309 def forward(ctx, x, normalized_shape, weight, eps=1e-5): 

310 y, inv_rms = rms_norm_forward(x, normalized_shape, weight, eps) 

311 ctx.save_for_backward(x, inv_rms, weight) 

312 ctx.normalized_shape = normalized_shape 

313 ctx.eps = eps 

314 return y 

315 

316 @staticmethod 

317 def backward(ctx, dy): 

318 x, inv_rms, weight = ctx.saved_tensors 

319 normalized_shape = ctx.normalized_shape 

320 eps = ctx.eps 

321 

322 dx, dw = rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps) 

323 return dx, None, dw, None 

324 

325 

326def rms_norm(x, normalized_shape, weight, eps=1e-5): 

327 return RmsNorm.apply(x, normalized_shape, weight, eps)