Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/instance_norm.py: 0%

343 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import math 

3from typing import Optional 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry 

12from flag_gems.utils.type_utils import get_accumulator_dtype 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15Tensor = torch.Tensor 

16 

17 

18@triton.jit 

19def prev_multiple_of(a, b): 

20 # the largest x<a that x%b ==0 

21 return tl.cdiv(a, b) * b - b 

22 

23 

24@libentry() 

25@triton.autotune( 

26 configs=runtime.get_tuned_config("instancenorm"), 

27 key=["M", "N"], 

28) 

29@triton.jit(do_not_specialize=["eps"]) 

30def instance_norm_persistent_kernel( 

31 in_ptr, 

32 out_ptr, 

33 weight_ptr, 

34 bias_ptr, 

35 out_mean_ptr, # pointer to the mean 

36 out_rstd_ptr, # pointer to the 1/std 

37 M, # M = B * C 

38 N, 

39 C, 

40 eps, 

41 TILE_N: tl.constexpr, 

42 HAS_WEIGHT_BIAS: tl.constexpr, 

43): 

44 # using 1d tile makes code clean 

45 # Map the program id to the row of X and Y it should compute. 

46 pid = tl.program_id(0) 

47 m_mask = pid < M 

48 c_offsets = pid % C 

49 

50 n_offsets = tl.arange(0, TILE_N) 

51 mask = n_offsets < N 

52 

53 x = tl.load(in_ptr + pid * N + n_offsets, mask, other=0.0).to(tl.float32) 

54 m = tl.sum(x) / N 

55 d = x - m # deviation 

56 s = tl.where(mask, d * d, 0) 

57 sum_square = tl.sum(s) # sum of square of deviation 

58 var = sum_square / N 

59 rstd = tl.math.rsqrt(var + eps) 

60 

61 tl.store(out_mean_ptr + pid, m) 

62 tl.store(out_rstd_ptr + pid, rstd) 

63 

64 if HAS_WEIGHT_BIAS: 

65 w = tl.load(weight_ptr + c_offsets, mask=m_mask) 

66 b = tl.load(bias_ptr + c_offsets, mask=m_mask) 

67 out = (x - m) * rstd * w + b 

68 else: 

69 out = (x - m) * rstd 

70 

71 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask) 

72 

73 

74@libentry() 

75# @triton.autotune( 

76# configs=runtime.get_tuned_config("instancenorm"), 

77# key=["M", "N"], 

78# ) 

79@triton.jit(do_not_specialize=["eps"]) 

80def instance_norm_persistent_kernel_multiline( 

81 in_ptr, 

82 out_ptr, 

83 weight_ptr, 

84 bias_ptr, 

85 out_mean_ptr, # pointer to the mean 

86 out_rstd_ptr, # pointer to the 1/std 

87 M, # M = B * C 

88 N, 

89 C, 

90 eps, 

91 TILE_M: tl.constexpr, 

92 TILE_N: tl.constexpr, 

93 HAS_WEIGHT_BIAS: tl.constexpr, 

94): 

95 # Map the program id to the row of X and Y it should compute. 

96 pid = tl.program_id(0) 

97 m_offsets = pid * TILE_M + tl.arange(0, TILE_M) 

98 m_mask = m_offsets < M 

99 c_offsets = m_offsets % C 

100 

101 n_offsets = tl.arange(0, TILE_N)[None, :] 

102 n_mask = n_offsets < N 

103 mask = m_mask[:, None] & n_mask 

104 

105 x = tl.load(in_ptr + m_offsets[:, None] * N + n_offsets, mask, other=0.0).to( 

106 tl.float32 

107 ) 

108 m = tl.sum(x, axis=1) / N 

109 d = x - m[:, None] # deviation 

110 s = tl.where(mask, d * d, 0) 

111 sum_square = tl.sum(s, axis=1) # sum of square of deviation 

112 var = sum_square / N 

113 rstd = tl.math.rsqrt(var + eps) 

114 

115 tl.store(out_mean_ptr + m_offsets, m, mask=m_mask) 

116 tl.store(out_rstd_ptr + m_offsets, rstd, mask=m_mask) 

117 

118 if HAS_WEIGHT_BIAS: 

119 w = tl.load(weight_ptr + c_offsets, mask=m_mask) 

120 b = tl.load(bias_ptr + c_offsets, mask=m_mask) 

121 out = (x - m[:, None]) * rstd[:, None] * w[:, None] + b[:, None] 

122 else: 

123 out = (x - m[:, None]) * rstd[:, None] 

124 

125 tl.store(out_ptr + m_offsets[:, None] * N + n_offsets, out, mask=mask) 

126 

127 

128def instance_norm_loop_kernel_heur_tile_n(args): 

129 return 8192 

130 

131 

132@libentry() 

133# @triton.autotune( 

134# configs=runtime.get_tuned_config("instance_norm_loop"), 

135# key=["M", "N"], 

136# ) 

137@triton.heuristics( 

138 values={ 

139 "TILE_N": instance_norm_loop_kernel_heur_tile_n, 

140 }, 

141) 

142@triton.jit(do_not_specialize=["eps"]) 

143def instance_norm_loop_kernel( 

144 in_ptr, 

145 out_ptr, 

146 weight_ptr, 

147 bias_ptr, 

148 out_mean_ptr, # pointer to the mean 

149 out_rstd_ptr, # pointer to the 1/std 

150 M, # M = B * C 

151 N, 

152 C, 

153 eps, 

154 TILE_N: tl.constexpr, 

155 HAS_WEIGHT_BIAS: tl.constexpr, 

156): 

157 # Map the program id to the row of X and Y it should compute. 

158 pid = tl.program_id(0) 

159 m_mask = pid < M 

160 c_offsets = pid % C 

161 

162 # Compute mean 

163 m = tl.zeros((TILE_N,), dtype=tl.float32) # mean 

164 s = tl.zeros((TILE_N,), dtype=tl.float32) # sum((x - m)^2) 

165 cnt = tl.zeros((TILE_N,), dtype=tl.int32) 

166 num_steps = tl.cdiv(N, TILE_N) 

167 for step in range(0, num_steps - 1, 1): 

168 start_n = step * TILE_N 

169 n_offsets = start_n + tl.arange(0, TILE_N) 

170 x = tl.load(in_ptr + pid * N + n_offsets).to(tl.float32) 

171 new_m = m + (x - m) / (step + 1) 

172 new_s = s + (x - new_m) * (x - m) 

173 cnt += 1 

174 m = new_m 

175 s = new_s 

176 

177 # the last step 

178 for step in range(num_steps - 1, num_steps, 1): 

179 start_n = step * TILE_N 

180 n_offsets = start_n + tl.arange(0, TILE_N) 

181 mask = n_offsets < N 

182 x = tl.load(in_ptr + pid * N + n_offsets, mask=mask).to(tl.float32) 

183 new_m = tl.where(mask, m + (x - m) / (step + 1), m) 

184 new_s = tl.where(mask, s + (x - new_m) * (x - m), s) 

185 cnt += mask.to(tl.int32) 

186 m = new_m 

187 s = new_s 

188 

189 final_m = tl.sum(m * cnt) / N 

190 var = tl.sum(s + cnt * (m - final_m) * (m - final_m)) / N 

191 rstd = tl.math.rsqrt(var + eps) 

192 m = final_m 

193 # Write mean / rstd 

194 tl.store(out_mean_ptr + pid, m) 

195 tl.store(out_rstd_ptr + pid, rstd) 

196 

197 if HAS_WEIGHT_BIAS: 

198 w = tl.load(weight_ptr + c_offsets, mask=m_mask) 

199 b = tl.load(bias_ptr + c_offsets, mask=m_mask) 

200 else: 

201 w = 1 

202 b = 0 

203 

204 # reverse the order of the second sweep 

205 # Normalize and apply linear transformation 

206 prev_multiple = prev_multiple_of(N, TILE_N) 

207 # the first step, masking is needed 

208 for start_n in range(0, TILE_N, TILE_N): 

209 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N) 

210 mask = n_offsets < N 

211 x = tl.load( 

212 in_ptr + pid * N + n_offsets, 

213 mask=mask, 

214 other=0.0, 

215 eviction_policy="evict_first", 

216 ).to(tl.float32) 

217 out = w * (x - m) * rstd + b 

218 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask) 

219 

220 for start_n in range(TILE_N, N, TILE_N): 

221 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N) 

222 x = tl.load(in_ptr + pid * N + n_offsets, eviction_policy="evict_first").to( 

223 tl.float32 

224 ) 

225 out = w * (x - m) * rstd + b 

226 tl.store(out_ptr + pid * N + n_offsets, out) 

227 

228 

229@libentry() 

230@triton.jit(do_not_specialize=["eps"]) 

231def instancenorm_fwd_kernel_xpu( 

232 X, 

233 Y, 

234 W, 

235 B, 

236 MEAN, 

237 RSTRD, 

238 M: tl.constexpr, 

239 N: tl.constexpr, 

240 C: tl.constexpr, 

241 eps: tl.constexpr, 

242 HAS_WEIGHT_BIAS: tl.constexpr, 

243 XBLOCK: tl.constexpr, 

244 RBLOCK: tl.constexpr, 

245): 

246 pid = tl.program_id(0) 

247 xoffset = pid * XBLOCK 

248 _xindex = xoffset + tl.arange(0, XBLOCK) 

249 xindex = _xindex[:, None] 

250 xmask = xindex < M 

251 rbase = tl.arange(0, RBLOCK)[None, :] 

252 _mean = tl.full([XBLOCK, RBLOCK], 0, tl.float32) 

253 _var = tl.full([XBLOCK, RBLOCK], 0, tl.float32) 

254 

255 for roffset in range(0, N, RBLOCK): 

256 rindex = roffset + rbase 

257 rmask = rindex < N 

258 x = tl.load(X + (rindex + (N * xindex)), rmask & xmask, other=0.0).to( 

259 tl.float32 

260 ) 

261 _mean = _mean + tl.broadcast_to(x, [XBLOCK, RBLOCK]) 

262 _var = _var + tl.broadcast_to(x * x, [XBLOCK, RBLOCK]) 

263 

264 mean = tl.sum(_mean, 1)[:, None] / N 

265 var = tl.sum(_var, 1)[:, None] / N 

266 var_mean = var - mean * mean 

267 rstd = 1 / tl.sqrt(var_mean + eps) 

268 

269 tl.store(MEAN + xindex, mean, xmask) 

270 tl.store(RSTRD + xindex, rstd, xmask) 

271 

272 cindex = xindex % C 

273 for roffset in range(0, N, RBLOCK): 

274 rindex = roffset + rbase 

275 rmask = rindex < N 

276 x = tl.load(X + (rindex + (N * xindex)), rmask & xmask, other=0.0).to( 

277 tl.float32 

278 ) 

279 if HAS_WEIGHT_BIAS: 

280 w = tl.load(W + cindex, xmask) 

281 b = tl.load(B + cindex, xmask) 

282 else: 

283 w = 1 

284 b = 0 

285 x_hat = (x - mean) * rstd 

286 y = x_hat * w + b 

287 tl.store(Y + (rindex + (N * xindex)), y, rmask & xmask) 

288 

289 

290def instance_norm_use_running_stats_kernel_heur_tile_n(args): 

291 return 8192 

292 

293 

294@libentry() 

295# @triton.autotune( 

296# configs=runtime.get_tuned_config("instancenorm"), 

297# key=["M", "N"], 

298# ) 

299@triton.jit(do_not_specialize=["eps"]) 

300def instance_norm_use_running_stats_kernel( 

301 in_ptr, 

302 out_ptr, 

303 weight_ptr, 

304 bias_ptr, 

305 running_mean_ptr, # pointer to the mean 

306 running_var_ptr, # pointer to the var 

307 out_mean_ptr, # pointer to the mean 

308 out_rstd_ptr, # pointer to the 1/std 

309 M, # M = B * C 

310 N, 

311 C, 

312 eps, 

313 TILE_N: tl.constexpr, 

314 HAS_WEIGHT_BIAS: tl.constexpr, 

315): 

316 # using 1d tile makes code clean 

317 # Map the program id to the row of X and Y it should compute. 

318 pid = tl.program_id(0) 

319 m_mask = pid < M 

320 c_offsets = pid % C 

321 

322 n_offsets = tl.arange(0, TILE_N) 

323 mask = n_offsets < N 

324 

325 x = tl.load(in_ptr + pid * N + n_offsets, mask, other=0.0).to(tl.float32) 

326 m = tl.load(running_mean_ptr + c_offsets, mask=m_mask).to(tl.float32) 

327 var = tl.load(running_var_ptr + c_offsets, mask=m_mask).to(tl.float32) 

328 rstd = tl.math.rsqrt(var + eps) 

329 

330 tl.store(out_mean_ptr + pid, m) 

331 tl.store(out_rstd_ptr + pid, rstd) 

332 

333 if HAS_WEIGHT_BIAS: 

334 w = tl.load(weight_ptr + c_offsets, mask=m_mask).to(tl.float32) 

335 b = tl.load(bias_ptr + c_offsets, mask=m_mask).to(tl.float32) 

336 out = (x - m) * rstd * w + b 

337 else: 

338 out = (x - m) * rstd 

339 

340 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask) 

341 

342 

343@triton.jit 

344def update_running_stats_kernel( 

345 mean_ptr, # pointer to the mean 

346 rstd_ptr, # pointer to the 1/std 

347 running_mean_ptr, 

348 running_var_ptr, 

349 momentum, 

350 B, 

351 C, 

352 N, 

353 eps, 

354 BLOCK_BATCH_SIZE: tl.constexpr = 1, 

355 BLOCK_CHANNEL_SIZE: tl.constexpr = 2048, 

356): 

357 cid = tl.program_id(0) * BLOCK_CHANNEL_SIZE + tl.arange(0, BLOCK_CHANNEL_SIZE) 

358 col_mask = cid < C 

359 running_mean = tl.load(running_mean_ptr + cid, mask=col_mask).to(tl.float32) 

360 running_var = tl.load(running_var_ptr + cid, mask=col_mask).to(tl.float32) 

361 

362 new_mean = tl.zeros((BLOCK_CHANNEL_SIZE,), dtype=tl.float32) 

363 new_var = tl.zeros((BLOCK_CHANNEL_SIZE,), dtype=tl.float32) 

364 for b in range(0, B, BLOCK_BATCH_SIZE): 

365 bid = b * BLOCK_BATCH_SIZE + tl.arange(0, BLOCK_BATCH_SIZE)[:, None] 

366 row_mask = bid < B 

367 mask = row_mask and col_mask[None, :] 

368 mean = tl.load(mean_ptr + bid * C + cid[None, :], mask=mask, other=0.0).to( 

369 tl.float32 

370 ) 

371 rstd = tl.load(rstd_ptr + bid * C + cid[None, :], mask=mask, other=0.0).to( 

372 tl.float32 

373 ) 

374 var = ( 

375 (1 / (rstd * rstd) + eps) * N / (N - 1) 

376 ) # NOTE: use unbiased var to update running_var 

377 

378 new_mean += tl.sum(mean, axis=0) 

379 new_var += tl.sum(var, axis=0) 

380 

381 new_running_mean = (1 - momentum) * running_mean + momentum * new_mean / B 

382 new_running_var = (1 - momentum) * running_var + momentum * new_var / B 

383 

384 tl.store(running_mean_ptr + cid, new_running_mean, mask=col_mask) 

385 tl.store(running_var_ptr + cid, new_running_var, mask=col_mask) 

386 

387 

388def instance_norm_backward_kernel_heur_block_row_size(args): 

389 return 1 

390 

391 

392def instance_norm_backward_kernel_heur_block_col_size(args): 

393 import builtins 

394 

395 return builtins.min(triton.next_power_of_2(args.get("N", 0)), 8192) 

396 

397 

398@libentry() 

399# @triton.autotune( 

400# configs=runtime.get_tuned_config("instance_norm_backward"), 

401# key=["M", "N", "C"], 

402# ) 

403@triton.heuristics( 

404 values={ 

405 "BLOCK_ROW_SIZE": instance_norm_backward_kernel_heur_block_row_size, 

406 "BLOCK_COL_SIZE": instance_norm_backward_kernel_heur_block_col_size, 

407 }, 

408) 

409@triton.jit 

410def instance_norm_backward_kernel( 

411 dY, 

412 X, 

413 W, 

414 Mean, # [B, C] 

415 Rstd, # [B, C] 

416 dX, 

417 M, # M = B * C 

418 N, 

419 C, 

420 BLOCK_ROW_SIZE: tl.constexpr, 

421 BLOCK_COL_SIZE: tl.constexpr, 

422 HAS_WEIGHT_BIAS: tl.constexpr, 

423): 

424 pid = tl.program_id(0) * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

425 c_offsets = pid % C 

426 row_mask = pid < M 

427 dY += pid * N 

428 X += pid * N 

429 dX += pid * N 

430 Mean += pid 

431 Rstd += pid 

432 

433 mean = tl.load(Mean, mask=row_mask, other=0.0).to(tl.float32) 

434 rstd = tl.load(Rstd, mask=row_mask, other=1.0).to(tl.float32) 

435 if HAS_WEIGHT_BIAS: 

436 w = tl.load(W + c_offsets, mask=row_mask).to(tl.float32) 

437 else: 

438 w = 1 

439 

440 dx_part2 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

441 dx_part3 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

442 

443 for off in range(0, N, BLOCK_COL_SIZE): 

444 cols = off + tl.arange(0, BLOCK_COL_SIZE) 

445 col_mask = cols[None, :] < N 

446 mask = row_mask and col_mask 

447 dy = tl.load(dY + cols[None, :], mask).to(tl.float32) 

448 x = tl.load(X + cols[None, :], mask).to(tl.float32) 

449 x = tl.where(mask, x - mean, 0.0) 

450 x_hat = x * rstd 

451 dx_hat = dy * w 

452 dx_part2 += dx_hat 

453 dx_part3 += dx_hat * x_hat 

454 

455 dx_2 = tl.sum(dx_part2, axis=1)[:, None] 

456 dx_3 = tl.sum(dx_part3, axis=1)[:, None] 

457 

458 for off in range(0, N, BLOCK_COL_SIZE): 

459 cols = off + tl.arange(0, BLOCK_COL_SIZE) 

460 col_mask = cols[None, :] < N 

461 mask = row_mask and col_mask 

462 dy = tl.load(dY + cols[None, :], mask).to(tl.float32) 

463 x = tl.load(X + cols[None, :], mask).to(tl.float32) 

464 x = tl.where(mask, x - mean, 0.0) 

465 x_hat = x * rstd 

466 dx_hat = dy * w 

467 dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / N) 

468 tl.store(dX + cols, dx, mask=mask) 

469 

470 

471def weight_bias_backward_kernel_heur_block_batch_size(args): 

472 return 1 

473 

474 

475def weight_bias_backward_kernel_heur_block_col_size(args): 

476 return triton.next_power_of_2(triton.cdiv(args.get("C", 1), 12)) # cluster_num 

477 

478 

479@libentry() 

480# @triton.autotune( 

481# configs=runtime.get_tuned_config("instance_norm_weight_bias_backward"), 

482# key=["N", "B", "C"], 

483# ) 

484@triton.heuristics( 

485 values={ 

486 "BLOCK_BATCH_SIZE": weight_bias_backward_kernel_heur_block_batch_size, 

487 "BLOCK_COL_SIZE": weight_bias_backward_kernel_heur_block_col_size, 

488 }, 

489) 

490@triton.jit 

491def weight_bias_backward_kernel( 

492 dY, 

493 X, 

494 Mean, # [B, C] 

495 Rstd, # [B, C] 

496 dW, 

497 dB, 

498 M, 

499 N, 

500 B, 

501 C, 

502 BLOCK_BATCH_SIZE: tl.constexpr, 

503 BLOCK_COL_SIZE: tl.constexpr, 

504): 

505 cid = tl.program_id(0)[:, None] 

506 dW += cid 

507 dB += cid 

508 c_mask = cid < C 

509 

510 accW = tl.zeros([BLOCK_BATCH_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

511 accB = tl.zeros([BLOCK_BATCH_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

512 

513 for b_off in range(0, B, BLOCK_BATCH_SIZE): 

514 bid = b_off + tl.arange(0, BLOCK_BATCH_SIZE)[:, None] 

515 mid = bid * C + cid 

516 row_mask = bid < B 

517 mean = tl.load(Mean + mid, mask=row_mask).to(tl.float32) 

518 rstd = tl.load(Rstd + mid, mask=row_mask).to(tl.float32) 

519 for off in range(0, N, BLOCK_COL_SIZE): 

520 cols = off + tl.arange(0, BLOCK_COL_SIZE) 

521 col_mask = cols[None, :] < N 

522 mask = row_mask and col_mask 

523 dy = tl.load(dY + mid * N + cols[None, :], mask).to(tl.float32) 

524 x = tl.load(X + mid * N + cols[None, :], mask).to(tl.float32) 

525 x = tl.where(mask, x - mean, 0.0) 

526 x_hat = x * rstd 

527 accW += dy * x_hat 

528 accB += dy 

529 dw = tl.sum(accW) 

530 db = tl.sum(accB) 

531 tl.store(dW, dw, mask=c_mask) 

532 tl.store(dB, db, mask=c_mask) 

533 

534 

535class InstanceNorm(torch.autograd.Function): 

536 @staticmethod 

537 def forward( 

538 ctx, 

539 x, 

540 weight=None, 

541 bias=None, 

542 running_mean=None, 

543 running_var=None, 

544 use_input_stats=False, 

545 momentum=0.1, 

546 eps=1e-05, 

547 cudnn_enable=False, 

548 ): 

549 logger.debug("GEMS_KUNLUNXIN INSTANCE_NORM") 

550 assert len(x.shape) in [ 

551 3, 

552 4, 

553 5, 

554 ], f"x.shape should be [B, C, N] or [B, C, H, W] or [B, C, H, W, L], but got {x.shape}" 

555 B, C = x.shape[:2] 

556 N = math.prod(x.shape[2:]) 

557 M = x.numel() // N 

558 

559 x = x.contiguous() 

560 weight = weight.contiguous() if weight is not None else None 

561 bias = bias.contiguous() if bias is not None else None 

562 y = torch.empty_like(x) 

563 

564 has_weight_bias = weight is not None and bias is not None 

565 

566 has_running_stats = running_mean is not None 

567 if has_running_stats: 

568 assert ( 

569 N > 1 

570 ), f"Expected more than 1 spatial element when training, got input size {x.shape}" 

571 assert ( 

572 running_mean is not None and running_var is not None 

573 ), "running_mean and running_var should not both be None" 

574 assert ( 

575 running_mean.shape == running_var.shape and running_mean.shape[0] == C 

576 ), f"running_mean and running_var should have shape as {[C,]}" 

577 assert ( 

578 running_mean.dtype == running_var.dtype 

579 ), "running_mean and running_var should have the same dtype" 

580 if not use_input_stats: 

581 assert ( 

582 has_running_stats 

583 ), "Expected running_mean and running_var to be defined when use_input_stats is False" 

584 

585 # NOTE: when the input is half-precision(either float16 or bfloat16) 

586 # these statistical data saved for backward is in single precision 

587 acc_type = get_accumulator_dtype(x.dtype) 

588 mean = torch.empty(size=(B, C), dtype=acc_type, device=x.device) 

589 rstd = torch.empty(size=(B, C), dtype=acc_type, device=x.device) 

590 

591 with torch_device_fn.device(x.device): 

592 if use_input_stats: 

593 grid = (12, 1, 1) 

594 instancenorm_fwd_kernel_xpu[grid]( 

595 x, 

596 y, 

597 weight, 

598 bias, 

599 mean, 

600 rstd, 

601 M, 

602 N, 

603 C, 

604 eps, 

605 HAS_WEIGHT_BIAS=has_weight_bias, 

606 XBLOCK=triton.next_power_of_2(triton.cdiv(M, 12)), 

607 RBLOCK=8192, 

608 isCloseUnrollControl=True, 

609 buffer_size_limit=512, 

610 ) 

611 if has_running_stats and use_input_stats: # update running stats 

612 grid = lambda meta: ( 

613 triton.cdiv(C, meta["BLOCK_CHANNEL_SIZE"]), 

614 1, 

615 1, 

616 ) 

617 update_running_stats_kernel[grid]( 

618 mean, 

619 rstd, 

620 running_mean, 

621 running_var, 

622 momentum, 

623 B, 

624 C, 

625 N, 

626 eps, 

627 isCloseCoreTiling=True, 

628 isCloseVectorization=True, 

629 isCloseUnrollControl=True, 

630 ) 

631 else: # use running stats instead of input stats 

632 TILE_N = triton.next_power_of_2(N) 

633 grid = (M, 1, 1) 

634 instance_norm_use_running_stats_kernel[grid]( 

635 x, 

636 y, 

637 weight, 

638 bias, 

639 running_mean, 

640 running_var, 

641 mean, 

642 rstd, 

643 M, 

644 N, 

645 C, 

646 eps, 

647 TILE_N, 

648 HAS_WEIGHT_BIAS=has_weight_bias, 

649 isCloseUnrollControl=True, 

650 ) 

651 

652 ctx.save_for_backward(x, weight, mean, rstd) 

653 ctx.M = M 

654 ctx.N = N 

655 ctx.C = C 

656 ctx.has_weight_bias = has_weight_bias 

657 return y 

658 

659 @staticmethod 

660 def backward(ctx, out_grad): 

661 logger.debug("GEMS_KUNLUNXIN INSTANCE_NORM_BACKWARD") 

662 out_grad = out_grad.contiguous() 

663 x, weight, mean, rstd = ctx.saved_tensors 

664 M = ctx.M 

665 N = ctx.N 

666 C = ctx.C 

667 B = M // C 

668 

669 with torch_device_fn.device(x.device): 

670 in_grad = torch.empty_like(x) 

671 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_ROW_SIZE"]), 1, 1) 

672 

673 instance_norm_backward_kernel[grid]( 

674 out_grad, 

675 x, 

676 weight, 

677 mean, 

678 rstd, 

679 in_grad, 

680 M, 

681 N, 

682 C, 

683 HAS_WEIGHT_BIAS=ctx.has_weight_bias, 

684 isCloseCoreTiling=True, 

685 ) 

686 

687 if ctx.has_weight_bias: 

688 grid = lambda meta: (C, 1, 1) 

689 weight_grad = torch.empty_like(weight) 

690 bias_grad = torch.empty_like(weight) 

691 weight_bias_backward_kernel[grid]( 

692 out_grad, 

693 x, 

694 mean, 

695 rstd, 

696 weight_grad, 

697 bias_grad, 

698 M, 

699 N, 

700 B, 

701 C, 

702 ) 

703 else: 

704 weight_grad = None 

705 bias_grad = None 

706 return in_grad, weight_grad, bias_grad, None, None, None, None, None, None 

707 

708 

709def instance_norm( 

710 input: Tensor, 

711 weight: Optional[Tensor] = None, 

712 bias: Optional[Tensor] = None, 

713 running_mean: Optional[Tensor] = None, 

714 running_var: Optional[Tensor] = None, 

715 use_input_stats: bool = True, 

716 momentum: float = 0.1, 

717 eps: float = 1e-5, 

718 cudnn_enable: bool = False, 

719) -> Tensor: 

720 r"""Applies Instance Normalization for each channel in each data sample in a 

721 batch. 

722 Inputs: 

723 input: input tensor of shape :math:`(N, C, *)` 

724 weight: weight tensor of shape :math:`(C)` 

725 bias: bias tensor of shape :math:`(C)` 

726 running_mean: running mean tensor of shape :math:`(C)` 

727 running_var: running variance tensor of shape :math:`(C)` 

728 use_input_stats: whether to use the mean and variance of the input tensor 

729 momentum: momentum value for the running mean and variance 

730 eps: epsilon value for numerical stability 

731 cudnn_enable: whether to use cudnn for normalization 

732 Returns: 

733 output tensor of shape :math:`(N, C, *)` 

734 """ 

735 

736 return InstanceNorm.apply( 

737 input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps 

738 )