Coverage for src/flag_gems/runtime/backend/_ascend/ops/batch_norm.py: 0%

159 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from torch import Tensor 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry, tl_extra_shim 

10 

11logger = logging.getLogger(__name__) 

12rsqrt = tl_extra_shim.rsqrt 

13 

14 

15def make_3d_for_bn(input: Tensor) -> Tensor: 

16 if input.ndim == 2: 

17 input = input.unsqueeze(-1) 

18 elif input.ndim >= 4: 

19 input = input.flatten(2, -1) 

20 return input 

21 

22 

23def _block_m(batch_dim): 

24 return min(64, triton.next_power_of_2(batch_dim)) 

25 

26 

27def _block_n(batch_dim, spatial_dim): 

28 BLOCK_M = _block_m(batch_dim) 

29 BLOCK_N = triton.next_power_of_2(spatial_dim) 

30 return min(BLOCK_N, max(1, 2**10 // BLOCK_M)) 

31 

32 

33@libentry() 

34@triton.jit 

35def batch_norm_forward_kernel( 

36 input_pointer, 

37 weight_pointer, 

38 bias_pointer, 

39 mean_pointer, 

40 inv_std_pointer, 

41 output_pointer, 

42 running_mean_pointer, 

43 running_var_pointer, 

44 batch_dim, 

45 spatial_dim, 

46 input_batch_stride, 

47 input_feat_stride, 

48 input_spatial_stride, 

49 output_batch_stride, 

50 output_feat_stride, 

51 output_spatial_stride, 

52 momentum, 

53 eps, 

54 is_train: tl.constexpr, 

55 BLOCK_M: tl.constexpr, 

56 BLOCK_N: tl.constexpr, 

57): 

58 feat_pid = tl.program_id(axis=0) 

59 

60 if is_train: 

61 mean = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

62 var = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

63 cnt = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) 

64 

65 m_num_steps = tl.cdiv(batch_dim, BLOCK_M) 

66 n_num_steps = tl.cdiv(spatial_dim, BLOCK_N) 

67 

68 for m_step in range(0, m_num_steps): 

69 for n_step in range(0, n_num_steps): 

70 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) 

71 spatial_mask = spatial_offset < spatial_dim 

72 

73 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) 

74 batch_mask = batch_offset < batch_dim 

75 

76 curr_input_pointer = ( 

77 input_pointer 

78 + input_feat_stride * feat_pid 

79 + input_batch_stride * batch_offset[:, None] 

80 + input_spatial_stride * spatial_offset[None, :] 

81 ) 

82 

83 mask = batch_mask[:, None] & spatial_mask[None, :] 

84 curr_input = tl.load(curr_input_pointer, mask=mask).to(tl.float32) 

85 

86 step = m_step * n_num_steps + n_step + 1 

87 new_mean = tl.where(mask, mean + (curr_input - mean) / step, mean) 

88 new_var = tl.where( 

89 mask, var + (curr_input - new_mean) * (curr_input - mean), var 

90 ) 

91 cnt += mask.to(tl.int32) 

92 mean = new_mean 

93 var = new_var 

94 

95 final_mean = tl.sum(mean * cnt) / (batch_dim * spatial_dim) 

96 var = tl.sum(var + cnt * (mean - final_mean) * (mean - final_mean)) / ( 

97 batch_dim * spatial_dim 

98 ) 

99 inv_std = rsqrt(var + eps) 

100 mean = final_mean 

101 

102 tl.store(feat_pid + mean_pointer, mean) 

103 tl.store(feat_pid + inv_std_pointer, inv_std) 

104 

105 running_mean_pointer += feat_pid 

106 running_var_pointer += feat_pid 

107 

108 running_mean = tl.load(running_mean_pointer) 

109 running_var = tl.load(running_var_pointer) 

110 

111 n = batch_dim * spatial_dim 

112 tl.store(running_mean_pointer, (1 - momentum) * running_mean + momentum * mean) 

113 tl.store( 

114 running_var_pointer, 

115 (1 - momentum) * running_var + momentum * var * n / (n - 1), 

116 ) 

117 

118 else: 

119 mean = tl.load(feat_pid + running_mean_pointer) 

120 inv_std = rsqrt(tl.load(feat_pid + running_var_pointer) + eps) 

121 

122 if weight_pointer: 

123 weight = tl.load(feat_pid + weight_pointer).to(tl.float32) 

124 else: 

125 weight = 1.0 

126 if bias_pointer: 

127 bias = tl.load(feat_pid + bias_pointer).to(tl.float32) 

128 else: 

129 bias = 0.0 

130 

131 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)): 

132 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)): 

133 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) 

134 batch_mask = batch_offset < batch_dim 

135 

136 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) 

137 spatial_mask = spatial_offset < spatial_dim 

138 

139 curr_input_pointer = ( 

140 input_pointer 

141 + input_feat_stride * feat_pid 

142 + input_batch_stride * batch_offset[:, None] 

143 + input_spatial_stride * spatial_offset[None, :] 

144 ) 

145 curr_output_pointer = ( 

146 output_pointer 

147 + output_feat_stride * feat_pid 

148 + output_batch_stride * batch_offset[:, None] 

149 + output_spatial_stride * spatial_offset[None, :] 

150 ) 

151 

152 curr_input = tl.load( 

153 curr_input_pointer, mask=batch_mask[:, None] & spatial_mask[None, :] 

154 ).to(tl.float32) 

155 output = weight * (curr_input - mean) * inv_std + bias 

156 

157 tl.store( 

158 curr_output_pointer, 

159 output, 

160 mask=batch_mask[:, None] & spatial_mask[None, :], 

161 ) 

162 

163 

164@libentry() 

165@triton.jit 

166def batch_norm_backward_kernel( 

167 output_grad_pointer, 

168 input_pointer, 

169 mean_pointer, 

170 inv_std_pointer, 

171 weight_pointer, 

172 input_grad_pointer, 

173 weight_grad_pointer, 

174 bias_grad_pointer, 

175 batch_dim, 

176 spatial_dim, 

177 output_grad_batch_stride, 

178 output_grad_feat_stride, 

179 output_grad_spatial_stride, 

180 input_batch_stride, 

181 input_feat_stride, 

182 input_spatial_stride, 

183 input_grad_batch_stride, 

184 input_grad_feat_stride, 

185 input_grad_spatial_stride, 

186 input_grad_mask: tl.constexpr, 

187 weight_grad_mask: tl.constexpr, 

188 bias_grad_mask: tl.constexpr, 

189 BLOCK_M: tl.constexpr, 

190 BLOCK_N: tl.constexpr, 

191): 

192 feat_pid = tl.program_id(axis=0) 

193 

194 mean = tl.load(feat_pid + mean_pointer).to(tl.float32) 

195 inv_std = tl.load(feat_pid + inv_std_pointer).to(tl.float32) 

196 

197 term1 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

198 term2 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

199 

200 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)): 

201 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)): 

202 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) 

203 batch_mask = batch_offset < batch_dim 

204 

205 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) 

206 spatial_mask = spatial_offset < spatial_dim 

207 

208 curr_output_grad_pointer = ( 

209 output_grad_pointer 

210 + output_grad_feat_stride * feat_pid 

211 + output_grad_batch_stride * batch_offset[:, None] 

212 + output_grad_spatial_stride * spatial_offset[None, :] 

213 ) 

214 curr_input_pointer = ( 

215 input_pointer 

216 + input_feat_stride * feat_pid 

217 + input_batch_stride * batch_offset[:, None] 

218 + input_spatial_stride * spatial_offset[None, :] 

219 ) 

220 

221 mask = batch_mask[:, None] & spatial_mask[None, :] 

222 curr_input = tl.load(curr_input_pointer, mask=mask).to(tl.float32) 

223 

224 curr_pre_lin = (curr_input - mean) * inv_std 

225 curr_output_grad = tl.load(curr_output_grad_pointer, mask=mask).to( 

226 tl.float32 

227 ) 

228 

229 term1 += curr_pre_lin * curr_output_grad 

230 term2 += curr_output_grad 

231 

232 term1 = tl.sum(term1) 

233 term2 = tl.sum(term2) 

234 

235 if weight_grad_mask: 

236 tl.store(feat_pid + weight_grad_pointer, term1) 

237 if bias_grad_mask: 

238 tl.store(feat_pid + bias_grad_pointer, term2) 

239 

240 if not input_grad_mask: 

241 return 

242 

243 if weight_pointer: 

244 weight = tl.load(feat_pid + weight_pointer).to(tl.float32) 

245 else: 

246 weight = 1.0 

247 

248 count = batch_dim * spatial_dim 

249 

250 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)): 

251 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)): 

252 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) 

253 batch_mask = batch_offset < batch_dim 

254 

255 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) 

256 spatial_mask = spatial_offset < spatial_dim 

257 

258 curr_output_grad_pointer = ( 

259 output_grad_pointer 

260 + output_grad_feat_stride * feat_pid 

261 + output_grad_batch_stride * batch_offset[:, None] 

262 + output_grad_spatial_stride * spatial_offset[None, :] 

263 ) 

264 curr_input_pointer = ( 

265 input_pointer 

266 + input_feat_stride * feat_pid 

267 + input_batch_stride * batch_offset[:, None] 

268 + input_spatial_stride * spatial_offset[None, :] 

269 ) 

270 curr_input_grad_pointer = ( 

271 input_grad_pointer 

272 + input_grad_feat_stride * feat_pid 

273 + input_grad_batch_stride * batch_offset[:, None] 

274 + input_grad_spatial_stride * spatial_offset[None, :] 

275 ) 

276 

277 curr_input = tl.load( 

278 curr_input_pointer, mask=batch_mask[:, None] & spatial_mask[None, :] 

279 ).to(tl.float32) 

280 curr_pre_lin = (curr_input - mean) * inv_std 

281 curr_output_grad = tl.load( 

282 curr_output_grad_pointer, 

283 mask=batch_mask[:, None] & spatial_mask[None, :], 

284 ).to(tl.float32) 

285 curr_input_grad = ( 

286 inv_std 

287 * weight 

288 * (curr_output_grad - (term1 * curr_pre_lin + term2) / count) 

289 ) 

290 tl.store( 

291 curr_input_grad_pointer, 

292 curr_input_grad, 

293 mask=batch_mask[:, None] & spatial_mask[None, :], 

294 ) 

295 

296 

297def batch_norm( 

298 input: Tensor, 

299 weight=None, 

300 bias=None, 

301 running_mean=None, 

302 running_var=None, 

303 training=False, 

304 momentum=0.1, 

305 eps=1e-05, 

306): 

307 logger.debug("GEMS_ASCEND BATCHNORM FORWARD") 

308 

309 input_3d = make_3d_for_bn(input) 

310 

311 batch_dim, feat_dim, spatial_dim = input_3d.shape 

312 output = torch.empty_like(input_3d) 

313 

314 mean = torch.empty(feat_dim, device=input.device, dtype=input.dtype) 

315 inv_std = torch.empty(feat_dim, device=input.device, dtype=input.dtype) 

316 

317 running_mean = input if running_mean is None else running_mean 

318 running_var = input if running_var is None else running_var 

319 

320 BM = _block_m(batch_dim) 

321 BN = _block_n(batch_dim, spatial_dim) 

322 

323 with torch_device_fn.device(input.device): 

324 batch_norm_forward_kernel[(feat_dim,)]( 

325 input_3d, 

326 weight, 

327 bias, 

328 mean, 

329 inv_std, 

330 output, 

331 running_mean, 

332 running_var, 

333 batch_dim, 

334 spatial_dim, 

335 *input_3d.stride(), 

336 *output.stride(), 

337 momentum, 

338 eps, 

339 is_train=training, 

340 BLOCK_M=BM, 

341 BLOCK_N=BN, 

342 ) 

343 

344 return output.view_as(input), mean, inv_std 

345 

346 

347def batch_norm_backward( 

348 grad_out, 

349 input, 

350 weight=None, 

351 running_mean=None, 

352 running_var=None, 

353 save_mean=None, 

354 save_invstd=None, 

355 train=False, 

356 eps=1e-05, 

357 output_mask=None, 

358): 

359 logger.debug("GEMS_ASCEND BATCHNORM BACKWARD") 

360 input_3d = make_3d_for_bn(input) 

361 output_grad_3d = make_3d_for_bn(grad_out) 

362 

363 batch_dim, feat_dim, spatial_dim = input_3d.shape 

364 

365 if output_mask[0]: 

366 input_grad = torch.empty_like(input_3d) 

367 else: 

368 input_grad = None 

369 if output_mask[1]: 

370 weight_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device) 

371 else: 

372 weight_grad = None 

373 if output_mask[2]: 

374 bias_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device) 

375 else: 

376 bias_grad = None 

377 

378 BM = _block_m(batch_dim) 

379 BN = _block_n(batch_dim, spatial_dim) 

380 

381 with torch_device_fn.device(input.device): 

382 batch_norm_backward_kernel[(feat_dim,)]( 

383 output_grad_3d, 

384 input_3d, 

385 save_mean, 

386 save_invstd, 

387 weight, 

388 input_grad, 

389 weight_grad, 

390 bias_grad, 

391 batch_dim, 

392 spatial_dim, 

393 *output_grad_3d.stride(), 

394 *input_3d.stride(), 

395 *input_grad.stride(), 

396 *output_mask, 

397 BLOCK_M=BM, 

398 BLOCK_N=BN, 

399 ) 

400 

401 return ( 

402 input_grad.view_as(input), 

403 weight_grad, 

404 bias_grad, 

405 )