Coverage for src/flag_gems/runtime/backend/_sunrise/ops/rms_norm.py: 0%

224 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as ext 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16@triton.jit 

17def prev_multiple_of(a, b): 

18 return tl.cdiv(a, b) * b - b 

19 

20 

21@libentry() 

22@triton.jit(do_not_specialize=["eps"]) 

23def rms_norm_kernel( 

24 out_ptr, # pointer to the output 

25 INV_RMS, # pointer to inverse rms 

26 in_ptr, # pointer to the input 

27 w_ptr, # pointer to the weights 

28 y_stride_r, 

29 y_stride_c, 

30 x_stride_r, # how much to increase the pointer when moving by 1 row 

31 x_stride_c, # how much to increase the pointer when moving by 1 col 

32 N, # number of columns in X 

33 eps, # epsilon to avoid division by zero 

34 BLOCK_SIZE: tl.constexpr, 

35): 

36 if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

37 in_ptr.dtype.element_ty == tl.bfloat16 

38 ): 

39 cdtype = tl.float32 

40 else: 

41 cdtype = in_ptr.dtype.element_ty 

42 

43 pid = tl.program_id(0) 

44 out_ptr += pid * y_stride_r 

45 in_ptr += pid * x_stride_r 

46 

47 mask = tl.arange(0, BLOCK_SIZE) < N 

48 cols = tl.arange(0, BLOCK_SIZE) 

49 x = tl.load(in_ptr + cols * x_stride_c, mask, other=0.0).to(cdtype) 

50 

51 var = tl.sum(x * x, axis=0) / N 

52 rrms = 1 / tl.sqrt(var + eps) 

53 

54 w = tl.load(w_ptr + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0) 

55 y = (x * rrms * w).to(cdtype) 

56 tl.store(out_ptr + cols * y_stride_c, y, mask=mask) 

57 tl.store(INV_RMS + pid, rrms) 

58 

59 

60@libentry() 

61@triton.autotune( 

62 configs=runtime.get_tuned_config("rms_norm_loop"), 

63 key=["N"], 

64) 

65@triton.jit(do_not_specialize=["eps"]) 

66def rms_norm_loop_kernel( 

67 out_ptr, 

68 INV_RMS, 

69 in_ptr, 

70 w_ptr, 

71 N, 

72 eps, 

73 TILE_N: tl.constexpr, 

74): 

75 if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

76 in_ptr.dtype.element_ty == tl.bfloat16 

77 ): 

78 cdtype = tl.float32 

79 else: 

80 cdtype = in_ptr.dtype.element_ty 

81 

82 pid = ext.program_id(0) 

83 

84 # Pass 1: compute sum(x^2) in chunks 

85 acc = tl.zeros((TILE_N,), dtype=tl.float32) 

86 num_steps = tl.cdiv(N, TILE_N) 

87 

88 for step in range(0, num_steps - 1): 

89 start_n = step * TILE_N 

90 n_offsets = start_n + tl.arange(0, TILE_N) 

91 x = tl.load(in_ptr + pid * N + n_offsets).to(tl.float32) 

92 acc += x * x 

93 

94 # last step with mask 

95 start_n = (num_steps - 1) * TILE_N 

96 n_offsets = start_n + tl.arange(0, TILE_N) 

97 mask = n_offsets < N 

98 x = tl.load(in_ptr + pid * N + n_offsets, mask=mask, other=0.0).to(tl.float32) 

99 acc += x * x 

100 

101 var = tl.sum(acc) / N 

102 rrms = 1 / tl.sqrt(var + eps) 

103 tl.store(INV_RMS + pid, rrms) 

104 

105 # Pass 2: normalize in reverse order (better L2 cache reuse) 

106 prev_multiple = prev_multiple_of(N, TILE_N) 

107 

108 # first reverse step with mask 

109 for start_n in range(0, TILE_N, TILE_N): 

110 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N) 

111 mask = n_offsets < N 

112 x = tl.load( 

113 in_ptr + pid * N + n_offsets, 

114 mask=mask, 

115 other=0.0, 

116 eviction_policy="evict_first", 

117 ).to(cdtype) 

118 w = tl.load(w_ptr + n_offsets, mask=mask, other=0.0) 

119 y = (x * rrms * w).to(cdtype) 

120 tl.store(out_ptr + pid * N + n_offsets, y, mask=mask) 

121 

122 for start_n in range(TILE_N, N, TILE_N): 

123 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N) 

124 x = tl.load( 

125 in_ptr + pid * N + n_offsets, 

126 eviction_policy="evict_first", 

127 ).to(cdtype) 

128 w = tl.load(w_ptr + n_offsets) 

129 y = (x * rrms * w).to(cdtype) 

130 tl.store(out_ptr + pid * N + n_offsets, y) 

131 

132 

133@libentry() 

134@triton.jit(do_not_specialize=["eps"]) 

135def rms_norm_2d_kernel( 

136 out_ptr, 

137 INV_RMS, 

138 in_ptr, 

139 w_ptr, 

140 M, 

141 N, 

142 eps, 

143 TILE_M: tl.constexpr, 

144 BLOCK_N: tl.constexpr, 

145): 

146 if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

147 in_ptr.dtype.element_ty == tl.bfloat16 

148 ): 

149 cdtype = tl.float32 

150 else: 

151 cdtype = in_ptr.dtype.element_ty 

152 

153 pid = tl.program_id(0) 

154 m_offsets = pid * TILE_M + tl.arange(0, TILE_M) 

155 m_mask = m_offsets < M 

156 cols = tl.arange(0, BLOCK_N) 

157 mask = m_mask[:, None] & (cols[None, :] < N) 

158 

159 x = tl.load(in_ptr + m_offsets[:, None] * N + cols[None, :], mask, other=0.0).to( 

160 cdtype 

161 ) 

162 var = tl.sum(x * x, axis=1) / N 

163 rrms = 1 / tl.sqrt(var + eps) 

164 

165 w = tl.load(w_ptr + cols, mask=cols < N, other=0.0) 

166 y = (x * rrms[:, None] * w[None, :]).to(cdtype) 

167 tl.store(out_ptr + m_offsets[:, None] * N + cols[None, :], y, mask=mask) 

168 tl.store(INV_RMS + m_offsets, rrms, mask=m_mask) 

169 

170 

171@libentry() 

172@triton.jit(do_not_specialize=["eps"]) 

173def rms_norm_c_split_kernel( 

174 out_ptr, # pointer to the output 

175 INV_RMS, # pointer to inverse rms 

176 in_ptr, # pointer to the input 

177 w_ptr, # pointer to the weights 

178 y_stride_r, 

179 y_stride_c, 

180 x_stride_r, # how much to increase the pointer when moving by 1 row 

181 x_stride_c, # how much to increase the pointer when moving by 1 col 

182 N, # number of columns in X 

183 eps, # epsilon to avoid division by zero 

184 BLOCK_SIZE: tl.constexpr, 

185): 

186 if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

187 in_ptr.dtype.element_ty == tl.bfloat16 

188 ): 

189 cdtype = tl.float32 

190 else: 

191 cdtype = in_ptr.dtype.element_ty 

192 

193 pid = tl.program_id(0) 

194 out_ptr += pid * y_stride_r 

195 in_ptr += pid * x_stride_r 

196 

197 var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 

198 for n_idx in range(0, N, BLOCK_SIZE): 

199 cols = n_idx + tl.arange(0, BLOCK_SIZE) 

200 mask = cols < N 

201 x = tl.load(in_ptr + cols * x_stride_c, mask, other=0.0).to(cdtype) 

202 var += x * x 

203 

204 var = tl.sum(var, axis=0) / N 

205 rrms = 1 / tl.sqrt(var + eps) 

206 

207 for n_idx in range(0, N, BLOCK_SIZE): 

208 cols = n_idx + tl.arange(0, BLOCK_SIZE) 

209 mask = cols < N 

210 w = tl.load(w_ptr + cols, mask=mask, other=0.0) 

211 x = tl.load(in_ptr + cols * x_stride_c, mask, other=0.0).to(cdtype) 

212 y = (x * rrms * w).to(cdtype) 

213 tl.store(out_ptr + cols * y_stride_c, y, mask=mask) 

214 tl.store(INV_RMS + pid, rrms) 

215 

216 

217@libentry() 

218@triton.jit(do_not_specialize=["eps"]) 

219def rms_norm_grad_dx_kernel( 

220 X, # pointer to the input 

221 DY, 

222 INV_RMS, # pointer to inverse rms 

223 DX, # pointer to the output 

224 W, # pointer to the weights 

225 dx_stride_r, 

226 dx_stride_c, 

227 x_stride_r, # how much to increase the pointer when moving by 1 row 

228 x_stride_c, # how much to increase the pointer when moving by 1 col 

229 N, # number of columns in X 

230 eps, # epsilon to avoid division by zero 

231 BLOCK_SIZE: tl.constexpr, 

232): 

233 pid = ext.program_id(0) 

234 DX += pid * dx_stride_r 

235 X += pid * x_stride_r 

236 DY += pid * x_stride_r 

237 INV_RMS += pid 

238 

239 inv_rms = tl.load(INV_RMS).to(tl.float32) 

240 

241 row_sum_stats = 0.0 

242 for off in range(0, N, BLOCK_SIZE): 

243 cols = off + tl.arange(0, BLOCK_SIZE) 

244 mask = cols < N 

245 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

246 dy = tl.load(DY + cols, mask, other=0.0).to(tl.float32) 

247 w = tl.load(W + cols, mask, other=0.0).to(tl.float32) 

248 dy = dy * w 

249 normalized_buf = x * inv_rms 

250 row_sum_stats += tl.sum(normalized_buf * dy) 

251 

252 for off in range(0, N, BLOCK_SIZE): 

253 cols = off + tl.arange(0, BLOCK_SIZE) 

254 mask = cols < N 

255 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

256 dy = tl.load(DY + cols, mask, other=0.0).to(tl.float32) 

257 w = tl.load(W + cols, mask, other=0.0).to(tl.float32) 

258 dy = dy * w 

259 normalized_buf = x * inv_rms 

260 norm_val = normalized_buf / N 

261 dx = (dy - norm_val * row_sum_stats) * inv_rms 

262 tl.store(DX + cols * dx_stride_c, dx, mask=mask) 

263 

264 

265@libentry() 

266@triton.jit 

267def rms_norm_grad_dw_kernel( 

268 X, # pointer to the input 

269 DY, 

270 INV_RMS, # pointer to inverse rms 

271 DW, # pointer to the output 

272 dx_stride_r, 

273 dx_stride_c, 

274 x_stride_r, # how much to increase the pointer when moving by 1 row 

275 x_stride_c, # how much to increase the pointer when moving by 1 col 

276 M, # number of rows in X 

277 N, # number of columns in X 

278 ROW_BLOCK_SIZE: tl.constexpr, 

279 COL_BLOCK_SIZE: tl.constexpr, 

280): 

281 row_pid = tl.program_id(0) 

282 col_pid = tl.program_id(1) 

283 

284 row_start = row_pid * ROW_BLOCK_SIZE 

285 col_start = col_pid * COL_BLOCK_SIZE 

286 

287 offset = row_start * x_stride_r + col_start * x_stride_c 

288 X += offset 

289 DY += offset 

290 INV_RMS += row_start 

291 

292 rows = tl.arange(0, ROW_BLOCK_SIZE) 

293 cols = tl.arange(0, COL_BLOCK_SIZE) 

294 

295 row_mask = (row_start + rows) < M 

296 col_mask = (col_start + cols) < N 

297 

298 x = tl.load( 

299 X + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c, 

300 row_mask[:, None] & col_mask[None, :], 

301 other=0.0, 

302 ).to(tl.float32) 

303 inv_rms = tl.load(INV_RMS + rows, row_mask, other=0.0).to(tl.float32) 

304 dy = tl.load( 

305 DY + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c, 

306 row_mask[:, None] & col_mask[None, :], 

307 other=0.0, 

308 ).to(tl.float32) 

309 

310 d_weight = x * dy * inv_rms[:, None] 

311 # Sum over rows (axis=0) - masked rows are 0 (from other=0.0 in load), so sum is correct 

312 # The mask ensures invalid rows contribute 0 to the sum 

313 partial_dweight_sum = tl.sum(d_weight, axis=0) 

314 

315 tl.store( 

316 DW + row_pid * N + col_start + cols, 

317 partial_dweight_sum, 

318 mask=col_mask, 

319 ) 

320 

321 

322def rms_norm_out(result, x, normalized_shape, weight, eps=1e-5): 

323 y, _ = rms_norm_forward(x, normalized_shape, weight, eps=eps) 

324 result.copy_(y) 

325 return result 

326 

327 

328def rms_norm_forward(x, normalized_shape, weight, eps=1e-5): 

329 logger.debug("GEMS RMS_NORM FORWARD") 

330 dim = x.ndim - len(normalized_shape) 

331 M = math.prod(x.shape[:dim]) 

332 N = math.prod(normalized_shape) 

333 

334 BLOCK_SIZE = triton.next_power_of_2(N) 

335 x = x.contiguous() 

336 weight = weight.contiguous() 

337 y = torch.empty_like(x) 

338 inv_rms = torch.empty((M,), device=x.device, dtype=torch.float32) 

339 

340 with torch_device_fn.device(x.device): 

341 if BLOCK_SIZE <= 512: # [Sunrise] 2d load works for block_size < 512 

342 TILE_M = triton.cdiv(1024, BLOCK_SIZE) 

343 grid = (triton.cdiv(M, TILE_M),) 

344 rms_norm_2d_kernel[grid]( 

345 y, inv_rms, x, weight, M, N, eps, TILE_M, BLOCK_SIZE 

346 ) 

347 elif BLOCK_SIZE <= 1024: 

348 rms_norm_kernel[M,](y, inv_rms, x, weight, N, 1, N, 1, N, eps, BLOCK_SIZE) 

349 else: 

350 BLOCK_SIZE = 1024 

351 rms_norm_c_split_kernel[M,]( 

352 y, inv_rms, x, weight, N, 1, N, 1, N, eps, BLOCK_SIZE, num_warps=16 

353 ) 

354 return y, inv_rms 

355 

356 

357def rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps=1e-5): 

358 logger.debug("GEMS RMS_NORM BACKWARD") 

359 dim = x.ndim - len(normalized_shape) 

360 M = math.prod(x.shape[:dim]) 

361 N = math.prod(normalized_shape) 

362 

363 BLOCK_SIZE = min(triton.next_power_of_2(N), 1024) 

364 x = x.contiguous() 

365 dy = dy.contiguous() 

366 weight = weight.contiguous() 

367 dx = torch.empty_like(x) 

368 

369 with torch_device_fn.device(x.device): 

370 rms_norm_grad_dx_kernel[M,]( 

371 x, dy, inv_rms, dx, weight, N, 1, N, 1, N, eps, BLOCK_SIZE 

372 ) 

373 

374 ROW_BLOCK_SIZE = 16 

375 COL_BLOCK_SIZE = 256 

376 row_block_num = triton.cdiv(M, ROW_BLOCK_SIZE) 

377 col_block_num = triton.cdiv(N, COL_BLOCK_SIZE) 

378 

379 partial_buffer = torch.empty( 

380 (row_block_num, N), dtype=torch.float32, device=x.device 

381 ) 

382 

383 with torch_device_fn.device(x.device): 

384 rms_norm_grad_dw_kernel[row_block_num, col_block_num]( 

385 x, 

386 dy, 

387 inv_rms, 

388 partial_buffer, 

389 N, 

390 1, 

391 N, 

392 1, 

393 M, 

394 N, 

395 ROW_BLOCK_SIZE, 

396 COL_BLOCK_SIZE, 

397 ) 

398 dw = ( 

399 torch.sum(partial_buffer, dim=0, dtype=torch.float32) 

400 .to(x.dtype) 

401 .reshape(-1) 

402 ) 

403 

404 return dx, dw 

405 

406 

407class RmsNorm(torch.autograd.Function): 

408 @staticmethod 

409 def forward(ctx, x, normalized_shape, weight, eps=1e-5): 

410 y, inv_rms = rms_norm_forward(x, normalized_shape, weight, eps) 

411 ctx.save_for_backward(x, inv_rms, weight) 

412 ctx.normalized_shape = normalized_shape 

413 ctx.eps = eps 

414 return y 

415 

416 @staticmethod 

417 def backward(ctx, dy): 

418 x, inv_rms, weight = ctx.saved_tensors 

419 normalized_shape = ctx.normalized_shape 

420 eps = ctx.eps 

421 

422 dx, dw = rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps) 

423 return dx, None, dw, None 

424 

425 

426def rms_norm(x, normalized_shape, weight, eps=1e-5): 

427 return RmsNorm.apply(x, normalized_shape, weight, eps)