Coverage for src/flag_gems/runtime/backend/_sunrise/ops/layernorm.py: 0%

243 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as ext 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16@triton.jit 

17def prev_multiple_of(a, b): 

18 # the largest x<a that x%b ==0 

19 return tl.cdiv(a, b) * b - b 

20 

21 

22@libentry() 

23@triton.autotune( 

24 configs=runtime.get_tuned_config("layer_norm_persistent"), 

25 key=["M", "N"], 

26) 

27@triton.jit(do_not_specialize=["eps"]) 

28def layer_norm_persistent_kernel( 

29 in_ptr, 

30 out_ptr, 

31 weight_ptr, 

32 bias_ptr, 

33 out_mean_ptr, # pointer to the mean 

34 out_rstd_ptr, # pointer to the 1/std 

35 M, 

36 N, 

37 eps, 

38 TILE_N: tl.constexpr, 

39): 

40 # using 1d tile makes code clean 

41 # Map the program id to the row of X and Y it should compute. 

42 pid = ext.program_id(0) 

43 

44 n_offsets = tl.arange(0, TILE_N) 

45 mask = n_offsets < N 

46 

47 x = tl.load(in_ptr + pid * N + n_offsets, mask, other=0.0).to(tl.float32) 

48 m = tl.sum(x) / N 

49 d = x - m # deviation 

50 s = tl.where(mask, d * d, 0) 

51 sum_square = tl.sum(s) # sum of square of deviation 

52 var = sum_square / N 

53 rstd = tl.math.rsqrt(var + eps) 

54 

55 tl.store(out_mean_ptr + pid, m) 

56 tl.store(out_rstd_ptr + pid, rstd) 

57 

58 if weight_ptr is None: 

59 w = 1 

60 else: 

61 w = tl.load(weight_ptr + n_offsets, mask=mask) 

62 if bias_ptr is None: 

63 b = 0 

64 else: 

65 b = tl.load(bias_ptr + n_offsets, mask=mask) 

66 out = (x - m) * rstd * w + b 

67 

68 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask) 

69 

70 

71@libentry() 

72@triton.autotune( 

73 configs=runtime.get_tuned_config("layer_norm_persistent"), 

74 key=["M", "N"], 

75) 

76@triton.jit(do_not_specialize=["eps"]) 

77def layer_norm_persistent_kernel_multiline( 

78 in_ptr, 

79 out_ptr, 

80 weight_ptr, 

81 bias_ptr, 

82 out_mean_ptr, # pointer to the mean 

83 out_rstd_ptr, # pointer to the 1/std 

84 M, 

85 N, 

86 eps, 

87 TILE_M: tl.constexpr, 

88 TILE_N: tl.constexpr, 

89): 

90 # Map the program id to the row of X and Y it should compute. 

91 pid = ext.program_id(0) 

92 m_offsets = pid * TILE_M + tl.arange(0, TILE_M) 

93 m_mask = m_offsets < M 

94 

95 n_offsets = tl.arange(0, TILE_N)[None, :] 

96 n_mask = n_offsets < N 

97 mask = m_mask[:, None] & n_mask 

98 

99 x = tl.load(in_ptr + m_offsets[:, None] * N + n_offsets, mask, other=0.0).to( 

100 tl.float32 

101 ) 

102 m = tl.sum(x, axis=1) / N 

103 d = x - m[:, None] # deviation 

104 s = tl.where(mask, d * d, 0) 

105 sum_square = tl.sum(s, axis=1) # sum of square of deviation 

106 var = sum_square / N 

107 rstd = tl.math.rsqrt(var + eps) 

108 

109 tl.store(out_mean_ptr + m_offsets, m, mask=m_mask) 

110 tl.store(out_rstd_ptr + m_offsets, rstd, mask=m_mask) 

111 

112 if weight_ptr is None: 

113 w = 1 

114 else: 

115 w = tl.load(weight_ptr + n_offsets, mask=n_mask) 

116 if bias_ptr is None: 

117 b = 0 

118 else: 

119 b = tl.load(bias_ptr + n_offsets, mask=n_mask) 

120 out = (x - m[:, None]) * rstd[:, None] * w + b 

121 

122 tl.store(out_ptr + m_offsets[:, None] * N + n_offsets, out, mask=mask) 

123 

124 

125@libentry() 

126@triton.autotune( 

127 configs=runtime.get_tuned_config("layer_norm_loop"), 

128 key=["M", "N"], 

129) 

130@triton.jit(do_not_specialize=["eps"]) 

131def layer_norm_loop_kernel( 

132 in_ptr, 

133 out_ptr, 

134 weight_ptr, 

135 bias_ptr, 

136 out_mean_ptr, # pointer to the mean 

137 out_rstd_ptr, # pointer to the 1/std 

138 M, 

139 N, 

140 eps, 

141 TILE_N: tl.constexpr, 

142): 

143 # Map the program id to the row of X and Y it should compute. 

144 pid = ext.program_id(0) 

145 

146 # Compute mean 

147 m = tl.zeros((TILE_N,), dtype=tl.float32) # mean 

148 s = tl.zeros((TILE_N,), dtype=tl.float32) # sum((x - m)^2) 

149 cnt = tl.zeros((TILE_N,), dtype=tl.int32) 

150 num_steps = tl.cdiv(N, TILE_N) 

151 for step in range(0, num_steps - 1, 1): 

152 start_n = step * TILE_N 

153 n_offsets = start_n + tl.arange(0, TILE_N) 

154 x = tl.load(in_ptr + pid * N + n_offsets).to(tl.float32) 

155 new_m = m + (x - m) / (step + 1) 

156 new_s = s + (x - new_m) * (x - m) 

157 cnt += 1 

158 m = new_m 

159 s = new_s 

160 

161 # the last step 

162 for step in range(num_steps - 1, num_steps, 1): 

163 start_n = step * TILE_N 

164 n_offsets = start_n + tl.arange(0, TILE_N) 

165 mask = n_offsets < N 

166 x = tl.load(in_ptr + pid * N + n_offsets, mask=mask).to(tl.float32) 

167 new_m = tl.where(mask, m + (x - m) / (step + 1), m) 

168 new_s = tl.where(mask, s + (x - new_m) * (x - m), s) 

169 cnt += mask.to(tl.int32) 

170 m = new_m 

171 s = new_s 

172 

173 final_m = tl.sum(m * cnt) / N 

174 var = tl.sum(s + cnt * (m - final_m) * (m - final_m)) / N 

175 rstd = tl.math.rsqrt(var + eps) 

176 m = final_m 

177 # Write mean / rstd 

178 tl.store(out_mean_ptr + pid, m) 

179 tl.store(out_rstd_ptr + pid, rstd) 

180 

181 # reverse the order of the second sweep 

182 # Normalize and apply linear transformation 

183 prev_multiple = prev_multiple_of(N, TILE_N) 

184 # the first step, masking is needed 

185 for start_n in range(0, TILE_N, TILE_N): 

186 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N) 

187 mask = n_offsets < N 

188 x = tl.load( 

189 in_ptr + pid * N + n_offsets, 

190 mask=mask, 

191 other=0.0, 

192 eviction_policy="evict_first", 

193 ).to(tl.float32) 

194 if weight_ptr is None: 

195 w = 1 

196 else: 

197 w = tl.load(weight_ptr + n_offsets, mask=mask) 

198 if bias_ptr is None: 

199 b = 0 

200 else: 

201 b = tl.load(bias_ptr + n_offsets, mask=mask) 

202 out = w * (x - m) * rstd + b 

203 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask) 

204 

205 for start_n in range(TILE_N, N, TILE_N): 

206 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N) 

207 x = tl.load(in_ptr + pid * N + n_offsets, eviction_policy="evict_first").to( 

208 tl.float32 

209 ) 

210 if weight_ptr is None: 

211 w = 1 

212 else: 

213 w = tl.load(weight_ptr + n_offsets) 

214 if bias_ptr is None: 

215 b = 0 

216 else: 

217 b = tl.load(bias_ptr + n_offsets) 

218 out = w * (x - m) * rstd + b 

219 tl.store(out_ptr + pid * N + n_offsets, out) 

220 

221 

222@libentry() 

223@triton.autotune( 

224 configs=runtime.get_tuned_config("layer_norm_backward"), 

225 key=["M", "N"], 

226) 

227@triton.jit 

228def layer_norm_backward_kernel( 

229 dY, 

230 X, 

231 W, 

232 Mean, 

233 Rstd, 

234 dX, 

235 M, 

236 N, 

237 has_w: tl.constexpr, 

238 BLOCK_ROW_SIZE: tl.constexpr, 

239 BLOCK_COL_SIZE: tl.constexpr, 

240): 

241 pid = tl.program_id(0) * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

242 row_mask = pid < M 

243 

244 dY_ptr = dY + pid * N 

245 X_ptr = X + pid * N 

246 dX_ptr = dX + pid * N 

247 Mean_ptr = Mean + pid 

248 Rstd_ptr = Rstd + pid 

249 

250 mean = tl.load(Mean_ptr, mask=row_mask).to(tl.float32) 

251 rstd = tl.load(Rstd_ptr, mask=row_mask).to(tl.float32) 

252 

253 dx_part2 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

254 dx_part3 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

255 

256 for off in range(0, N, BLOCK_COL_SIZE): 

257 cols = off + tl.arange(0, BLOCK_COL_SIZE) 

258 col_mask = cols[None, :] < N 

259 mask = row_mask & col_mask 

260 dy = tl.load(dY_ptr + cols[None, :], mask, other=0.0).to(tl.float32) 

261 x = tl.load(X_ptr + cols[None, :], mask, other=0.0).to(tl.float32) 

262 x = tl.where(mask, x - mean, 0.0) 

263 x_hat = x * rstd 

264 if has_w: 

265 w = tl.load(W + cols, mask=cols < N, other=1.0).to(tl.float32) 

266 else: 

267 w = 1.0 

268 dx_hat = dy * w 

269 dx_part2 += dx_hat 

270 dx_part3 += dx_hat * x_hat 

271 

272 dx_2 = tl.sum(dx_part2, axis=1)[:, None] 

273 dx_3 = tl.sum(dx_part3, axis=1)[:, None] 

274 

275 for off in range(0, N, BLOCK_COL_SIZE): 

276 cols = off + tl.arange(0, BLOCK_COL_SIZE) 

277 col_mask = cols[None, :] < N 

278 mask = row_mask & col_mask 

279 dy = tl.load(dY_ptr + cols[None, :], mask, other=0.0).to(tl.float32) 

280 x = tl.load(X_ptr + cols[None, :], mask, other=0.0).to(tl.float32) 

281 if has_w: 

282 w = tl.load(W + cols, mask=cols < N, other=1.0).to(tl.float32) 

283 else: 

284 w = 1.0 

285 x = tl.where(mask, x - mean, 0.0) 

286 x_hat = x * rstd 

287 dx_hat = dy * w 

288 dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / N) 

289 tl.store(dX_ptr + cols[None, :], dx, mask=mask) 

290 

291 

292@libentry() 

293@triton.autotune( 

294 configs=runtime.get_tuned_config("weight_bias_backward"), 

295 key=["N"], 

296) 

297@triton.jit 

298def weight_bias_backward_kernel( 

299 dY, 

300 X, 

301 Mean, 

302 Rstd, 

303 dW, 

304 dB, 

305 M, 

306 N, 

307 BLOCK_ROW_SIZE: tl.constexpr, 

308 BLOCK_COL_SIZE: tl.constexpr, 

309): 

310 pid = ( 

311 ext.program_id(0) * BLOCK_COL_SIZE + tl.arange(0, BLOCK_COL_SIZE)[None, :] 

312 ) # triton地址自动广播可能会出现对不齐的情况,所以用到的时候手动广播 

313 col_mask = pid < N 

314 dY += pid 

315 X += pid 

316 accW = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

317 accB = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

318 for off in range(0, M, BLOCK_ROW_SIZE): 

319 rows = off + tl.arange(0, BLOCK_ROW_SIZE) # triton地址自动广播可能会出现对不齐的情况,所以用到的时候手动广播 

320 row_mask = rows[:, None] < M 

321 mask = row_mask & col_mask 

322 dy = tl.load(dY + rows[:, None] * N, mask).to(tl.float32) 

323 x = tl.load(X + rows[:, None] * N, mask).to(tl.float32) 

324 mean = tl.load(Mean + rows, mask=rows < M)[:, None].to(tl.float32) 

325 rstd = tl.load(Rstd + rows, mask=rows < M)[:, None].to(tl.float32) 

326 x = tl.where(mask, x - mean, 0.0) 

327 x_hat = x * rstd 

328 accW += dy * x_hat 

329 accB += dy 

330 if dW: 

331 dw = tl.sum(accW, axis=0) 

332 tl.store(dW + pid, dw[None, :], mask=col_mask) 

333 if dB: 

334 db = tl.sum(accB, axis=0) 

335 tl.store(dB + pid, db[None, :], mask=col_mask) 

336 

337 

338def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): 

339 logger.debug("GEMS LAYERNORM FORWARD") 

340 

341 N = math.prod(normalized_shape) 

342 M = input.numel() // N 

343 

344 input = input.contiguous() 

345 weight = None if weight is None else weight.contiguous() 

346 bias = None if bias is None else bias.contiguous() 

347 y = torch.empty(input.shape, dtype=input.dtype).to(device=input.device) 

348 

349 # NOTE: when the input is half-precision(either float16 or bfloat16) 

350 # these statistical data saved for backward is in single precision 

351 mean = torch.empty(M, dtype=input.dtype, device=input.device) 

352 rstd = torch.empty(M, dtype=input.dtype, device=input.device) 

353 

354 with torch_device_fn.device(input.device): 

355 if N <= 128: 

356 TILE_N = triton.next_power_of_2(N) 

357 TILE_M = triton.cdiv(1024, TILE_N) 

358 grid = (triton.cdiv(M, TILE_M), 1, 1) 

359 layer_norm_persistent_kernel_multiline[grid]( 

360 input, 

361 y, 

362 weight, 

363 bias, 

364 mean, 

365 rstd, 

366 M, 

367 N, 

368 eps, 

369 TILE_M, 

370 TILE_N, 

371 ) 

372 elif N <= 4096: 

373 TILE_N = triton.next_power_of_2(N) 

374 grid = (M, 1, 1) 

375 layer_norm_persistent_kernel[grid]( 

376 input, 

377 y, 

378 weight, 

379 bias, 

380 mean, 

381 rstd, 

382 M, 

383 N, 

384 eps, 

385 TILE_N, 

386 ) 

387 else: 

388 grid = (M, 1, 1) 

389 layer_norm_loop_kernel[grid]( 

390 input, 

391 y, 

392 weight, 

393 bias, 

394 mean, 

395 rstd, 

396 M, 

397 N, 

398 eps, 

399 ) 

400 return y, mean, rstd 

401 

402 

403def layer_norm_backward( 

404 grad_out, 

405 input, 

406 normalized_shape, 

407 mean, 

408 rstd, 

409 weight=None, 

410 bias=None, 

411 output_mask=None, 

412): 

413 logger.debug("GEMS LAYERNORM BACKWARD") 

414 

415 grad_out = grad_out.contiguous() 

416 input = input.contiguous() 

417 mean = mean.contiguous() 

418 rstd = rstd.contiguous() 

419 weight = None if weight is None else weight.contiguous() 

420 bias = None if bias is None else bias.contiguous() 

421 

422 M = input.shape[0] 

423 N = input.numel() // M 

424 

425 if output_mask[0]: 

426 in_grad = torch.empty(input.shape, dtype=input.dtype).to(device=input.device) 

427 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_ROW_SIZE"]), 1, 1) 

428 has_w = 1 if weight is not None else 0 

429 with torch_device_fn.device(input.device): 

430 layer_norm_backward_kernel[grid]( 

431 grad_out, input, weight, mean, rstd, in_grad, M, N, has_w 

432 ) 

433 else: 

434 in_grad = None 

435 

436 if output_mask[1] is False and output_mask[2] is False: 

437 return in_grad, None, None 

438 

439 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_COL_SIZE"]), 1, 1) 

440 weight_grad = ( 

441 torch.empty(weight.shape, dtype=weight.dtype).to(device=weight.device) 

442 if output_mask[1] 

443 else None 

444 ) 

445 bias_grad = ( 

446 torch.empty(bias.shape, dtype=bias.dtype).to(device=bias.device) 

447 if output_mask[2] 

448 else None 

449 ) 

450 with torch_device_fn.device(input.device): 

451 weight_bias_backward_kernel[grid]( 

452 grad_out, input, mean, rstd, weight_grad, bias_grad, M, N 

453 ) 

454 return in_grad, weight_grad, bias_grad