Coverage for src/flag_gems/runtime/backend/_sunrise/ops/rms_norm.py: 0%
179 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as ext
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@libentry()
16@triton.jit(do_not_specialize=["eps"])
17def rms_norm_kernel(
18 out_ptr, # pointer to the output
19 INV_RMS, # pointer to inverse rms
20 in_ptr, # pointer to the input
21 w_ptr, # pointer to the weights
22 y_stride_r,
23 y_stride_c,
24 x_stride_r, # how much to increase the pointer when moving by 1 row
25 x_stride_c, # how much to increase the pointer when moving by 1 col
26 N, # number of columns in X
27 eps, # epsilon to avoid division by zero
28 BLOCK_SIZE: tl.constexpr,
29):
30 if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
31 in_ptr.dtype.element_ty == tl.bfloat16
32 ):
33 cdtype = tl.float32
34 else:
35 cdtype = in_ptr.dtype.element_ty
37 pid = tl.program_id(0)
38 out_ptr += pid * y_stride_r
39 in_ptr += pid * x_stride_r
41 mask = tl.arange(0, BLOCK_SIZE) < N
42 cols = tl.arange(0, BLOCK_SIZE)
43 x = tl.load(in_ptr + cols * x_stride_c, mask, other=0.0).to(cdtype)
45 var = tl.sum(x * x, axis=0) / N
46 rrms = 1 / tl.sqrt(var + eps)
48 w = tl.load(w_ptr + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)
49 y = (x * rrms * w).to(cdtype)
50 tl.store(out_ptr + cols * y_stride_c, y, mask=mask)
51 tl.store(INV_RMS + pid, rrms)
54@libentry()
55@triton.jit(do_not_specialize=["eps"])
56def rms_norm_2d_kernel(
57 out_ptr,
58 INV_RMS,
59 in_ptr,
60 w_ptr,
61 M,
62 N,
63 eps,
64 TILE_M: tl.constexpr,
65 BLOCK_N: tl.constexpr,
66):
67 if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
68 in_ptr.dtype.element_ty == tl.bfloat16
69 ):
70 cdtype = tl.float32
71 else:
72 cdtype = in_ptr.dtype.element_ty
74 pid = tl.program_id(0)
75 m_offsets = pid * TILE_M + tl.arange(0, TILE_M)
76 m_mask = m_offsets < M
77 cols = tl.arange(0, BLOCK_N)
78 mask = m_mask[:, None] & (cols[None, :] < N)
80 x = tl.load(in_ptr + m_offsets[:, None] * N + cols[None, :], mask, other=0.0).to(
81 cdtype
82 )
83 var = tl.sum(x * x, axis=1) / N
84 rrms = 1 / tl.sqrt(var + eps)
86 w = tl.load(w_ptr + cols, mask=cols < N, other=0.0)
87 y = (x * rrms[:, None] * w[None, :]).to(cdtype)
88 tl.store(out_ptr + m_offsets[:, None] * N + cols[None, :], y, mask=mask)
89 tl.store(INV_RMS + m_offsets, rrms, mask=m_mask)
92@libentry()
93@triton.jit(do_not_specialize=["eps"])
94def rms_norm_c_split_kernel(
95 out_ptr, # pointer to the output
96 INV_RMS, # pointer to inverse rms
97 in_ptr, # pointer to the input
98 w_ptr, # pointer to the weights
99 y_stride_r,
100 y_stride_c,
101 x_stride_r, # how much to increase the pointer when moving by 1 row
102 x_stride_c, # how much to increase the pointer when moving by 1 col
103 N, # number of columns in X
104 eps, # epsilon to avoid division by zero
105 BLOCK_SIZE: tl.constexpr,
106):
107 if tl.constexpr(in_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
108 in_ptr.dtype.element_ty == tl.bfloat16
109 ):
110 cdtype = tl.float32
111 else:
112 cdtype = in_ptr.dtype.element_ty
114 pid = tl.program_id(0)
115 out_ptr += pid * y_stride_r
116 in_ptr += pid * x_stride_r
118 var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
119 for n_idx in range(0, N, BLOCK_SIZE):
120 cols = n_idx + tl.arange(0, BLOCK_SIZE)
121 mask = cols < N
122 x = tl.load(in_ptr + cols * x_stride_c, mask, other=0.0).to(cdtype)
123 var += x * x
125 var = tl.sum(var, axis=0) / N
126 rrms = 1 / tl.sqrt(var + eps)
128 for n_idx in range(0, N, BLOCK_SIZE):
129 cols = n_idx + tl.arange(0, BLOCK_SIZE)
130 mask = cols < N
131 w = tl.load(w_ptr + cols, mask=mask, other=0.0)
132 x = tl.load(in_ptr + cols * x_stride_c, mask, other=0.0).to(cdtype)
133 y = (x * rrms * w).to(cdtype)
134 tl.store(out_ptr + cols * y_stride_c, y, mask=mask)
135 tl.store(INV_RMS + pid, rrms)
138@libentry()
139@triton.jit(do_not_specialize=["eps"])
140def rms_norm_grad_dx_kernel(
141 X, # pointer to the input
142 DY,
143 INV_RMS, # pointer to inverse rms
144 DX, # pointer to the output
145 W, # pointer to the weights
146 dx_stride_r,
147 dx_stride_c,
148 x_stride_r, # how much to increase the pointer when moving by 1 row
149 x_stride_c, # how much to increase the pointer when moving by 1 col
150 N, # number of columns in X
151 eps, # epsilon to avoid division by zero
152 BLOCK_SIZE: tl.constexpr,
153):
154 pid = ext.program_id(0)
155 DX += pid * dx_stride_r
156 X += pid * x_stride_r
157 DY += pid * x_stride_r
158 INV_RMS += pid
160 inv_rms = tl.load(INV_RMS).to(tl.float32)
162 row_sum_stats = 0.0
163 for off in range(0, N, BLOCK_SIZE):
164 cols = off + tl.arange(0, BLOCK_SIZE)
165 mask = cols < N
166 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
167 dy = tl.load(DY + cols, mask, other=0.0).to(tl.float32)
168 w = tl.load(W + cols, mask, other=0.0).to(tl.float32)
169 dy = dy * w
170 normalized_buf = x * inv_rms
171 row_sum_stats += tl.sum(normalized_buf * dy)
173 for off in range(0, N, BLOCK_SIZE):
174 cols = off + tl.arange(0, BLOCK_SIZE)
175 mask = cols < N
176 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
177 dy = tl.load(DY + cols, mask, other=0.0).to(tl.float32)
178 w = tl.load(W + cols, mask, other=0.0).to(tl.float32)
179 dy = dy * w
180 normalized_buf = x * inv_rms
181 norm_val = normalized_buf / N
182 dx = (dy - norm_val * row_sum_stats) * inv_rms
183 tl.store(DX + cols * dx_stride_c, dx, mask=mask)
186@libentry()
187@triton.jit
188def rms_norm_grad_dw_kernel(
189 X, # pointer to the input
190 DY,
191 INV_RMS, # pointer to inverse rms
192 DW, # pointer to the output
193 dx_stride_r,
194 dx_stride_c,
195 x_stride_r, # how much to increase the pointer when moving by 1 row
196 x_stride_c, # how much to increase the pointer when moving by 1 col
197 M, # number of rows in X
198 N, # number of columns in X
199 ROW_BLOCK_SIZE: tl.constexpr,
200 COL_BLOCK_SIZE: tl.constexpr,
201):
202 row_pid = tl.program_id(0)
203 col_pid = tl.program_id(1)
205 row_start = row_pid * ROW_BLOCK_SIZE
206 col_start = col_pid * COL_BLOCK_SIZE
208 offset = row_start * x_stride_r + col_start * x_stride_c
209 X += offset
210 DY += offset
211 INV_RMS += row_start
213 rows = tl.arange(0, ROW_BLOCK_SIZE)
214 cols = tl.arange(0, COL_BLOCK_SIZE)
216 row_mask = (row_start + rows) < M
217 col_mask = (col_start + cols) < N
219 x = tl.load(
220 X + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c,
221 row_mask[:, None] & col_mask[None, :],
222 other=0.0,
223 ).to(tl.float32)
224 inv_rms = tl.load(INV_RMS + rows, row_mask, other=0.0).to(tl.float32)
225 dy = tl.load(
226 DY + rows[:, None] * x_stride_r + cols[None, :] * x_stride_c,
227 row_mask[:, None] & col_mask[None, :],
228 other=0.0,
229 ).to(tl.float32)
231 d_weight = x * dy * inv_rms[:, None]
232 # Sum over rows (axis=0) - masked rows are 0 (from other=0.0 in load), so sum is correct
233 # The mask ensures invalid rows contribute 0 to the sum
234 partial_dweight_sum = tl.sum(d_weight, axis=0)
236 tl.store(
237 DW + row_pid * N + col_start + cols,
238 partial_dweight_sum,
239 mask=col_mask,
240 )
243def rms_norm_forward(x, normalized_shape, weight, eps=1e-5):
244 logger.debug("GEMS RMS_NORM FORWARD")
245 dim = x.ndim - len(normalized_shape)
246 M = math.prod(x.shape[:dim])
247 N = math.prod(normalized_shape)
249 BLOCK_SIZE = triton.next_power_of_2(N)
250 x = x.contiguous()
251 weight = weight.contiguous()
252 y = torch.empty_like(x)
253 inv_rms = torch.empty((M,), device=x.device, dtype=torch.float32)
255 with torch_device_fn.device(x.device):
256 if BLOCK_SIZE <= 512: # [Sunrise] 2d load works for block_size < 512
257 TILE_M = triton.cdiv(1024, BLOCK_SIZE)
258 grid = (triton.cdiv(M, TILE_M),)
259 rms_norm_2d_kernel[grid](
260 y, inv_rms, x, weight, M, N, eps, TILE_M, BLOCK_SIZE
261 )
262 elif BLOCK_SIZE <= 1024:
263 rms_norm_kernel[M,](y, inv_rms, x, weight, N, 1, N, 1, N, eps, BLOCK_SIZE)
264 else:
265 BLOCK_SIZE = 1024
266 rms_norm_c_split_kernel[M,](
267 y, inv_rms, x, weight, N, 1, N, 1, N, eps, BLOCK_SIZE, num_warps=16
268 )
269 return y, inv_rms
272def rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps=1e-5):
273 logger.debug("GEMS RMS_NORM BACKWARD")
274 dim = x.ndim - len(normalized_shape)
275 M = math.prod(x.shape[:dim])
276 N = math.prod(normalized_shape)
278 BLOCK_SIZE = min(triton.next_power_of_2(N), 1024)
279 x = x.contiguous()
280 dy = dy.contiguous()
281 weight = weight.contiguous()
282 dx = torch.empty_like(x)
284 with torch_device_fn.device(x.device):
285 rms_norm_grad_dx_kernel[M,](
286 x, dy, inv_rms, dx, weight, N, 1, N, 1, N, eps, BLOCK_SIZE
287 )
289 ROW_BLOCK_SIZE = 16
290 COL_BLOCK_SIZE = 256
291 row_block_num = triton.cdiv(M, ROW_BLOCK_SIZE)
292 col_block_num = triton.cdiv(N, COL_BLOCK_SIZE)
294 partial_buffer = torch.empty(
295 (row_block_num, N), dtype=torch.float32, device=x.device
296 )
298 with torch_device_fn.device(x.device):
299 rms_norm_grad_dw_kernel[row_block_num, col_block_num](
300 x,
301 dy,
302 inv_rms,
303 partial_buffer,
304 N,
305 1,
306 N,
307 1,
308 M,
309 N,
310 ROW_BLOCK_SIZE,
311 COL_BLOCK_SIZE,
312 )
313 dw = (
314 torch.sum(partial_buffer, dim=0, dtype=torch.float32)
315 .to(x.dtype)
316 .reshape(-1)
317 )
319 return dx, dw
322class RmsNorm(torch.autograd.Function):
323 @staticmethod
324 def forward(ctx, x, normalized_shape, weight, eps=1e-5):
325 y, inv_rms = rms_norm_forward(x, normalized_shape, weight, eps)
326 ctx.save_for_backward(x, inv_rms, weight)
327 ctx.normalized_shape = normalized_shape
328 ctx.eps = eps
329 return y
331 @staticmethod
332 def backward(ctx, dy):
333 x, inv_rms, weight = ctx.saved_tensors
334 normalized_shape = ctx.normalized_shape
335 eps = ctx.eps
337 dx, dw = rms_norm_backward(dy, x, inv_rms, normalized_shape, weight, eps)
338 return dx, None, dw, None
341def rms_norm(x, normalized_shape, weight, eps=1e-5):
342 return RmsNorm.apply(x, normalized_shape, weight, eps)