Coverage for src/flag_gems/ops/rms_norm.py: 33%
166 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as ext
13logger = logging.getLogger(__name__)
16@triton.jit
17def prev_multiple_of(a, b):
18 return tl.cdiv(a, b) * b - b
21@libentry()
22@triton.jit(do_not_specialize=["eps"])
23def rms_norm_kernel(
24 out_ptr, # pointer to the output
25 INV_RMS, # pointer to inverse rms
26 in_ptr, # pointer to the input
27 w_ptr, # pointer to the weights
28 y_stride_r,
29 y_stride_c,
30 x_stride_r, # how much to increase the pointer when moving by 1 row
31 x_stride_c, # how much to increase the pointer when moving by 1 col
32 N, # number of columns in X
33 eps, # epsilon to avoid division by zero
34 BLOCK_SIZE: tl.constexpr,
35):
36 if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
37 in_ptr.dtype.element_ty == tl.bfloat16
38 ):
39 cdtype = tl.float32
40 else:
41 cdtype = in_ptr.dtype.element_ty
43 pid = tl.program_id(0)
44 out_ptr += pid * y_stride_r
45 in_ptr += pid * x_stride_r
47 mask = tl.arange(0, BLOCK_SIZE) < N
48 cols = tl.arange(0, BLOCK_SIZE)
49 x = tl.load(in_ptr + cols * x_stride_c, mask, other=0.0).to(cdtype)
51 var = tl.sum(x * x, axis=0) / N
52 rrms = 1 / tl.sqrt(var + eps)
54 w = tl.load(w_ptr + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)
55 y = (x * rrms * w).to(cdtype)
56 tl.store(out_ptr + cols * y_stride_c, y, mask=mask)
57 tl.store(INV_RMS + pid, rrms)
60@libentry()
61@triton.autotune(
62 configs=runtime.get_tuned_config("rms_norm_loop"),
63 key=["N"],
64)
65@triton.jit(do_not_specialize=["eps"])
66def rms_norm_loop_kernel(
67 out_ptr,
68 INV_RMS,
69 in_ptr,
70 w_ptr,
71 N,
72 eps,
73 TILE_N: tl.constexpr,
74):
75 if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
76 in_ptr.dtype.element_ty == tl.bfloat16
77 ):
78 cdtype = tl.float32
79 else:
80 cdtype = in_ptr.dtype.element_ty
82 pid = ext.program_id(0)
84 # Pass 1: compute sum(x^2) in chunks
85 acc = tl.zeros((TILE_N,), dtype=tl.float32)
86 num_steps = tl.cdiv(N, TILE_N)
88 for step in range(0, num_steps - 1):
89 start_n = step * TILE_N
90 n_offsets = start_n + tl.arange(0, TILE_N)
91 x = tl.load(in_ptr + pid * N + n_offsets).to(tl.float32)
92 acc += x * x
94 # last step with mask
95 start_n = (num_steps - 1) * TILE_N
96 n_offsets = start_n + tl.arange(0, TILE_N)
97 mask = n_offsets < N
98 x = tl.load(in_ptr + pid * N + n_offsets, mask=mask, other=0.0).to(tl.float32)
99 acc += x * x
101 var = tl.sum(acc) / N
102 rrms = 1 / tl.sqrt(var + eps)
103 tl.store(INV_RMS + pid, rrms)
105 # Pass 2: normalize in reverse order (better L2 cache reuse)
106 prev_multiple = prev_multiple_of(N, TILE_N)
108 # first reverse step with mask
109 for start_n in range(0, TILE_N, TILE_N):
110 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)
111 mask = n_offsets < N
112 x = tl.load(
113 in_ptr + pid * N + n_offsets,
114 mask=mask,
115 other=0.0,
116 eviction_policy="evict_first",
117 ).to(cdtype)
118 w = tl.load(w_ptr + n_offsets, mask=mask, other=0.0)
119 y = (x * rrms * w).to(cdtype)
120 tl.store(out_ptr + pid * N + n_offsets, y, mask=mask)
122 for start_n in range(TILE_N, N, TILE_N):
123 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)
124 x = tl.load(
125 in_ptr + pid * N + n_offsets,
126 eviction_policy="evict_first",
127 ).to(cdtype)
128 w = tl.load(w_ptr + n_offsets)
129 y = (x * rrms * w).to(cdtype)
130 tl.store(out_ptr + pid * N + n_offsets, y)
133@libentry()
134@triton.jit(do_not_specialize=["eps"])
135def rms_norm_grad_dx_kernel(
136 X, # pointer to the input
137 DY,
138 INV_RMS, # pointer to inverse rms
139 DX, # pointer to the output
140 W, # pointer to the weights
141 dx_stride_r,
142 dx_stride_c,
143 x_stride_r, # how much to increase the pointer when moving by 1 row
144 x_stride_c, # how much to increase the pointer when moving by 1 col
145 N, # number of columns in X
146 eps, # epsilon to avoid division by zero
147 BLOCK_SIZE: tl.constexpr,
148):
149 pid = ext.program_id(0)
150 DX += pid * dx_stride_r
151 X += pid * x_stride_r
152 DY += pid * x_stride_r
153 INV_RMS += pid
155 mask = tl.arange(0, BLOCK_SIZE) < N
156 cols = tl.arange(0, BLOCK_SIZE)
157 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)
158 inv_rms = tl.load(INV_RMS).to(tl.float32)
159 dy = tl.load(DY + cols * x_stride_c, mask, other=0.0).to(tl.float32)
160 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)
162 dy = dy * w
164 normalized_buf = x * inv_rms
165 row_sum_stats = tl.sum(normalized_buf * dy, axis=0)
167 norm_val = normalized_buf / N
168 dx = (dy - norm_val * row_sum_stats) * inv_rms
170 tl.store(DX + cols * dx_stride_c, dx, mask=mask)
173@libentry()
174@triton.jit
175def rms_norm_grad_dw_kernel(
176 X, # pointer to the input
177 DY,
178 INV_RMS, # pointer to inverse rms
179 DW, # pointer to the output
180 dx_stride_r,
181 dx_stride_c,
182 x_stride_r, # how much to increase the pointer when moving by 1 row
183 x_stride_c, # how much to increase the pointer when moving by 1 col
184 M, # number of rows in X
185 N, # number of columns in X
186 ROW_BLOCK_SIZE: tl.constexpr,
187 COL_BLOCK_SIZE: tl.constexpr,
188):
189 row_pid = tl.program_id(0)
190 col_pid = tl.program_id(1)
192 row_start = row_pid * ROW_BLOCK_SIZE
193 col_start = col_pid * COL_BLOCK_SIZE
195 offset = row_start * x_stride_r + col_start * x_stride_c
196 X += offset
197 DY += offset
198 INV_RMS += row_start
200 rows = tl.arange(0, ROW_BLOCK_SIZE)
201 cols = tl.arange(0, COL_BLOCK_SIZE)
203 row_mask = (row_start + rows) < M
204 col_mask = (col_start + cols) < N
206 x = tl.load(
207 X + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c,
208 row_mask[:, None] & col_mask[None, :],
209 other=0.0,
210 ).to(tl.float32)
211 inv_rms = tl.load(INV_RMS + rows, row_mask, other=0.0).to(tl.float32)
212 dy = tl.load(
213 DY + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c,
214 row_mask[:, None] & col_mask[None, :],
215 other=0.0,
216 ).to(tl.float32)
218 d_weight = x * dy * inv_rms[:, None]
219 # Sum over rows (axis=0) - masked rows are 0 (from other=0.0 in load), so sum is correct
220 # The mask ensures invalid rows contribute 0 to the sum
221 partial_dweight_sum = tl.sum(d_weight, axis=0)
223 tl.store(
224 DW + row_pid * N + col_start + cols,
225 partial_dweight_sum,
226 mask=col_mask,
227 )
230def rms_norm_out(result, x, normalized_shape, weight, eps=1e-5):
231 y, _ = rms_norm_forward(x, normalized_shape, weight, eps=eps)
232 result.copy_(y)
233 return result
236def rms_norm_forward(x, normalized_shape, weight, eps=1e-5):
237 logger.debug("GEMS RMS_NORM FORWARD")
238 dim = x.ndim - len(normalized_shape)
239 M = math.prod(x.shape[:dim])
240 N = math.prod(normalized_shape)
242 x = x.contiguous()
243 weight = weight.contiguous()
244 y = torch.empty_like(x)
245 inv_rms = torch.empty((M,), device=x.device, dtype=torch.float32)
247 with torch_device_fn.device(x.device):
248 if N <= 4096:
249 BLOCK_SIZE = triton.next_power_of_2(N)
250 rms_norm_kernel[M,](y, inv_rms, x, weight, N, 1, N, 1, N, eps, BLOCK_SIZE)
251 else:
252 rms_norm_loop_kernel[M,](y, inv_rms, x, weight, N, eps)
254 return y, inv_rms
257def rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps=1e-5):
258 logger.debug("GEMS RMS_NORM BACKWARD")
259 dim = x.ndim - len(normalized_shape)
260 M = math.prod(x.shape[:dim])
261 N = math.prod(normalized_shape)
263 BLOCK_SIZE = triton.next_power_of_2(N)
264 x = x.contiguous()
265 dy = dy.contiguous()
266 weight = weight.contiguous()
267 dx = torch.empty_like(x)
269 with torch_device_fn.device(x.device):
270 rms_norm_grad_dx_kernel[M,](
271 x, dy, inv_rms, dx, weight, N, 1, N, 1, N, eps, BLOCK_SIZE
272 )
274 ROW_BLOCK_SIZE = 16
275 COL_BLOCK_SIZE = 256
276 row_block_num = triton.cdiv(M, ROW_BLOCK_SIZE)
277 col_block_num = triton.cdiv(N, COL_BLOCK_SIZE)
279 partial_buffer = torch.empty(
280 (row_block_num, N), dtype=torch.float32, device=x.device
281 )
283 with torch_device_fn.device(x.device):
284 rms_norm_grad_dw_kernel[row_block_num, col_block_num](
285 x,
286 dy,
287 inv_rms,
288 partial_buffer,
289 N,
290 1,
291 N,
292 1,
293 M,
294 N,
295 ROW_BLOCK_SIZE,
296 COL_BLOCK_SIZE,
297 )
298 dw = (
299 torch.sum(partial_buffer, dim=0, dtype=torch.float32)
300 .to(x.dtype)
301 .reshape(-1)
302 )
304 return dx, dw
307class RmsNorm(torch.autograd.Function):
308 @staticmethod
309 def forward(ctx, x, normalized_shape, weight, eps=1e-5):
310 y, inv_rms = rms_norm_forward(x, normalized_shape, weight, eps)
311 ctx.save_for_backward(x, inv_rms, weight)
312 ctx.normalized_shape = normalized_shape
313 ctx.eps = eps
314 return y
316 @staticmethod
317 def backward(ctx, dy):
318 x, inv_rms, weight = ctx.saved_tensors
319 normalized_shape = ctx.normalized_shape
320 eps = ctx.eps
322 dx, dw = rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps)
323 return dx, None, dw, None
326def rms_norm(x, normalized_shape, weight, eps=1e-5):
327 return RmsNorm.apply(x, normalized_shape, weight, eps)