Coverage for src/flag_gems/runtime/backend/_spacemit/ops/layernorm.py: 0%
166 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, tl_extra_shim
11from flag_gems.utils.type_utils import get_accumulator_dtype
13pow = tl_extra_shim.pow
16@libentry()
17@triton.jit(do_not_specialize=["eps"])
18def layer_norm_common_kernel(
19 X,
20 Y,
21 W,
22 B,
23 Mean,
24 Rstd,
25 M,
26 N,
27 eps,
28 TILE_N: tl.constexpr,
29):
30 # Map the program id to the row of X and Y it should compute.
31 row = tl.program_id(0)
33 X = X + row * N
34 Y = Y + row * N
36 # Compute mean
37 mean = 0.0
38 var = 0.0
39 num_pid_n = tl.cdiv(N, TILE_N)
40 x_ptr_desc = tl.make_block_ptr(
41 base=X,
42 shape=[N],
43 strides=[1],
44 offsets=[0],
45 block_shape=[TILE_N],
46 order=[0],
47 )
48 for off_n in range(0, num_pid_n):
49 a = tl.load(
50 x_ptr_desc,
51 boundary_check=[0],
52 )
53 mean += tl.sum(a)
54 var += tl.sum(pow(a, (2).to(X.type.element_ty)))
56 x_ptr_desc = tl.advance(x_ptr_desc, [TILE_N])
58 mean = mean / N
59 var = var / N - (mean * mean)
60 rstd = tl.math.rsqrt(var + eps)
61 # Write mean / rstd
62 tl.store(Mean + row, mean)
63 tl.store(Rstd + row, rstd)
65 x_ptr_desc = tl.make_block_ptr(
66 base=X,
67 shape=[N],
68 strides=[1],
69 offsets=[0],
70 block_shape=[TILE_N],
71 order=[0],
72 )
74 weight_ptr_desc = tl.make_block_ptr(
75 base=W,
76 shape=[N],
77 strides=[1],
78 offsets=[0],
79 block_shape=[TILE_N],
80 order=[0],
81 )
83 bias_ptr_desc = tl.make_block_ptr(
84 base=B,
85 shape=[N],
86 strides=[1],
87 offsets=[0],
88 block_shape=[TILE_N],
89 order=[0],
90 )
91 y_ptr_desc = tl.make_block_ptr(
92 base=Y,
93 shape=[N],
94 strides=[1],
95 offsets=[0],
96 block_shape=[TILE_N],
97 order=[0],
98 )
100 for off_n in range(0, num_pid_n):
101 a = tl.load(
102 x_ptr_desc,
103 boundary_check=[0],
104 )
105 x_hat = (a - mean) * rstd
107 x_ptr_desc = tl.advance(x_ptr_desc, [TILE_N])
109 if W is None:
110 w = 1
111 else:
112 w = tl.load(
113 weight_ptr_desc,
114 boundary_check=[0],
115 )
116 weight_ptr_desc = tl.advance(weight_ptr_desc, [TILE_N])
118 if B is None:
119 b = 0
120 else:
121 b = tl.load(
122 bias_ptr_desc,
123 boundary_check=[0],
124 )
125 bias_ptr_desc = tl.advance(bias_ptr_desc, [TILE_N])
127 y = x_hat * w + b
128 tl.store(
129 y_ptr_desc,
130 y,
131 boundary_check=[0],
132 )
133 y_ptr_desc = tl.advance(y_ptr_desc, [TILE_N])
136@libentry()
137@triton.autotune(
138 configs=runtime.get_tuned_config("layer_norm_backward"),
139 key=["M", "N"],
140)
141@triton.jit
142def layer_norm_backward_kernel(
143 dY,
144 X,
145 W,
146 Mean,
147 Rstd,
148 dX,
149 M,
150 N,
151 BLOCK_ROW_SIZE: tl.constexpr,
152 BLOCK_COL_SIZE: tl.constexpr,
153):
154 pid = tl.program_id(0) * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None]
155 row_mask = pid < M
156 dY += pid * N
157 X += pid * N
158 dX += pid * N
159 Mean += pid
160 Rstd += pid
162 mean = tl.load(Mean, mask=row_mask).to(tl.float32)
163 rstd = tl.load(Rstd, mask=row_mask).to(tl.float32)
165 dx_part2 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
166 dx_part3 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
168 for off in range(0, N, BLOCK_COL_SIZE):
169 cols = off + tl.arange(0, BLOCK_COL_SIZE)
170 col_mask = cols[None, :] < N
171 mask = row_mask and col_mask
172 dy = tl.load(dY + cols[None, :], mask).to(tl.float32)
173 x = tl.load(X + cols[None, :], mask).to(tl.float32)
174 x = tl.where(mask, x - mean, 0.0)
175 x_hat = x * rstd
176 if W is None:
177 w = 1
178 else:
179 w = tl.load(W + cols, mask=cols < N).to(tl.float32)
180 dx_hat = dy * w
181 dx_part2 += dx_hat
182 dx_part3 += dx_hat * x_hat
184 dx_2 = tl.sum(dx_part2, axis=1)[:, None]
185 dx_3 = tl.sum(dx_part3, axis=1)[:, None]
187 for off in range(0, N, BLOCK_COL_SIZE):
188 cols = off + tl.arange(0, BLOCK_COL_SIZE)
189 col_mask = cols[None, :] < N
190 mask = row_mask and col_mask
191 dy = tl.load(dY + cols[None, :], mask).to(tl.float32)
192 x = tl.load(X + cols[None, :], mask).to(tl.float32)
193 if W is None:
194 w = 1
195 else:
196 w = tl.load(W + cols, mask=cols < N).to(tl.float32)
197 x = tl.where(mask, x - mean, 0.0)
198 x_hat = x * rstd
199 dx_hat = dy * w
200 dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / N)
201 tl.store(dX + cols, dx, mask=mask)
204@libentry()
205@triton.autotune(
206 configs=runtime.get_tuned_config("weight_bias_backward"),
207 key=["N"],
208)
209@triton.jit
210def weight_bias_backward_kernel(
211 dY,
212 X,
213 Mean,
214 Rstd,
215 dW,
216 dB,
217 M,
218 N,
219 BLOCK_ROW_SIZE: tl.constexpr,
220 BLOCK_COL_SIZE: tl.constexpr,
221):
222 pid = tl.program_id(0) * BLOCK_COL_SIZE + tl.arange(0, BLOCK_COL_SIZE)[None, :]
223 col_mask = pid < N
224 dY += pid
225 X += pid
226 accW = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
227 accB = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
228 for off in range(0, M, BLOCK_ROW_SIZE):
229 rows = off + tl.arange(0, BLOCK_ROW_SIZE)
230 row_mask = rows[:, None] < M
231 mask = row_mask and col_mask
232 dy = tl.load(dY + rows[:, None] * N, mask).to(tl.float32)
233 x = tl.load(X + rows[:, None] * N, mask).to(tl.float32)
234 mean = tl.load(Mean + rows, mask=rows < M)[:, None].to(tl.float32)
235 rstd = tl.load(Rstd + rows, mask=rows < M)[:, None].to(tl.float32)
236 x = tl.where(col_mask, x - mean, 0.0)
237 x_hat = x * rstd
238 accW += dy * x_hat
239 accB += dy
240 if dW is not None:
241 dw = tl.sum(accW, axis=0)
242 tl.store(dW + pid, dw[None, :], mask=col_mask)
243 if dB is not None:
244 db = tl.sum(accB, axis=0)
245 tl.store(dB + pid, db[None, :], mask=col_mask)
248class LayerNorm(torch.autograd.Function):
249 @staticmethod
250 def forward(ctx, x, normalized_shape, weight, bias, eps=1e-5, cudnn_enable=True):
251 logging.debug("GEMS_SPACEMIT LAYERNORM_FORWARD")
252 # dim = x.ndim - len(normalized_shape)
253 # M = math.prod(x.shape[:dim])
254 N = math.prod(normalized_shape)
255 M = x.numel() // N
257 x = x.contiguous()
258 if weight is not None:
259 weight = weight.contiguous()
260 if bias is not None:
261 bias = bias.contiguous()
262 y = torch.empty_like(x)
264 # NOTE: when the input is half-precision(either float16 or bfloat16)
265 # these statistical data saved for backward is in single precision
266 acc_type = get_accumulator_dtype(x.dtype)
267 mean = torch.empty(M, dtype=acc_type, device=x.device)
268 rstd = torch.empty(M, dtype=acc_type, device=x.device)
270 TILE_N = 512
271 with torch_device_fn.device(x.device):
272 layer_norm_common_kernel[(M,)](
273 x, y, weight, bias, mean, rstd, M, N, eps, TILE_N=TILE_N
274 )
276 if x.requires_grad:
277 ctx.save_for_backward(x, weight, bias, mean, rstd)
278 ctx.M = M
279 ctx.N = N
280 return y, mean, rstd
282 @staticmethod
283 def backward(ctx, out_grad, mean_grad, rstd_grad):
284 logging.debug("GEMS_SPACEMIT LAYERNORM_BACKWARD")
285 out_grad = out_grad.contiguous()
286 (x, weight, bias, mean, rstd) = ctx.saved_tensors
287 M = ctx.M
288 N = ctx.N
290 with torch_device_fn.device(x.device):
291 in_grad = torch.empty_like(x)
292 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_ROW_SIZE"]), 1, 1)
293 layer_norm_backward_kernel[grid](
294 out_grad, x, weight, mean, rstd, in_grad, M, N
295 )
297 if weight is None and bias is None:
298 return in_grad, None, None, None, None, None
300 with torch_device_fn.device(x.device):
301 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_COL_SIZE"]), 1, 1)
302 weight_grad = None if weight is None else torch.empty_like(weight)
303 bias_grad = None if bias is None else torch.empty_like(bias)
304 weight_bias_backward_kernel[grid](
305 out_grad, x, mean, rstd, weight_grad, bias_grad, M, N
306 )
307 return in_grad, None, weight_grad, bias_grad, None, None
310def layer_norm(
311 x, normalized_shape, weight=None, bias=None, eps=1e-5, cudnn_enable=True
312):
313 return LayerNorm.apply(x, normalized_shape, weight, bias, eps, cudnn_enable)