Coverage for src/flag_gems/runtime/backend/_sunrise/ops/rms_norm.py: 0%
224 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as ext
13logger = logging.getLogger(__name__)
16@triton.jit
17def prev_multiple_of(a, b):
18 return tl.cdiv(a, b) * b - b
21@libentry()
22@triton.jit(do_not_specialize=["eps"])
23def rms_norm_kernel(
24 out_ptr, # pointer to the output
25 INV_RMS, # pointer to inverse rms
26 in_ptr, # pointer to the input
27 w_ptr, # pointer to the weights
28 y_stride_r,
29 y_stride_c,
30 x_stride_r, # how much to increase the pointer when moving by 1 row
31 x_stride_c, # how much to increase the pointer when moving by 1 col
32 N, # number of columns in X
33 eps, # epsilon to avoid division by zero
34 BLOCK_SIZE: tl.constexpr,
35):
36 if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
37 in_ptr.dtype.element_ty == tl.bfloat16
38 ):
39 cdtype = tl.float32
40 else:
41 cdtype = in_ptr.dtype.element_ty
43 pid = tl.program_id(0)
44 out_ptr += pid * y_stride_r
45 in_ptr += pid * x_stride_r
47 mask = tl.arange(0, BLOCK_SIZE) < N
48 cols = tl.arange(0, BLOCK_SIZE)
49 x = tl.load(in_ptr + cols * x_stride_c, mask, other=0.0).to(cdtype)
51 var = tl.sum(x * x, axis=0) / N
52 rrms = 1 / tl.sqrt(var + eps)
54 w = tl.load(w_ptr + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)
55 y = (x * rrms * w).to(cdtype)
56 tl.store(out_ptr + cols * y_stride_c, y, mask=mask)
57 tl.store(INV_RMS + pid, rrms)
60@libentry()
61@triton.autotune(
62 configs=runtime.get_tuned_config("rms_norm_loop"),
63 key=["N"],
64)
65@triton.jit(do_not_specialize=["eps"])
66def rms_norm_loop_kernel(
67 out_ptr,
68 INV_RMS,
69 in_ptr,
70 w_ptr,
71 N,
72 eps,
73 TILE_N: tl.constexpr,
74):
75 if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
76 in_ptr.dtype.element_ty == tl.bfloat16
77 ):
78 cdtype = tl.float32
79 else:
80 cdtype = in_ptr.dtype.element_ty
82 pid = ext.program_id(0)
84 # Pass 1: compute sum(x^2) in chunks
85 acc = tl.zeros((TILE_N,), dtype=tl.float32)
86 num_steps = tl.cdiv(N, TILE_N)
88 for step in range(0, num_steps - 1):
89 start_n = step * TILE_N
90 n_offsets = start_n + tl.arange(0, TILE_N)
91 x = tl.load(in_ptr + pid * N + n_offsets).to(tl.float32)
92 acc += x * x
94 # last step with mask
95 start_n = (num_steps - 1) * TILE_N
96 n_offsets = start_n + tl.arange(0, TILE_N)
97 mask = n_offsets < N
98 x = tl.load(in_ptr + pid * N + n_offsets, mask=mask, other=0.0).to(tl.float32)
99 acc += x * x
101 var = tl.sum(acc) / N
102 rrms = 1 / tl.sqrt(var + eps)
103 tl.store(INV_RMS + pid, rrms)
105 # Pass 2: normalize in reverse order (better L2 cache reuse)
106 prev_multiple = prev_multiple_of(N, TILE_N)
108 # first reverse step with mask
109 for start_n in range(0, TILE_N, TILE_N):
110 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)
111 mask = n_offsets < N
112 x = tl.load(
113 in_ptr + pid * N + n_offsets,
114 mask=mask,
115 other=0.0,
116 eviction_policy="evict_first",
117 ).to(cdtype)
118 w = tl.load(w_ptr + n_offsets, mask=mask, other=0.0)
119 y = (x * rrms * w).to(cdtype)
120 tl.store(out_ptr + pid * N + n_offsets, y, mask=mask)
122 for start_n in range(TILE_N, N, TILE_N):
123 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)
124 x = tl.load(
125 in_ptr + pid * N + n_offsets,
126 eviction_policy="evict_first",
127 ).to(cdtype)
128 w = tl.load(w_ptr + n_offsets)
129 y = (x * rrms * w).to(cdtype)
130 tl.store(out_ptr + pid * N + n_offsets, y)
133@libentry()
134@triton.jit(do_not_specialize=["eps"])
135def rms_norm_2d_kernel(
136 out_ptr,
137 INV_RMS,
138 in_ptr,
139 w_ptr,
140 M,
141 N,
142 eps,
143 TILE_M: tl.constexpr,
144 BLOCK_N: tl.constexpr,
145):
146 if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
147 in_ptr.dtype.element_ty == tl.bfloat16
148 ):
149 cdtype = tl.float32
150 else:
151 cdtype = in_ptr.dtype.element_ty
153 pid = tl.program_id(0)
154 m_offsets = pid * TILE_M + tl.arange(0, TILE_M)
155 m_mask = m_offsets < M
156 cols = tl.arange(0, BLOCK_N)
157 mask = m_mask[:, None] & (cols[None, :] < N)
159 x = tl.load(in_ptr + m_offsets[:, None] * N + cols[None, :], mask, other=0.0).to(
160 cdtype
161 )
162 var = tl.sum(x * x, axis=1) / N
163 rrms = 1 / tl.sqrt(var + eps)
165 w = tl.load(w_ptr + cols, mask=cols < N, other=0.0)
166 y = (x * rrms[:, None] * w[None, :]).to(cdtype)
167 tl.store(out_ptr + m_offsets[:, None] * N + cols[None, :], y, mask=mask)
168 tl.store(INV_RMS + m_offsets, rrms, mask=m_mask)
171@libentry()
172@triton.jit(do_not_specialize=["eps"])
173def rms_norm_c_split_kernel(
174 out_ptr, # pointer to the output
175 INV_RMS, # pointer to inverse rms
176 in_ptr, # pointer to the input
177 w_ptr, # pointer to the weights
178 y_stride_r,
179 y_stride_c,
180 x_stride_r, # how much to increase the pointer when moving by 1 row
181 x_stride_c, # how much to increase the pointer when moving by 1 col
182 N, # number of columns in X
183 eps, # epsilon to avoid division by zero
184 BLOCK_SIZE: tl.constexpr,
185):
186 if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
187 in_ptr.dtype.element_ty == tl.bfloat16
188 ):
189 cdtype = tl.float32
190 else:
191 cdtype = in_ptr.dtype.element_ty
193 pid = tl.program_id(0)
194 out_ptr += pid * y_stride_r
195 in_ptr += pid * x_stride_r
197 var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
198 for n_idx in range(0, N, BLOCK_SIZE):
199 cols = n_idx + tl.arange(0, BLOCK_SIZE)
200 mask = cols < N
201 x = tl.load(in_ptr + cols * x_stride_c, mask, other=0.0).to(cdtype)
202 var += x * x
204 var = tl.sum(var, axis=0) / N
205 rrms = 1 / tl.sqrt(var + eps)
207 for n_idx in range(0, N, BLOCK_SIZE):
208 cols = n_idx + tl.arange(0, BLOCK_SIZE)
209 mask = cols < N
210 w = tl.load(w_ptr + cols, mask=mask, other=0.0)
211 x = tl.load(in_ptr + cols * x_stride_c, mask, other=0.0).to(cdtype)
212 y = (x * rrms * w).to(cdtype)
213 tl.store(out_ptr + cols * y_stride_c, y, mask=mask)
214 tl.store(INV_RMS + pid, rrms)
217@libentry()
218@triton.jit(do_not_specialize=["eps"])
219def rms_norm_grad_dx_kernel(
220 X, # pointer to the input
221 DY,
222 INV_RMS, # pointer to inverse rms
223 DX, # pointer to the output
224 W, # pointer to the weights
225 dx_stride_r,
226 dx_stride_c,
227 x_stride_r, # how much to increase the pointer when moving by 1 row
228 x_stride_c, # how much to increase the pointer when moving by 1 col
229 N, # number of columns in X
230 eps, # epsilon to avoid division by zero
231 BLOCK_SIZE: tl.constexpr,
232):
233 pid = ext.program_id(0)
234 DX += pid * dx_stride_r
235 X += pid * x_stride_r
236 DY += pid * x_stride_r
237 INV_RMS += pid
239 inv_rms = tl.load(INV_RMS).to(tl.float32)
241 row_sum_stats = 0.0
242 for off in range(0, N, BLOCK_SIZE):
243 cols = off + tl.arange(0, BLOCK_SIZE)
244 mask = cols < N
245 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
246 dy = tl.load(DY + cols, mask, other=0.0).to(tl.float32)
247 w = tl.load(W + cols, mask, other=0.0).to(tl.float32)
248 dy = dy * w
249 normalized_buf = x * inv_rms
250 row_sum_stats += tl.sum(normalized_buf * dy)
252 for off in range(0, N, BLOCK_SIZE):
253 cols = off + tl.arange(0, BLOCK_SIZE)
254 mask = cols < N
255 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
256 dy = tl.load(DY + cols, mask, other=0.0).to(tl.float32)
257 w = tl.load(W + cols, mask, other=0.0).to(tl.float32)
258 dy = dy * w
259 normalized_buf = x * inv_rms
260 norm_val = normalized_buf / N
261 dx = (dy - norm_val * row_sum_stats) * inv_rms
262 tl.store(DX + cols * dx_stride_c, dx, mask=mask)
265@libentry()
266@triton.jit
267def rms_norm_grad_dw_kernel(
268 X, # pointer to the input
269 DY,
270 INV_RMS, # pointer to inverse rms
271 DW, # pointer to the output
272 dx_stride_r,
273 dx_stride_c,
274 x_stride_r, # how much to increase the pointer when moving by 1 row
275 x_stride_c, # how much to increase the pointer when moving by 1 col
276 M, # number of rows in X
277 N, # number of columns in X
278 ROW_BLOCK_SIZE: tl.constexpr,
279 COL_BLOCK_SIZE: tl.constexpr,
280):
281 row_pid = tl.program_id(0)
282 col_pid = tl.program_id(1)
284 row_start = row_pid * ROW_BLOCK_SIZE
285 col_start = col_pid * COL_BLOCK_SIZE
287 offset = row_start * x_stride_r + col_start * x_stride_c
288 X += offset
289 DY += offset
290 INV_RMS += row_start
292 rows = tl.arange(0, ROW_BLOCK_SIZE)
293 cols = tl.arange(0, COL_BLOCK_SIZE)
295 row_mask = (row_start + rows) < M
296 col_mask = (col_start + cols) < N
298 x = tl.load(
299 X + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c,
300 row_mask[:, None] & col_mask[None, :],
301 other=0.0,
302 ).to(tl.float32)
303 inv_rms = tl.load(INV_RMS + rows, row_mask, other=0.0).to(tl.float32)
304 dy = tl.load(
305 DY + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c,
306 row_mask[:, None] & col_mask[None, :],
307 other=0.0,
308 ).to(tl.float32)
310 d_weight = x * dy * inv_rms[:, None]
311 # Sum over rows (axis=0) - masked rows are 0 (from other=0.0 in load), so sum is correct
312 # The mask ensures invalid rows contribute 0 to the sum
313 partial_dweight_sum = tl.sum(d_weight, axis=0)
315 tl.store(
316 DW + row_pid * N + col_start + cols,
317 partial_dweight_sum,
318 mask=col_mask,
319 )
322def rms_norm_out(result, x, normalized_shape, weight, eps=1e-5):
323 y, _ = rms_norm_forward(x, normalized_shape, weight, eps=eps)
324 result.copy_(y)
325 return result
328def rms_norm_forward(x, normalized_shape, weight, eps=1e-5):
329 logger.debug("GEMS RMS_NORM FORWARD")
330 dim = x.ndim - len(normalized_shape)
331 M = math.prod(x.shape[:dim])
332 N = math.prod(normalized_shape)
334 BLOCK_SIZE = triton.next_power_of_2(N)
335 x = x.contiguous()
336 weight = weight.contiguous()
337 y = torch.empty_like(x)
338 inv_rms = torch.empty((M,), device=x.device, dtype=torch.float32)
340 with torch_device_fn.device(x.device):
341 if BLOCK_SIZE <= 512: # [Sunrise] 2d load works for block_size < 512
342 TILE_M = triton.cdiv(1024, BLOCK_SIZE)
343 grid = (triton.cdiv(M, TILE_M),)
344 rms_norm_2d_kernel[grid](
345 y, inv_rms, x, weight, M, N, eps, TILE_M, BLOCK_SIZE
346 )
347 elif BLOCK_SIZE <= 1024:
348 rms_norm_kernel[M,](y, inv_rms, x, weight, N, 1, N, 1, N, eps, BLOCK_SIZE)
349 else:
350 BLOCK_SIZE = 1024
351 rms_norm_c_split_kernel[M,](
352 y, inv_rms, x, weight, N, 1, N, 1, N, eps, BLOCK_SIZE, num_warps=16
353 )
354 return y, inv_rms
357def rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps=1e-5):
358 logger.debug("GEMS RMS_NORM BACKWARD")
359 dim = x.ndim - len(normalized_shape)
360 M = math.prod(x.shape[:dim])
361 N = math.prod(normalized_shape)
363 BLOCK_SIZE = min(triton.next_power_of_2(N), 1024)
364 x = x.contiguous()
365 dy = dy.contiguous()
366 weight = weight.contiguous()
367 dx = torch.empty_like(x)
369 with torch_device_fn.device(x.device):
370 rms_norm_grad_dx_kernel[M,](
371 x, dy, inv_rms, dx, weight, N, 1, N, 1, N, eps, BLOCK_SIZE
372 )
374 ROW_BLOCK_SIZE = 16
375 COL_BLOCK_SIZE = 256
376 row_block_num = triton.cdiv(M, ROW_BLOCK_SIZE)
377 col_block_num = triton.cdiv(N, COL_BLOCK_SIZE)
379 partial_buffer = torch.empty(
380 (row_block_num, N), dtype=torch.float32, device=x.device
381 )
383 with torch_device_fn.device(x.device):
384 rms_norm_grad_dw_kernel[row_block_num, col_block_num](
385 x,
386 dy,
387 inv_rms,
388 partial_buffer,
389 N,
390 1,
391 N,
392 1,
393 M,
394 N,
395 ROW_BLOCK_SIZE,
396 COL_BLOCK_SIZE,
397 )
398 dw = (
399 torch.sum(partial_buffer, dim=0, dtype=torch.float32)
400 .to(x.dtype)
401 .reshape(-1)
402 )
404 return dx, dw
407class RmsNorm(torch.autograd.Function):
408 @staticmethod
409 def forward(ctx, x, normalized_shape, weight, eps=1e-5):
410 y, inv_rms = rms_norm_forward(x, normalized_shape, weight, eps)
411 ctx.save_for_backward(x, inv_rms, weight)
412 ctx.normalized_shape = normalized_shape
413 ctx.eps = eps
414 return y
416 @staticmethod
417 def backward(ctx, dy):
418 x, inv_rms, weight = ctx.saved_tensors
419 normalized_shape = ctx.normalized_shape
420 eps = ctx.eps
422 dx, dw = rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps)
423 return dx, None, dw, None
426def rms_norm(x, normalized_shape, weight, eps=1e-5):
427 return RmsNorm.apply(x, normalized_shape, weight, eps)