Coverage for src/flag_gems/runtime/backend/_sunrise/ops/layernorm.py: 0%
243 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as ext
13logger = logging.getLogger(__name__)
16@triton.jit
17def prev_multiple_of(a, b):
18 # the largest x<a that x%b ==0
19 return tl.cdiv(a, b) * b - b
22@libentry()
23@triton.autotune(
24 configs=runtime.get_tuned_config("layer_norm_persistent"),
25 key=["M", "N"],
26)
27@triton.jit(do_not_specialize=["eps"])
28def layer_norm_persistent_kernel(
29 in_ptr,
30 out_ptr,
31 weight_ptr,
32 bias_ptr,
33 out_mean_ptr, # pointer to the mean
34 out_rstd_ptr, # pointer to the 1/std
35 M,
36 N,
37 eps,
38 TILE_N: tl.constexpr,
39):
40 # using 1d tile makes code clean
41 # Map the program id to the row of X and Y it should compute.
42 pid = ext.program_id(0)
44 n_offsets = tl.arange(0, TILE_N)
45 mask = n_offsets < N
47 x = tl.load(in_ptr + pid * N + n_offsets, mask, other=0.0).to(tl.float32)
48 m = tl.sum(x) / N
49 d = x - m # deviation
50 s = tl.where(mask, d * d, 0)
51 sum_square = tl.sum(s) # sum of square of deviation
52 var = sum_square / N
53 rstd = tl.math.rsqrt(var + eps)
55 tl.store(out_mean_ptr + pid, m)
56 tl.store(out_rstd_ptr + pid, rstd)
58 if weight_ptr is None:
59 w = 1
60 else:
61 w = tl.load(weight_ptr + n_offsets, mask=mask)
62 if bias_ptr is None:
63 b = 0
64 else:
65 b = tl.load(bias_ptr + n_offsets, mask=mask)
66 out = (x - m) * rstd * w + b
68 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask)
71@libentry()
72@triton.autotune(
73 configs=runtime.get_tuned_config("layer_norm_persistent"),
74 key=["M", "N"],
75)
76@triton.jit(do_not_specialize=["eps"])
77def layer_norm_persistent_kernel_multiline(
78 in_ptr,
79 out_ptr,
80 weight_ptr,
81 bias_ptr,
82 out_mean_ptr, # pointer to the mean
83 out_rstd_ptr, # pointer to the 1/std
84 M,
85 N,
86 eps,
87 TILE_M: tl.constexpr,
88 TILE_N: tl.constexpr,
89):
90 # Map the program id to the row of X and Y it should compute.
91 pid = ext.program_id(0)
92 m_offsets = pid * TILE_M + tl.arange(0, TILE_M)
93 m_mask = m_offsets < M
95 n_offsets = tl.arange(0, TILE_N)[None, :]
96 n_mask = n_offsets < N
97 mask = m_mask[:, None] & n_mask
99 x = tl.load(in_ptr + m_offsets[:, None] * N + n_offsets, mask, other=0.0).to(
100 tl.float32
101 )
102 m = tl.sum(x, axis=1) / N
103 d = x - m[:, None] # deviation
104 s = tl.where(mask, d * d, 0)
105 sum_square = tl.sum(s, axis=1) # sum of square of deviation
106 var = sum_square / N
107 rstd = tl.math.rsqrt(var + eps)
109 tl.store(out_mean_ptr + m_offsets, m, mask=m_mask)
110 tl.store(out_rstd_ptr + m_offsets, rstd, mask=m_mask)
112 if weight_ptr is None:
113 w = 1
114 else:
115 w = tl.load(weight_ptr + n_offsets, mask=n_mask)
116 if bias_ptr is None:
117 b = 0
118 else:
119 b = tl.load(bias_ptr + n_offsets, mask=n_mask)
120 out = (x - m[:, None]) * rstd[:, None] * w + b
122 tl.store(out_ptr + m_offsets[:, None] * N + n_offsets, out, mask=mask)
125@libentry()
126@triton.autotune(
127 configs=runtime.get_tuned_config("layer_norm_loop"),
128 key=["M", "N"],
129)
130@triton.jit(do_not_specialize=["eps"])
131def layer_norm_loop_kernel(
132 in_ptr,
133 out_ptr,
134 weight_ptr,
135 bias_ptr,
136 out_mean_ptr, # pointer to the mean
137 out_rstd_ptr, # pointer to the 1/std
138 M,
139 N,
140 eps,
141 TILE_N: tl.constexpr,
142):
143 # Map the program id to the row of X and Y it should compute.
144 pid = ext.program_id(0)
146 # Compute mean
147 m = tl.zeros((TILE_N,), dtype=tl.float32) # mean
148 s = tl.zeros((TILE_N,), dtype=tl.float32) # sum((x - m)^2)
149 cnt = tl.zeros((TILE_N,), dtype=tl.int32)
150 num_steps = tl.cdiv(N, TILE_N)
151 for step in range(0, num_steps - 1, 1):
152 start_n = step * TILE_N
153 n_offsets = start_n + tl.arange(0, TILE_N)
154 x = tl.load(in_ptr + pid * N + n_offsets).to(tl.float32)
155 new_m = m + (x - m) / (step + 1)
156 new_s = s + (x - new_m) * (x - m)
157 cnt += 1
158 m = new_m
159 s = new_s
161 # the last step
162 for step in range(num_steps - 1, num_steps, 1):
163 start_n = step * TILE_N
164 n_offsets = start_n + tl.arange(0, TILE_N)
165 mask = n_offsets < N
166 x = tl.load(in_ptr + pid * N + n_offsets, mask=mask).to(tl.float32)
167 new_m = tl.where(mask, m + (x - m) / (step + 1), m)
168 new_s = tl.where(mask, s + (x - new_m) * (x - m), s)
169 cnt += mask.to(tl.int32)
170 m = new_m
171 s = new_s
173 final_m = tl.sum(m * cnt) / N
174 var = tl.sum(s + cnt * (m - final_m) * (m - final_m)) / N
175 rstd = tl.math.rsqrt(var + eps)
176 m = final_m
177 # Write mean / rstd
178 tl.store(out_mean_ptr + pid, m)
179 tl.store(out_rstd_ptr + pid, rstd)
181 # reverse the order of the second sweep
182 # Normalize and apply linear transformation
183 prev_multiple = prev_multiple_of(N, TILE_N)
184 # the first step, masking is needed
185 for start_n in range(0, TILE_N, TILE_N):
186 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)
187 mask = n_offsets < N
188 x = tl.load(
189 in_ptr + pid * N + n_offsets,
190 mask=mask,
191 other=0.0,
192 eviction_policy="evict_first",
193 ).to(tl.float32)
194 if weight_ptr is None:
195 w = 1
196 else:
197 w = tl.load(weight_ptr + n_offsets, mask=mask)
198 if bias_ptr is None:
199 b = 0
200 else:
201 b = tl.load(bias_ptr + n_offsets, mask=mask)
202 out = w * (x - m) * rstd + b
203 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask)
205 for start_n in range(TILE_N, N, TILE_N):
206 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)
207 x = tl.load(in_ptr + pid * N + n_offsets, eviction_policy="evict_first").to(
208 tl.float32
209 )
210 if weight_ptr is None:
211 w = 1
212 else:
213 w = tl.load(weight_ptr + n_offsets)
214 if bias_ptr is None:
215 b = 0
216 else:
217 b = tl.load(bias_ptr + n_offsets)
218 out = w * (x - m) * rstd + b
219 tl.store(out_ptr + pid * N + n_offsets, out)
222@libentry()
223@triton.autotune(
224 configs=runtime.get_tuned_config("layer_norm_backward"),
225 key=["M", "N"],
226)
227@triton.jit
228def layer_norm_backward_kernel(
229 dY,
230 X,
231 W,
232 Mean,
233 Rstd,
234 dX,
235 M,
236 N,
237 has_w: tl.constexpr,
238 BLOCK_ROW_SIZE: tl.constexpr,
239 BLOCK_COL_SIZE: tl.constexpr,
240):
241 pid = tl.program_id(0) * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None]
242 row_mask = pid < M
244 dY_ptr = dY + pid * N
245 X_ptr = X + pid * N
246 dX_ptr = dX + pid * N
247 Mean_ptr = Mean + pid
248 Rstd_ptr = Rstd + pid
250 mean = tl.load(Mean_ptr, mask=row_mask).to(tl.float32)
251 rstd = tl.load(Rstd_ptr, mask=row_mask).to(tl.float32)
253 dx_part2 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
254 dx_part3 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
256 for off in range(0, N, BLOCK_COL_SIZE):
257 cols = off + tl.arange(0, BLOCK_COL_SIZE)
258 col_mask = cols[None, :] < N
259 mask = row_mask & col_mask
260 dy = tl.load(dY_ptr + cols[None, :], mask, other=0.0).to(tl.float32)
261 x = tl.load(X_ptr + cols[None, :], mask, other=0.0).to(tl.float32)
262 x = tl.where(mask, x - mean, 0.0)
263 x_hat = x * rstd
264 if has_w:
265 w = tl.load(W + cols, mask=cols < N, other=1.0).to(tl.float32)
266 else:
267 w = 1.0
268 dx_hat = dy * w
269 dx_part2 += dx_hat
270 dx_part3 += dx_hat * x_hat
272 dx_2 = tl.sum(dx_part2, axis=1)[:, None]
273 dx_3 = tl.sum(dx_part3, axis=1)[:, None]
275 for off in range(0, N, BLOCK_COL_SIZE):
276 cols = off + tl.arange(0, BLOCK_COL_SIZE)
277 col_mask = cols[None, :] < N
278 mask = row_mask & col_mask
279 dy = tl.load(dY_ptr + cols[None, :], mask, other=0.0).to(tl.float32)
280 x = tl.load(X_ptr + cols[None, :], mask, other=0.0).to(tl.float32)
281 if has_w:
282 w = tl.load(W + cols, mask=cols < N, other=1.0).to(tl.float32)
283 else:
284 w = 1.0
285 x = tl.where(mask, x - mean, 0.0)
286 x_hat = x * rstd
287 dx_hat = dy * w
288 dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / N)
289 tl.store(dX_ptr + cols[None, :], dx, mask=mask)
292@libentry()
293@triton.autotune(
294 configs=runtime.get_tuned_config("weight_bias_backward"),
295 key=["N"],
296)
297@triton.jit
298def weight_bias_backward_kernel(
299 dY,
300 X,
301 Mean,
302 Rstd,
303 dW,
304 dB,
305 M,
306 N,
307 BLOCK_ROW_SIZE: tl.constexpr,
308 BLOCK_COL_SIZE: tl.constexpr,
309):
310 pid = (
311 ext.program_id(0) * BLOCK_COL_SIZE + tl.arange(0, BLOCK_COL_SIZE)[None, :]
312 ) # triton地址自动广播可能会出现对不齐的情况,所以用到的时候手动广播
313 col_mask = pid < N
314 dY += pid
315 X += pid
316 accW = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
317 accB = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
318 for off in range(0, M, BLOCK_ROW_SIZE):
319 rows = off + tl.arange(0, BLOCK_ROW_SIZE) # triton地址自动广播可能会出现对不齐的情况,所以用到的时候手动广播
320 row_mask = rows[:, None] < M
321 mask = row_mask & col_mask
322 dy = tl.load(dY + rows[:, None] * N, mask).to(tl.float32)
323 x = tl.load(X + rows[:, None] * N, mask).to(tl.float32)
324 mean = tl.load(Mean + rows, mask=rows < M)[:, None].to(tl.float32)
325 rstd = tl.load(Rstd + rows, mask=rows < M)[:, None].to(tl.float32)
326 x = tl.where(mask, x - mean, 0.0)
327 x_hat = x * rstd
328 accW += dy * x_hat
329 accB += dy
330 if dW:
331 dw = tl.sum(accW, axis=0)
332 tl.store(dW + pid, dw[None, :], mask=col_mask)
333 if dB:
334 db = tl.sum(accB, axis=0)
335 tl.store(dB + pid, db[None, :], mask=col_mask)
338def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
339 logger.debug("GEMS LAYERNORM FORWARD")
341 N = math.prod(normalized_shape)
342 M = input.numel() // N
344 input = input.contiguous()
345 weight = None if weight is None else weight.contiguous()
346 bias = None if bias is None else bias.contiguous()
347 y = torch.empty(input.shape, dtype=input.dtype).to(device=input.device)
349 # NOTE: when the input is half-precision(either float16 or bfloat16)
350 # these statistical data saved for backward is in single precision
351 mean = torch.empty(M, dtype=input.dtype, device=input.device)
352 rstd = torch.empty(M, dtype=input.dtype, device=input.device)
354 with torch_device_fn.device(input.device):
355 if N <= 128:
356 TILE_N = triton.next_power_of_2(N)
357 TILE_M = triton.cdiv(1024, TILE_N)
358 grid = (triton.cdiv(M, TILE_M), 1, 1)
359 layer_norm_persistent_kernel_multiline[grid](
360 input,
361 y,
362 weight,
363 bias,
364 mean,
365 rstd,
366 M,
367 N,
368 eps,
369 TILE_M,
370 TILE_N,
371 )
372 elif N <= 4096:
373 TILE_N = triton.next_power_of_2(N)
374 grid = (M, 1, 1)
375 layer_norm_persistent_kernel[grid](
376 input,
377 y,
378 weight,
379 bias,
380 mean,
381 rstd,
382 M,
383 N,
384 eps,
385 TILE_N,
386 )
387 else:
388 grid = (M, 1, 1)
389 layer_norm_loop_kernel[grid](
390 input,
391 y,
392 weight,
393 bias,
394 mean,
395 rstd,
396 M,
397 N,
398 eps,
399 )
400 return y, mean, rstd
403def layer_norm_backward(
404 grad_out,
405 input,
406 normalized_shape,
407 mean,
408 rstd,
409 weight=None,
410 bias=None,
411 output_mask=None,
412):
413 logger.debug("GEMS LAYERNORM BACKWARD")
415 grad_out = grad_out.contiguous()
416 input = input.contiguous()
417 mean = mean.contiguous()
418 rstd = rstd.contiguous()
419 weight = None if weight is None else weight.contiguous()
420 bias = None if bias is None else bias.contiguous()
422 M = input.shape[0]
423 N = input.numel() // M
425 if output_mask[0]:
426 in_grad = torch.empty(input.shape, dtype=input.dtype).to(device=input.device)
427 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_ROW_SIZE"]), 1, 1)
428 has_w = 1 if weight is not None else 0
429 with torch_device_fn.device(input.device):
430 layer_norm_backward_kernel[grid](
431 grad_out, input, weight, mean, rstd, in_grad, M, N, has_w
432 )
433 else:
434 in_grad = None
436 if output_mask[1] is False and output_mask[2] is False:
437 return in_grad, None, None
439 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_COL_SIZE"]), 1, 1)
440 weight_grad = (
441 torch.empty(weight.shape, dtype=weight.dtype).to(device=weight.device)
442 if output_mask[1]
443 else None
444 )
445 bias_grad = (
446 torch.empty(bias.shape, dtype=bias.dtype).to(device=bias.device)
447 if output_mask[2]
448 else None
449 )
450 with torch_device_fn.device(input.device):
451 weight_bias_backward_kernel[grid](
452 grad_out, input, mean, rstd, weight_grad, bias_grad, M, N
453 )
454 return in_grad, weight_grad, bias_grad