Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/batch_norm.py: 0%

172 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from torch import Tensor 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, tl_extra_shim 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13rsqrt = tl_extra_shim.rsqrt 

14 

15 

16def make_3d_for_bn(input: Tensor) -> Tensor: 

17 if input.ndim == 2: 

18 input = input.unsqueeze(-1) 

19 elif input.ndim >= 4: 

20 input = input.flatten(2, -1) 

21 return input 

22 

23 

24@libentry() 

25@triton.heuristics(runtime.get_heuristic_config("batch_norm")) 

26@triton.jit 

27def batch_norm_forward_kernel( 

28 input_pointer, 

29 weight_pointer, 

30 bias_pointer, 

31 mean_pointer, 

32 inv_std_pointer, 

33 output_pointer, 

34 running_mean_pointer, 

35 running_var_pointer, 

36 batch_dim, 

37 spatial_dim, 

38 input_batch_stride, 

39 input_feat_stride, 

40 input_spatial_stride, 

41 output_batch_stride, 

42 output_feat_stride, 

43 output_spatial_stride, 

44 momentum, 

45 eps, 

46 is_train: tl.constexpr, 

47 BLOCK_M: tl.constexpr, 

48 BLOCK_N: tl.constexpr, 

49): 

50 feat_pid = tl.program_id(axis=0) 

51 

52 if is_train: 

53 # Two-pass algorithm: first compute sum, then variance 

54 # Pass 1: Compute sum for mean 

55 total_sum = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

56 

57 m_num_steps = tl.cdiv(batch_dim, BLOCK_M) 

58 n_num_steps = tl.cdiv(spatial_dim, BLOCK_N) 

59 

60 for m_step in range(0, m_num_steps): 

61 for n_step in range(0, n_num_steps): 

62 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) 

63 spatial_mask = spatial_offset < spatial_dim 

64 

65 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) 

66 batch_mask = batch_offset < batch_dim 

67 

68 curr_input_pointer = ( 

69 input_pointer 

70 + input_feat_stride * feat_pid 

71 + input_batch_stride * batch_offset[:, None] 

72 + input_spatial_stride * spatial_offset[None, :] 

73 ) 

74 

75 mask = batch_mask[:, None] & spatial_mask[None, :] 

76 curr_input = tl.load(curr_input_pointer, mask=mask, other=0.0).to( 

77 tl.float32 

78 ) 

79 total_sum += curr_input 

80 

81 n_elements = batch_dim * spatial_dim 

82 mean = tl.sum(total_sum) / n_elements 

83 

84 # Pass 2: Compute variance 

85 var_sum = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

86 

87 for m_step in range(0, m_num_steps): 

88 for n_step in range(0, n_num_steps): 

89 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) 

90 spatial_mask = spatial_offset < spatial_dim 

91 

92 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) 

93 batch_mask = batch_offset < batch_dim 

94 

95 curr_input_pointer = ( 

96 input_pointer 

97 + input_feat_stride * feat_pid 

98 + input_batch_stride * batch_offset[:, None] 

99 + input_spatial_stride * spatial_offset[None, :] 

100 ) 

101 

102 mask = batch_mask[:, None] & spatial_mask[None, :] 

103 curr_input = tl.load(curr_input_pointer, mask=mask, other=0.0).to( 

104 tl.float32 

105 ) 

106 diff = tl.where(mask, curr_input - mean, 0.0) 

107 var_sum += diff * diff 

108 

109 var = tl.sum(var_sum) / n_elements 

110 inv_std = rsqrt(var + eps) 

111 

112 tl.store(feat_pid + mean_pointer, mean) 

113 tl.store(feat_pid + inv_std_pointer, inv_std) 

114 

115 running_mean_pointer += feat_pid 

116 running_var_pointer += feat_pid 

117 

118 running_mean = tl.load(running_mean_pointer) 

119 running_var = tl.load(running_var_pointer) 

120 

121 tl.store(running_mean_pointer, (1 - momentum) * running_mean + momentum * mean) 

122 tl.store( 

123 running_var_pointer, 

124 (1 - momentum) * running_var 

125 + momentum * var * n_elements / (n_elements - 1), 

126 ) 

127 

128 else: 

129 mean = tl.load(feat_pid + running_mean_pointer) 

130 inv_std = rsqrt(tl.load(feat_pid + running_var_pointer) + eps) 

131 

132 if weight_pointer: 

133 weight = tl.load(feat_pid + weight_pointer).to(tl.float32) 

134 else: 

135 weight = 1.0 

136 if bias_pointer: 

137 bias = tl.load(feat_pid + bias_pointer).to(tl.float32) 

138 else: 

139 bias = 0.0 

140 

141 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)): 

142 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)): 

143 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) 

144 batch_mask = batch_offset < batch_dim 

145 

146 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) 

147 spatial_mask = spatial_offset < spatial_dim 

148 

149 curr_input_pointer = ( 

150 input_pointer 

151 + input_feat_stride * feat_pid 

152 + input_batch_stride * batch_offset[:, None] 

153 + input_spatial_stride * spatial_offset[None, :] 

154 ) 

155 curr_output_pointer = ( 

156 output_pointer 

157 + output_feat_stride * feat_pid 

158 + output_batch_stride * batch_offset[:, None] 

159 + output_spatial_stride * spatial_offset[None, :] 

160 ) 

161 

162 curr_input = tl.load( 

163 curr_input_pointer, mask=batch_mask[:, None] & spatial_mask[None, :] 

164 ).to(tl.float32) 

165 output = weight * (curr_input - mean) * inv_std + bias 

166 

167 tl.store( 

168 curr_output_pointer, 

169 output, 

170 mask=batch_mask[:, None] & spatial_mask[None, :], 

171 ) 

172 

173 

174def batch_norm_heur_block_m(args): 

175 return min(64, triton.next_power_of_2(args.get("batch_dim", 0))) 

176 

177 

178def batch_norm_heur_block_n(args): 

179 BLOCK_M = batch_norm_heur_block_m(args) 

180 BLOCK_N = triton.next_power_of_2(args.get("spatial_dim", 0)) 

181 return min(BLOCK_N, max(1, 2**14 // BLOCK_M)) 

182 

183 

184@libentry() 

185@triton.heuristics( 

186 values={ 

187 "BLOCK_M": batch_norm_heur_block_m, 

188 "BLOCK_N": batch_norm_heur_block_n, 

189 }, 

190) 

191@triton.jit 

192def batch_norm_backward_kernel( 

193 output_grad_pointer, 

194 input_pointer, 

195 mean_pointer, 

196 inv_std_pointer, 

197 weight_pointer, 

198 input_grad_pointer, 

199 weight_grad_pointer, 

200 bias_grad_pointer, 

201 batch_dim, 

202 spatial_dim, 

203 output_grad_batch_stride, 

204 output_grad_feat_stride, 

205 output_grad_spatial_stride, 

206 input_batch_stride, 

207 input_feat_stride, 

208 input_spatial_stride, 

209 input_grad_batch_stride, 

210 input_grad_feat_stride, 

211 input_grad_spatial_stride, 

212 input_grad_mask: tl.constexpr, 

213 weight_grad_mask: tl.constexpr, 

214 bias_grad_mask: tl.constexpr, 

215 BLOCK_M: tl.constexpr, 

216 BLOCK_N: tl.constexpr, 

217): 

218 feat_pid = tl.program_id(axis=0) 

219 

220 mean = tl.load(feat_pid + mean_pointer).to(tl.float32) 

221 inv_std = tl.load(feat_pid + inv_std_pointer).to(tl.float32) 

222 

223 term1 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

224 term2 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

225 

226 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)): 

227 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) 

228 batch_mask = batch_offset < batch_dim 

229 

230 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)): 

231 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) 

232 spatial_mask = spatial_offset < spatial_dim 

233 

234 curr_output_grad_pointer = ( 

235 output_grad_pointer 

236 + output_grad_feat_stride * feat_pid 

237 + output_grad_batch_stride * batch_offset[:, None] 

238 + output_grad_spatial_stride * spatial_offset[None, :] 

239 ) 

240 curr_input_pointer = ( 

241 input_pointer 

242 + input_feat_stride * feat_pid 

243 + input_batch_stride * batch_offset[:, None] 

244 + input_spatial_stride * spatial_offset[None, :] 

245 ) 

246 

247 mask = batch_mask[:, None] & spatial_mask[None, :] 

248 curr_input = tl.load(curr_input_pointer, mask=mask, other=0).to(tl.float32) 

249 

250 curr_pre_lin = ((curr_input - mean) * inv_std).to(tl.float32) 

251 curr_output_grad = tl.load( 

252 curr_output_grad_pointer, mask=mask, other=0.0 

253 ).to(tl.float32) 

254 

255 term1 += curr_pre_lin * curr_output_grad 

256 term2 += curr_output_grad 

257 

258 term1 = tl.sum(term1) 

259 term2 = tl.sum(term2) 

260 

261 if weight_grad_mask: 

262 tl.store(feat_pid + weight_grad_pointer, term1) 

263 if bias_grad_mask: 

264 tl.store(feat_pid + bias_grad_pointer, term2) 

265 

266 if not input_grad_mask: 

267 return 

268 

269 if weight_pointer: 

270 weight = tl.load(feat_pid + weight_pointer).to(tl.float32) 

271 else: 

272 weight = 1.0 

273 weight = weight.to(tl.float32) 

274 

275 count = batch_dim * spatial_dim 

276 

277 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)): 

278 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)): 

279 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) 

280 batch_mask = batch_offset < batch_dim 

281 

282 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) 

283 spatial_mask = spatial_offset < spatial_dim 

284 

285 curr_output_grad_pointer = ( 

286 output_grad_pointer 

287 + output_grad_feat_stride * feat_pid 

288 + output_grad_batch_stride * batch_offset[:, None] 

289 + output_grad_spatial_stride * spatial_offset[None, :] 

290 ) 

291 curr_input_pointer = ( 

292 input_pointer 

293 + input_feat_stride * feat_pid 

294 + input_batch_stride * batch_offset[:, None] 

295 + input_spatial_stride * spatial_offset[None, :] 

296 ) 

297 curr_input_grad_pointer = ( 

298 input_grad_pointer 

299 + input_grad_feat_stride * feat_pid 

300 + input_grad_batch_stride * batch_offset[:, None] 

301 + input_grad_spatial_stride * spatial_offset[None, :] 

302 ) 

303 

304 curr_input = tl.load( 

305 curr_input_pointer, mask=batch_mask[:, None] & spatial_mask[None, :] 

306 ).to(tl.float32) 

307 curr_pre_lin = (curr_input - mean) * inv_std 

308 curr_output_grad = tl.load( 

309 curr_output_grad_pointer, 

310 mask=batch_mask[:, None] & spatial_mask[None, :], 

311 ).to(tl.float32) 

312 curr_input_grad = ( 

313 inv_std 

314 * weight 

315 * (curr_output_grad - (term1 * curr_pre_lin + term2) / count) 

316 ) 

317 tl.store( 

318 curr_input_grad_pointer, 

319 curr_input_grad, 

320 mask=batch_mask[:, None] & spatial_mask[None, :], 

321 ) 

322 

323 

324def batch_norm( 

325 input: Tensor, 

326 weight=None, 

327 bias=None, 

328 running_mean=None, 

329 running_var=None, 

330 training=False, 

331 momentum=0.1, 

332 eps=1e-05, 

333): 

334 logger.debug("GEMS_KUNLUNXIN BATCH_NORM") 

335 

336 input_3d_i = make_3d_for_bn(input) 

337 m, n, k = input_3d_i.shape 

338 input_3d_f = input_3d_i.permute(0, 2, 1).reshape(-1, n) 

339 input_3d = make_3d_for_bn(input_3d_f) 

340 

341 batch_dim, feat_dim, spatial_dim = input_3d.shape 

342 output = torch.empty_like(input_3d) 

343 

344 mean = torch.empty(feat_dim, device=input.device, dtype=input.dtype) 

345 inv_std = torch.empty(feat_dim, device=input.device, dtype=input.dtype) 

346 

347 running_mean = input if running_mean is None else running_mean 

348 running_var = input if running_var is None else running_var 

349 

350 with torch_device_fn.device(input.device): 

351 batch_norm_forward_kernel[(feat_dim,)]( 

352 input_3d, 

353 weight, 

354 bias, 

355 mean, 

356 inv_std, 

357 output, 

358 running_mean, 

359 running_var, 

360 batch_dim, 

361 spatial_dim, 

362 *input_3d.stride(), 

363 *output.stride(), 

364 momentum, 

365 eps, 

366 is_train=training, 

367 buffer_size_limit=2048, 

368 ) 

369 

370 output_reshaped = output.reshape(m, k, n).permute(0, 2, 1) 

371 return output_reshaped.view_as(input), mean, inv_std 

372 

373 

374def batch_norm_backward( 

375 grad_out, 

376 input, 

377 weight=None, 

378 running_mean=None, 

379 running_var=None, 

380 save_mean=None, 

381 save_invstd=None, 

382 train=False, 

383 eps=1e-05, 

384 output_mask=None, 

385): 

386 logger.debug("GEMS_KUNLUNXIN BATCH_NORM_BACKWARD") 

387 input_3d_i = make_3d_for_bn(input) 

388 m, n, k = input_3d_i.shape 

389 input_3d_f = input_3d_i.permute(0, 2, 1).reshape(-1, n) 

390 input_3d = make_3d_for_bn(input_3d_f) 

391 

392 output_grad_3d_i = make_3d_for_bn(grad_out) 

393 output_grad_3d_f = output_grad_3d_i.permute(0, 2, 1).reshape(-1, n) 

394 output_grad_3d = make_3d_for_bn(output_grad_3d_f) 

395 

396 batch_dim, feat_dim, spatial_dim = input_3d.shape 

397 

398 if output_mask[0]: 

399 input_grad = torch.empty_like(input_3d) 

400 else: 

401 input_grad = None 

402 if output_mask[1]: 

403 weight_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device) 

404 else: 

405 weight_grad = None 

406 if output_mask[2]: 

407 bias_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device) 

408 else: 

409 bias_grad = None 

410 

411 with torch_device_fn.device(input.device): 

412 batch_norm_backward_kernel[(feat_dim, 1, 1)]( 

413 output_grad_3d, 

414 input_3d, 

415 save_mean, 

416 save_invstd, 

417 weight, 

418 input_grad, 

419 weight_grad, 

420 bias_grad, 

421 batch_dim, 

422 spatial_dim, 

423 *output_grad_3d.stride(), 

424 *input_3d.stride(), 

425 *input_grad.stride(), 

426 *output_mask, 

427 buffer_size_limit=2048, 

428 ) 

429 

430 return ( 

431 input_grad.reshape(m, k, n).permute(0, 2, 1).view_as(input), 

432 weight_grad, 

433 bias_grad, 

434 )