Coverage for src/flag_gems/runtime/backend/_sunrise/ops/rms_norm.py: 0%

179 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as ext 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@libentry() 

16@triton.jit(do_not_specialize=["eps"]) 

17def rms_norm_kernel( 

18 out_ptr, # pointer to the output 

19 INV_RMS, # pointer to inverse rms 

20 in_ptr, # pointer to the input 

21 w_ptr, # pointer to the weights 

22 y_stride_r, 

23 y_stride_c, 

24 x_stride_r, # how much to increase the pointer when moving by 1 row 

25 x_stride_c, # how much to increase the pointer when moving by 1 col 

26 N, # number of columns in X 

27 eps, # epsilon to avoid division by zero 

28 BLOCK_SIZE: tl.constexpr, 

29): 

30 if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

31 in_ptr.dtype.element_ty == tl.bfloat16 

32 ): 

33 cdtype = tl.float32 

34 else: 

35 cdtype = in_ptr.dtype.element_ty 

36 

37 pid = tl.program_id(0) 

38 out_ptr += pid * y_stride_r 

39 in_ptr += pid * x_stride_r 

40 

41 mask = tl.arange(0, BLOCK_SIZE) < N 

42 cols = tl.arange(0, BLOCK_SIZE) 

43 x = tl.load(in_ptr + cols * x_stride_c, mask, other=0.0).to(cdtype) 

44 

45 var = tl.sum(x * x, axis=0) / N 

46 rrms = 1 / tl.sqrt(var + eps) 

47 

48 w = tl.load(w_ptr + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0) 

49 y = (x * rrms * w).to(cdtype) 

50 tl.store(out_ptr + cols * y_stride_c, y, mask=mask) 

51 tl.store(INV_RMS + pid, rrms) 

52 

53 

54@libentry() 

55@triton.jit(do_not_specialize=["eps"]) 

56def rms_norm_2d_kernel( 

57 out_ptr, 

58 INV_RMS, 

59 in_ptr, 

60 w_ptr, 

61 M, 

62 N, 

63 eps, 

64 TILE_M: tl.constexpr, 

65 BLOCK_N: tl.constexpr, 

66): 

67 if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

68 in_ptr.dtype.element_ty == tl.bfloat16 

69 ): 

70 cdtype = tl.float32 

71 else: 

72 cdtype = in_ptr.dtype.element_ty 

73 

74 pid = tl.program_id(0) 

75 m_offsets = pid * TILE_M + tl.arange(0, TILE_M) 

76 m_mask = m_offsets < M 

77 cols = tl.arange(0, BLOCK_N) 

78 mask = m_mask[:, None] & (cols[None, :] < N) 

79 

80 x = tl.load(in_ptr + m_offsets[:, None] * N + cols[None, :], mask, other=0.0).to( 

81 cdtype 

82 ) 

83 var = tl.sum(x * x, axis=1) / N 

84 rrms = 1 / tl.sqrt(var + eps) 

85 

86 w = tl.load(w_ptr + cols, mask=cols < N, other=0.0) 

87 y = (x * rrms[:, None] * w[None, :]).to(cdtype) 

88 tl.store(out_ptr + m_offsets[:, None] * N + cols[None, :], y, mask=mask) 

89 tl.store(INV_RMS + m_offsets, rrms, mask=m_mask) 

90 

91 

92@libentry() 

93@triton.jit(do_not_specialize=["eps"]) 

94def rms_norm_c_split_kernel( 

95 out_ptr, # pointer to the output 

96 INV_RMS, # pointer to inverse rms 

97 in_ptr, # pointer to the input 

98 w_ptr, # pointer to the weights 

99 y_stride_r, 

100 y_stride_c, 

101 x_stride_r, # how much to increase the pointer when moving by 1 row 

102 x_stride_c, # how much to increase the pointer when moving by 1 col 

103 N, # number of columns in X 

104 eps, # epsilon to avoid division by zero 

105 BLOCK_SIZE: tl.constexpr, 

106): 

107 if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

108 in_ptr.dtype.element_ty == tl.bfloat16 

109 ): 

110 cdtype = tl.float32 

111 else: 

112 cdtype = in_ptr.dtype.element_ty 

113 

114 pid = tl.program_id(0) 

115 out_ptr += pid * y_stride_r 

116 in_ptr += pid * x_stride_r 

117 

118 var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 

119 for n_idx in range(0, N, BLOCK_SIZE): 

120 cols = n_idx + tl.arange(0, BLOCK_SIZE) 

121 mask = cols < N 

122 x = tl.load(in_ptr + cols * x_stride_c, mask, other=0.0).to(cdtype) 

123 var += x * x 

124 

125 var = tl.sum(var, axis=0) / N 

126 rrms = 1 / tl.sqrt(var + eps) 

127 

128 for n_idx in range(0, N, BLOCK_SIZE): 

129 cols = n_idx + tl.arange(0, BLOCK_SIZE) 

130 mask = cols < N 

131 w = tl.load(w_ptr + cols, mask=mask, other=0.0) 

132 x = tl.load(in_ptr + cols * x_stride_c, mask, other=0.0).to(cdtype) 

133 y = (x * rrms * w).to(cdtype) 

134 tl.store(out_ptr + cols * y_stride_c, y, mask=mask) 

135 tl.store(INV_RMS + pid, rrms) 

136 

137 

138@libentry() 

139@triton.jit(do_not_specialize=["eps"]) 

140def rms_norm_grad_dx_kernel( 

141 X, # pointer to the input 

142 DY, 

143 INV_RMS, # pointer to inverse rms 

144 DX, # pointer to the output 

145 W, # pointer to the weights 

146 dx_stride_r, 

147 dx_stride_c, 

148 x_stride_r, # how much to increase the pointer when moving by 1 row 

149 x_stride_c, # how much to increase the pointer when moving by 1 col 

150 N, # number of columns in X 

151 eps, # epsilon to avoid division by zero 

152 BLOCK_SIZE: tl.constexpr, 

153): 

154 pid = ext.program_id(0) 

155 DX += pid * dx_stride_r 

156 X += pid * x_stride_r 

157 DY += pid * x_stride_r 

158 INV_RMS += pid 

159 

160 inv_rms = tl.load(INV_RMS).to(tl.float32) 

161 

162 row_sum_stats = 0.0 

163 for off in range(0, N, BLOCK_SIZE): 

164 cols = off + tl.arange(0, BLOCK_SIZE) 

165 mask = cols < N 

166 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

167 dy = tl.load(DY + cols, mask, other=0.0).to(tl.float32) 

168 w = tl.load(W + cols, mask, other=0.0).to(tl.float32) 

169 dy = dy * w 

170 normalized_buf = x * inv_rms 

171 row_sum_stats += tl.sum(normalized_buf * dy) 

172 

173 for off in range(0, N, BLOCK_SIZE): 

174 cols = off + tl.arange(0, BLOCK_SIZE) 

175 mask = cols < N 

176 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

177 dy = tl.load(DY + cols, mask, other=0.0).to(tl.float32) 

178 w = tl.load(W + cols, mask, other=0.0).to(tl.float32) 

179 dy = dy * w 

180 normalized_buf = x * inv_rms 

181 norm_val = normalized_buf / N 

182 dx = (dy - norm_val * row_sum_stats) * inv_rms 

183 tl.store(DX + cols * dx_stride_c, dx, mask=mask) 

184 

185 

186@libentry() 

187@triton.jit 

188def rms_norm_grad_dw_kernel( 

189 X, # pointer to the input 

190 DY, 

191 INV_RMS, # pointer to inverse rms 

192 DW, # pointer to the output 

193 dx_stride_r, 

194 dx_stride_c, 

195 x_stride_r, # how much to increase the pointer when moving by 1 row 

196 x_stride_c, # how much to increase the pointer when moving by 1 col 

197 M, # number of rows in X 

198 N, # number of columns in X 

199 ROW_BLOCK_SIZE: tl.constexpr, 

200 COL_BLOCK_SIZE: tl.constexpr, 

201): 

202 row_pid = tl.program_id(0) 

203 col_pid = tl.program_id(1) 

204 

205 row_start = row_pid * ROW_BLOCK_SIZE 

206 col_start = col_pid * COL_BLOCK_SIZE 

207 

208 offset = row_start * x_stride_r + col_start * x_stride_c 

209 X += offset 

210 DY += offset 

211 INV_RMS += row_start 

212 

213 rows = tl.arange(0, ROW_BLOCK_SIZE) 

214 cols = tl.arange(0, COL_BLOCK_SIZE) 

215 

216 row_mask = (row_start + rows) < M 

217 col_mask = (col_start + cols) < N 

218 

219 x = tl.load( 

220 X + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c, 

221 row_mask[:, None] & col_mask[None, :], 

222 other=0.0, 

223 ).to(tl.float32) 

224 inv_rms = tl.load(INV_RMS + rows, row_mask, other=0.0).to(tl.float32) 

225 dy = tl.load( 

226 DY + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c, 

227 row_mask[:, None] & col_mask[None, :], 

228 other=0.0, 

229 ).to(tl.float32) 

230 

231 d_weight = x * dy * inv_rms[:, None] 

232 # Sum over rows (axis=0) - masked rows are 0 (from other=0.0 in load), so sum is correct 

233 # The mask ensures invalid rows contribute 0 to the sum 

234 partial_dweight_sum = tl.sum(d_weight, axis=0) 

235 

236 tl.store( 

237 DW + row_pid * N + col_start + cols, 

238 partial_dweight_sum, 

239 mask=col_mask, 

240 ) 

241 

242 

243def rms_norm_forward(x, normalized_shape, weight, eps=1e-5): 

244 logger.debug("GEMS RMS_NORM FORWARD") 

245 dim = x.ndim - len(normalized_shape) 

246 M = math.prod(x.shape[:dim]) 

247 N = math.prod(normalized_shape) 

248 

249 BLOCK_SIZE = triton.next_power_of_2(N) 

250 x = x.contiguous() 

251 weight = weight.contiguous() 

252 y = torch.empty_like(x) 

253 inv_rms = torch.empty((M,), device=x.device, dtype=torch.float32) 

254 

255 with torch_device_fn.device(x.device): 

256 if BLOCK_SIZE <= 512: # [Sunrise] 2d load works for block_size < 512 

257 TILE_M = triton.cdiv(1024, BLOCK_SIZE) 

258 grid = (triton.cdiv(M, TILE_M),) 

259 rms_norm_2d_kernel[grid]( 

260 y, inv_rms, x, weight, M, N, eps, TILE_M, BLOCK_SIZE 

261 ) 

262 elif BLOCK_SIZE <= 1024: 

263 rms_norm_kernel[M,](y, inv_rms, x, weight, N, 1, N, 1, N, eps, BLOCK_SIZE) 

264 else: 

265 BLOCK_SIZE = 1024 

266 rms_norm_c_split_kernel[M,]( 

267 y, inv_rms, x, weight, N, 1, N, 1, N, eps, BLOCK_SIZE, num_warps=16 

268 ) 

269 return y, inv_rms 

270 

271 

272def rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps=1e-5): 

273 logger.debug("GEMS RMS_NORM BACKWARD") 

274 dim = x.ndim - len(normalized_shape) 

275 M = math.prod(x.shape[:dim]) 

276 N = math.prod(normalized_shape) 

277 

278 BLOCK_SIZE = min(triton.next_power_of_2(N), 1024) 

279 x = x.contiguous() 

280 dy = dy.contiguous() 

281 weight = weight.contiguous() 

282 dx = torch.empty_like(x) 

283 

284 with torch_device_fn.device(x.device): 

285 rms_norm_grad_dx_kernel[M,]( 

286 x, dy, inv_rms, dx, weight, N, 1, N, 1, N, eps, BLOCK_SIZE 

287 ) 

288 

289 ROW_BLOCK_SIZE = 16 

290 COL_BLOCK_SIZE = 256 

291 row_block_num = triton.cdiv(M, ROW_BLOCK_SIZE) 

292 col_block_num = triton.cdiv(N, COL_BLOCK_SIZE) 

293 

294 partial_buffer = torch.empty( 

295 (row_block_num, N), dtype=torch.float32, device=x.device 

296 ) 

297 

298 with torch_device_fn.device(x.device): 

299 rms_norm_grad_dw_kernel[row_block_num, col_block_num]( 

300 x, 

301 dy, 

302 inv_rms, 

303 partial_buffer, 

304 N, 

305 1, 

306 N, 

307 1, 

308 M, 

309 N, 

310 ROW_BLOCK_SIZE, 

311 COL_BLOCK_SIZE, 

312 ) 

313 dw = ( 

314 torch.sum(partial_buffer, dim=0, dtype=torch.float32) 

315 .to(x.dtype) 

316 .reshape(-1) 

317 ) 

318 

319 return dx, dw 

320 

321 

322class RmsNorm(torch.autograd.Function): 

323 @staticmethod 

324 def forward(ctx, x, normalized_shape, weight, eps=1e-5): 

325 y, inv_rms = rms_norm_forward(x, normalized_shape, weight, eps) 

326 ctx.save_for_backward(x, inv_rms, weight) 

327 ctx.normalized_shape = normalized_shape 

328 ctx.eps = eps 

329 return y 

330 

331 @staticmethod 

332 def backward(ctx, dy): 

333 x, inv_rms, weight = ctx.saved_tensors 

334 normalized_shape = ctx.normalized_shape 

335 eps = ctx.eps 

336 

337 dx, dw = rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps) 

338 return dx, None, dw, None 

339 

340 

341def rms_norm(x, normalized_shape, weight, eps=1e-5): 

342 return RmsNorm.apply(x, normalized_shape, weight, eps)