Coverage for src/flag_gems/runtime/backend/_spacemit/ops/layernorm.py: 0%

166 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, tl_extra_shim 

11from flag_gems.utils.type_utils import get_accumulator_dtype 

12 

13pow = tl_extra_shim.pow 

14 

15 

16@libentry() 

17@triton.jit(do_not_specialize=["eps"]) 

18def layer_norm_common_kernel( 

19 X, 

20 Y, 

21 W, 

22 B, 

23 Mean, 

24 Rstd, 

25 M, 

26 N, 

27 eps, 

28 TILE_N: tl.constexpr, 

29): 

30 # Map the program id to the row of X and Y it should compute. 

31 row = tl.program_id(0) 

32 

33 X = X + row * N 

34 Y = Y + row * N 

35 

36 # Compute mean 

37 mean = 0.0 

38 var = 0.0 

39 num_pid_n = tl.cdiv(N, TILE_N) 

40 x_ptr_desc = tl.make_block_ptr( 

41 base=X, 

42 shape=[N], 

43 strides=[1], 

44 offsets=[0], 

45 block_shape=[TILE_N], 

46 order=[0], 

47 ) 

48 for off_n in range(0, num_pid_n): 

49 a = tl.load( 

50 x_ptr_desc, 

51 boundary_check=[0], 

52 ) 

53 mean += tl.sum(a) 

54 var += tl.sum(pow(a, (2).to(X.type.element_ty))) 

55 

56 x_ptr_desc = tl.advance(x_ptr_desc, [TILE_N]) 

57 

58 mean = mean / N 

59 var = var / N - (mean * mean) 

60 rstd = tl.math.rsqrt(var + eps) 

61 # Write mean / rstd 

62 tl.store(Mean + row, mean) 

63 tl.store(Rstd + row, rstd) 

64 

65 x_ptr_desc = tl.make_block_ptr( 

66 base=X, 

67 shape=[N], 

68 strides=[1], 

69 offsets=[0], 

70 block_shape=[TILE_N], 

71 order=[0], 

72 ) 

73 

74 weight_ptr_desc = tl.make_block_ptr( 

75 base=W, 

76 shape=[N], 

77 strides=[1], 

78 offsets=[0], 

79 block_shape=[TILE_N], 

80 order=[0], 

81 ) 

82 

83 bias_ptr_desc = tl.make_block_ptr( 

84 base=B, 

85 shape=[N], 

86 strides=[1], 

87 offsets=[0], 

88 block_shape=[TILE_N], 

89 order=[0], 

90 ) 

91 y_ptr_desc = tl.make_block_ptr( 

92 base=Y, 

93 shape=[N], 

94 strides=[1], 

95 offsets=[0], 

96 block_shape=[TILE_N], 

97 order=[0], 

98 ) 

99 

100 for off_n in range(0, num_pid_n): 

101 a = tl.load( 

102 x_ptr_desc, 

103 boundary_check=[0], 

104 ) 

105 x_hat = (a - mean) * rstd 

106 

107 x_ptr_desc = tl.advance(x_ptr_desc, [TILE_N]) 

108 

109 if W is None: 

110 w = 1 

111 else: 

112 w = tl.load( 

113 weight_ptr_desc, 

114 boundary_check=[0], 

115 ) 

116 weight_ptr_desc = tl.advance(weight_ptr_desc, [TILE_N]) 

117 

118 if B is None: 

119 b = 0 

120 else: 

121 b = tl.load( 

122 bias_ptr_desc, 

123 boundary_check=[0], 

124 ) 

125 bias_ptr_desc = tl.advance(bias_ptr_desc, [TILE_N]) 

126 

127 y = x_hat * w + b 

128 tl.store( 

129 y_ptr_desc, 

130 y, 

131 boundary_check=[0], 

132 ) 

133 y_ptr_desc = tl.advance(y_ptr_desc, [TILE_N]) 

134 

135 

136@libentry() 

137@triton.autotune( 

138 configs=runtime.get_tuned_config("layer_norm_backward"), 

139 key=["M", "N"], 

140) 

141@triton.jit 

142def layer_norm_backward_kernel( 

143 dY, 

144 X, 

145 W, 

146 Mean, 

147 Rstd, 

148 dX, 

149 M, 

150 N, 

151 BLOCK_ROW_SIZE: tl.constexpr, 

152 BLOCK_COL_SIZE: tl.constexpr, 

153): 

154 pid = tl.program_id(0) * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

155 row_mask = pid < M 

156 dY += pid * N 

157 X += pid * N 

158 dX += pid * N 

159 Mean += pid 

160 Rstd += pid 

161 

162 mean = tl.load(Mean, mask=row_mask).to(tl.float32) 

163 rstd = tl.load(Rstd, mask=row_mask).to(tl.float32) 

164 

165 dx_part2 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

166 dx_part3 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

167 

168 for off in range(0, N, BLOCK_COL_SIZE): 

169 cols = off + tl.arange(0, BLOCK_COL_SIZE) 

170 col_mask = cols[None, :] < N 

171 mask = row_mask and col_mask 

172 dy = tl.load(dY + cols[None, :], mask).to(tl.float32) 

173 x = tl.load(X + cols[None, :], mask).to(tl.float32) 

174 x = tl.where(mask, x - mean, 0.0) 

175 x_hat = x * rstd 

176 if W is None: 

177 w = 1 

178 else: 

179 w = tl.load(W + cols, mask=cols < N).to(tl.float32) 

180 dx_hat = dy * w 

181 dx_part2 += dx_hat 

182 dx_part3 += dx_hat * x_hat 

183 

184 dx_2 = tl.sum(dx_part2, axis=1)[:, None] 

185 dx_3 = tl.sum(dx_part3, axis=1)[:, None] 

186 

187 for off in range(0, N, BLOCK_COL_SIZE): 

188 cols = off + tl.arange(0, BLOCK_COL_SIZE) 

189 col_mask = cols[None, :] < N 

190 mask = row_mask and col_mask 

191 dy = tl.load(dY + cols[None, :], mask).to(tl.float32) 

192 x = tl.load(X + cols[None, :], mask).to(tl.float32) 

193 if W is None: 

194 w = 1 

195 else: 

196 w = tl.load(W + cols, mask=cols < N).to(tl.float32) 

197 x = tl.where(mask, x - mean, 0.0) 

198 x_hat = x * rstd 

199 dx_hat = dy * w 

200 dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / N) 

201 tl.store(dX + cols, dx, mask=mask) 

202 

203 

204@libentry() 

205@triton.autotune( 

206 configs=runtime.get_tuned_config("weight_bias_backward"), 

207 key=["N"], 

208) 

209@triton.jit 

210def weight_bias_backward_kernel( 

211 dY, 

212 X, 

213 Mean, 

214 Rstd, 

215 dW, 

216 dB, 

217 M, 

218 N, 

219 BLOCK_ROW_SIZE: tl.constexpr, 

220 BLOCK_COL_SIZE: tl.constexpr, 

221): 

222 pid = tl.program_id(0) * BLOCK_COL_SIZE + tl.arange(0, BLOCK_COL_SIZE)[None, :] 

223 col_mask = pid < N 

224 dY += pid 

225 X += pid 

226 accW = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

227 accB = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

228 for off in range(0, M, BLOCK_ROW_SIZE): 

229 rows = off + tl.arange(0, BLOCK_ROW_SIZE) 

230 row_mask = rows[:, None] < M 

231 mask = row_mask and col_mask 

232 dy = tl.load(dY + rows[:, None] * N, mask).to(tl.float32) 

233 x = tl.load(X + rows[:, None] * N, mask).to(tl.float32) 

234 mean = tl.load(Mean + rows, mask=rows < M)[:, None].to(tl.float32) 

235 rstd = tl.load(Rstd + rows, mask=rows < M)[:, None].to(tl.float32) 

236 x = tl.where(col_mask, x - mean, 0.0) 

237 x_hat = x * rstd 

238 accW += dy * x_hat 

239 accB += dy 

240 if dW is not None: 

241 dw = tl.sum(accW, axis=0) 

242 tl.store(dW + pid, dw[None, :], mask=col_mask) 

243 if dB is not None: 

244 db = tl.sum(accB, axis=0) 

245 tl.store(dB + pid, db[None, :], mask=col_mask) 

246 

247 

248class LayerNorm(torch.autograd.Function): 

249 @staticmethod 

250 def forward(ctx, x, normalized_shape, weight, bias, eps=1e-5, cudnn_enable=True): 

251 logging.debug("GEMS_SPACEMIT LAYERNORM_FORWARD") 

252 # dim = x.ndim - len(normalized_shape) 

253 # M = math.prod(x.shape[:dim]) 

254 N = math.prod(normalized_shape) 

255 M = x.numel() // N 

256 

257 x = x.contiguous() 

258 if weight is not None: 

259 weight = weight.contiguous() 

260 if bias is not None: 

261 bias = bias.contiguous() 

262 y = torch.empty_like(x) 

263 

264 # NOTE: when the input is half-precision(either float16 or bfloat16) 

265 # these statistical data saved for backward is in single precision 

266 acc_type = get_accumulator_dtype(x.dtype) 

267 mean = torch.empty(M, dtype=acc_type, device=x.device) 

268 rstd = torch.empty(M, dtype=acc_type, device=x.device) 

269 

270 TILE_N = 512 

271 with torch_device_fn.device(x.device): 

272 layer_norm_common_kernel[(M,)]( 

273 x, y, weight, bias, mean, rstd, M, N, eps, TILE_N=TILE_N 

274 ) 

275 

276 if x.requires_grad: 

277 ctx.save_for_backward(x, weight, bias, mean, rstd) 

278 ctx.M = M 

279 ctx.N = N 

280 return y, mean, rstd 

281 

282 @staticmethod 

283 def backward(ctx, out_grad, mean_grad, rstd_grad): 

284 logging.debug("GEMS_SPACEMIT LAYERNORM_BACKWARD") 

285 out_grad = out_grad.contiguous() 

286 (x, weight, bias, mean, rstd) = ctx.saved_tensors 

287 M = ctx.M 

288 N = ctx.N 

289 

290 with torch_device_fn.device(x.device): 

291 in_grad = torch.empty_like(x) 

292 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_ROW_SIZE"]), 1, 1) 

293 layer_norm_backward_kernel[grid]( 

294 out_grad, x, weight, mean, rstd, in_grad, M, N 

295 ) 

296 

297 if weight is None and bias is None: 

298 return in_grad, None, None, None, None, None 

299 

300 with torch_device_fn.device(x.device): 

301 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_COL_SIZE"]), 1, 1) 

302 weight_grad = None if weight is None else torch.empty_like(weight) 

303 bias_grad = None if bias is None else torch.empty_like(bias) 

304 weight_bias_backward_kernel[grid]( 

305 out_grad, x, mean, rstd, weight_grad, bias_grad, M, N 

306 ) 

307 return in_grad, None, weight_grad, bias_grad, None, None 

308 

309 

310def layer_norm( 

311 x, normalized_shape, weight=None, bias=None, eps=1e-5, cudnn_enable=True 

312): 

313 return LayerNorm.apply(x, normalized_shape, weight, bias, eps, cudnn_enable)