Coverage for src/flag_gems/runtime/backend/_sunrise/fused/skip_layernorm.py: 0%
85 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as ext
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@libentry()
16@triton.jit(do_not_specialize=["eps"])
17def skip_layer_norm_kernel(
18 Y, # pointer to the output
19 X, # pointer to the input
20 R, # pointer to the residual
21 W, # pointer to the weights
22 B, # pointer to the biases
23 y_stride_r,
24 y_stride_c,
25 x_stride_r, # how much to increase the pointer when moving by 1 row
26 x_stride_c, # how much to increase the pointer when moving by 1 col
27 r_stride_r, # how much to increase the pointer when moving by 1 row
28 r_stride_c, # how much to increase the pointer when moving by 1 col
29 N, # number of columns in X
30 eps, # epsilon to avoid division by zero
31 BLOCK_SIZE: tl.constexpr,
32):
33 pid = ext.program_id(0)
34 Y += pid * y_stride_r
35 X += pid * x_stride_r
36 R += pid * r_stride_r
38 mask = tl.arange(0, BLOCK_SIZE) < N
39 cols = tl.arange(0, BLOCK_SIZE)
40 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)
41 r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32)
43 x += r
45 mean = tl.sum(x, axis=0) / N
47 # Compute variance
48 _var = tl.where(mask, x - mean, 0.0)
49 _var = _var * _var
50 var = tl.sum(_var, axis=0) / N
51 rstd = 1 / tl.sqrt(var + eps)
53 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0).to(tl.float32)
54 b = tl.load(B + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0).to(tl.float32)
56 x_hat = (x - mean) * rstd
57 y = w * x_hat + b
58 y = y.to(Y.dtype.element_ty)
59 tl.store(Y + cols * y_stride_c, y, mask=mask)
62@libentry()
63@triton.jit(do_not_specialize=["eps"])
64def skip_layer_norm_c_split_kernel(
65 Y, # pointer to the output
66 X, # pointer to the input
67 R, # pointer to the residual
68 W, # pointer to the weights
69 B, # pointer to the biases
70 y_stride_r,
71 y_stride_c,
72 x_stride_r, # how much to increase the pointer when moving by 1 row
73 x_stride_c, # how much to increase the pointer when moving by 1 col
74 r_stride_r, # how much to increase the pointer when moving by 1 row
75 r_stride_c, # how much to increase the pointer when moving by 1 col
76 N, # number of columns in X
77 eps, # epsilon to avoid division by zero
78 BLOCK_SIZE: tl.constexpr,
79):
80 pid = ext.program_id(0)
81 Y += pid * y_stride_r
82 X += pid * x_stride_r
83 R += pid * r_stride_r
85 _sum = tl.zeros((), dtype=tl.float32)
86 _var = tl.zeros((), dtype=tl.float32)
88 for off in range(0, N, BLOCK_SIZE):
89 cols = off + tl.arange(0, BLOCK_SIZE)
90 mask = cols < N
91 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
92 r = tl.load(R + cols, mask, other=0.0).to(tl.float32)
93 x += r
94 _sum += tl.sum(x, axis=0)
95 _var += tl.sum(x * x, axis=0)
97 mean = _sum / N
98 var = (_var / N) - (mean * mean)
99 rstd = 1 / tl.sqrt(var + eps)
101 for off in range(0, N, BLOCK_SIZE):
102 cols = off + tl.arange(0, BLOCK_SIZE)
103 mask = cols < N
104 w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
105 b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
106 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
107 r = tl.load(R + cols, mask, other=0.0).to(tl.float32)
108 x += r
109 x_hat = (x - mean) * rstd
110 y = w * x_hat + b
111 y = y.to(Y.dtype.element_ty)
112 tl.store(Y + cols * y_stride_c, y, mask=mask)
115class SkipLayerNorm(torch.autograd.Function):
116 @staticmethod
117 def forward(ctx, x, residual, normalized_shape, weight, bias, eps=1e-5):
118 logger.debug("GEMS SKIP LAYERNORM FORWARD")
119 dim = x.ndim - len(normalized_shape)
120 M = math.prod(x.shape[:dim])
121 N = math.prod(normalized_shape)
123 BLOCK_SIZE = triton.next_power_of_2(N)
124 x = x.contiguous()
125 residual = residual.contiguous()
126 weight = weight.contiguous()
127 bias = bias.contiguous()
128 y = torch.empty_like(x)
130 with torch_device_fn.device(x.device):
131 if BLOCK_SIZE <= 1024:
132 skip_layer_norm_kernel[M,](
133 y, x, residual, weight, bias, N, 1, N, 1, N, 1, N, eps, BLOCK_SIZE
134 )
135 else:
136 BLOCK_SIZE = 1024
137 skip_layer_norm_c_split_kernel[M,](
138 y,
139 x,
140 residual,
141 weight,
142 bias,
143 N,
144 1,
145 N,
146 1,
147 N,
148 1,
149 N,
150 eps,
151 BLOCK_SIZE,
152 num_warps=16,
153 )
154 return y
157def skip_layer_norm(x, residual, normalized_shape, weight, bias, eps=1e-5):
158 return SkipLayerNorm.apply(x, residual, normalized_shape, weight, bias, eps)