Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/instance_norm.py: 0%
343 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
3from typing import Optional
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry
12from flag_gems.utils.type_utils import get_accumulator_dtype
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15Tensor = torch.Tensor
18@triton.jit
19def prev_multiple_of(a, b):
20 # the largest x<a that x%b ==0
21 return tl.cdiv(a, b) * b - b
24@libentry()
25@triton.autotune(
26 configs=runtime.get_tuned_config("instancenorm"),
27 key=["M", "N"],
28)
29@triton.jit(do_not_specialize=["eps"])
30def instance_norm_persistent_kernel(
31 in_ptr,
32 out_ptr,
33 weight_ptr,
34 bias_ptr,
35 out_mean_ptr, # pointer to the mean
36 out_rstd_ptr, # pointer to the 1/std
37 M, # M = B * C
38 N,
39 C,
40 eps,
41 TILE_N: tl.constexpr,
42 HAS_WEIGHT_BIAS: tl.constexpr,
43):
44 # using 1d tile makes code clean
45 # Map the program id to the row of X and Y it should compute.
46 pid = tl.program_id(0)
47 m_mask = pid < M
48 c_offsets = pid % C
50 n_offsets = tl.arange(0, TILE_N)
51 mask = n_offsets < N
53 x = tl.load(in_ptr + pid * N + n_offsets, mask, other=0.0).to(tl.float32)
54 m = tl.sum(x) / N
55 d = x - m # deviation
56 s = tl.where(mask, d * d, 0)
57 sum_square = tl.sum(s) # sum of square of deviation
58 var = sum_square / N
59 rstd = tl.math.rsqrt(var + eps)
61 tl.store(out_mean_ptr + pid, m)
62 tl.store(out_rstd_ptr + pid, rstd)
64 if HAS_WEIGHT_BIAS:
65 w = tl.load(weight_ptr + c_offsets, mask=m_mask)
66 b = tl.load(bias_ptr + c_offsets, mask=m_mask)
67 out = (x - m) * rstd * w + b
68 else:
69 out = (x - m) * rstd
71 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask)
74@libentry()
75# @triton.autotune(
76# configs=runtime.get_tuned_config("instancenorm"),
77# key=["M", "N"],
78# )
79@triton.jit(do_not_specialize=["eps"])
80def instance_norm_persistent_kernel_multiline(
81 in_ptr,
82 out_ptr,
83 weight_ptr,
84 bias_ptr,
85 out_mean_ptr, # pointer to the mean
86 out_rstd_ptr, # pointer to the 1/std
87 M, # M = B * C
88 N,
89 C,
90 eps,
91 TILE_M: tl.constexpr,
92 TILE_N: tl.constexpr,
93 HAS_WEIGHT_BIAS: tl.constexpr,
94):
95 # Map the program id to the row of X and Y it should compute.
96 pid = tl.program_id(0)
97 m_offsets = pid * TILE_M + tl.arange(0, TILE_M)
98 m_mask = m_offsets < M
99 c_offsets = m_offsets % C
101 n_offsets = tl.arange(0, TILE_N)[None, :]
102 n_mask = n_offsets < N
103 mask = m_mask[:, None] & n_mask
105 x = tl.load(in_ptr + m_offsets[:, None] * N + n_offsets, mask, other=0.0).to(
106 tl.float32
107 )
108 m = tl.sum(x, axis=1) / N
109 d = x - m[:, None] # deviation
110 s = tl.where(mask, d * d, 0)
111 sum_square = tl.sum(s, axis=1) # sum of square of deviation
112 var = sum_square / N
113 rstd = tl.math.rsqrt(var + eps)
115 tl.store(out_mean_ptr + m_offsets, m, mask=m_mask)
116 tl.store(out_rstd_ptr + m_offsets, rstd, mask=m_mask)
118 if HAS_WEIGHT_BIAS:
119 w = tl.load(weight_ptr + c_offsets, mask=m_mask)
120 b = tl.load(bias_ptr + c_offsets, mask=m_mask)
121 out = (x - m[:, None]) * rstd[:, None] * w[:, None] + b[:, None]
122 else:
123 out = (x - m[:, None]) * rstd[:, None]
125 tl.store(out_ptr + m_offsets[:, None] * N + n_offsets, out, mask=mask)
128def instance_norm_loop_kernel_heur_tile_n(args):
129 return 8192
132@libentry()
133# @triton.autotune(
134# configs=runtime.get_tuned_config("instance_norm_loop"),
135# key=["M", "N"],
136# )
137@triton.heuristics(
138 values={
139 "TILE_N": instance_norm_loop_kernel_heur_tile_n,
140 },
141)
142@triton.jit(do_not_specialize=["eps"])
143def instance_norm_loop_kernel(
144 in_ptr,
145 out_ptr,
146 weight_ptr,
147 bias_ptr,
148 out_mean_ptr, # pointer to the mean
149 out_rstd_ptr, # pointer to the 1/std
150 M, # M = B * C
151 N,
152 C,
153 eps,
154 TILE_N: tl.constexpr,
155 HAS_WEIGHT_BIAS: tl.constexpr,
156):
157 # Map the program id to the row of X and Y it should compute.
158 pid = tl.program_id(0)
159 m_mask = pid < M
160 c_offsets = pid % C
162 # Compute mean
163 m = tl.zeros((TILE_N,), dtype=tl.float32) # mean
164 s = tl.zeros((TILE_N,), dtype=tl.float32) # sum((x - m)^2)
165 cnt = tl.zeros((TILE_N,), dtype=tl.int32)
166 num_steps = tl.cdiv(N, TILE_N)
167 for step in range(0, num_steps - 1, 1):
168 start_n = step * TILE_N
169 n_offsets = start_n + tl.arange(0, TILE_N)
170 x = tl.load(in_ptr + pid * N + n_offsets).to(tl.float32)
171 new_m = m + (x - m) / (step + 1)
172 new_s = s + (x - new_m) * (x - m)
173 cnt += 1
174 m = new_m
175 s = new_s
177 # the last step
178 for step in range(num_steps - 1, num_steps, 1):
179 start_n = step * TILE_N
180 n_offsets = start_n + tl.arange(0, TILE_N)
181 mask = n_offsets < N
182 x = tl.load(in_ptr + pid * N + n_offsets, mask=mask).to(tl.float32)
183 new_m = tl.where(mask, m + (x - m) / (step + 1), m)
184 new_s = tl.where(mask, s + (x - new_m) * (x - m), s)
185 cnt += mask.to(tl.int32)
186 m = new_m
187 s = new_s
189 final_m = tl.sum(m * cnt) / N
190 var = tl.sum(s + cnt * (m - final_m) * (m - final_m)) / N
191 rstd = tl.math.rsqrt(var + eps)
192 m = final_m
193 # Write mean / rstd
194 tl.store(out_mean_ptr + pid, m)
195 tl.store(out_rstd_ptr + pid, rstd)
197 if HAS_WEIGHT_BIAS:
198 w = tl.load(weight_ptr + c_offsets, mask=m_mask)
199 b = tl.load(bias_ptr + c_offsets, mask=m_mask)
200 else:
201 w = 1
202 b = 0
204 # reverse the order of the second sweep
205 # Normalize and apply linear transformation
206 prev_multiple = prev_multiple_of(N, TILE_N)
207 # the first step, masking is needed
208 for start_n in range(0, TILE_N, TILE_N):
209 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)
210 mask = n_offsets < N
211 x = tl.load(
212 in_ptr + pid * N + n_offsets,
213 mask=mask,
214 other=0.0,
215 eviction_policy="evict_first",
216 ).to(tl.float32)
217 out = w * (x - m) * rstd + b
218 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask)
220 for start_n in range(TILE_N, N, TILE_N):
221 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)
222 x = tl.load(in_ptr + pid * N + n_offsets, eviction_policy="evict_first").to(
223 tl.float32
224 )
225 out = w * (x - m) * rstd + b
226 tl.store(out_ptr + pid * N + n_offsets, out)
229@libentry()
230@triton.jit(do_not_specialize=["eps"])
231def instancenorm_fwd_kernel_xpu(
232 X,
233 Y,
234 W,
235 B,
236 MEAN,
237 RSTRD,
238 M: tl.constexpr,
239 N: tl.constexpr,
240 C: tl.constexpr,
241 eps: tl.constexpr,
242 HAS_WEIGHT_BIAS: tl.constexpr,
243 XBLOCK: tl.constexpr,
244 RBLOCK: tl.constexpr,
245):
246 pid = tl.program_id(0)
247 xoffset = pid * XBLOCK
248 _xindex = xoffset + tl.arange(0, XBLOCK)
249 xindex = _xindex[:, None]
250 xmask = xindex < M
251 rbase = tl.arange(0, RBLOCK)[None, :]
252 _mean = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
253 _var = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
255 for roffset in range(0, N, RBLOCK):
256 rindex = roffset + rbase
257 rmask = rindex < N
258 x = tl.load(X + (rindex + (N * xindex)), rmask & xmask, other=0.0).to(
259 tl.float32
260 )
261 _mean = _mean + tl.broadcast_to(x, [XBLOCK, RBLOCK])
262 _var = _var + tl.broadcast_to(x * x, [XBLOCK, RBLOCK])
264 mean = tl.sum(_mean, 1)[:, None] / N
265 var = tl.sum(_var, 1)[:, None] / N
266 var_mean = var - mean * mean
267 rstd = 1 / tl.sqrt(var_mean + eps)
269 tl.store(MEAN + xindex, mean, xmask)
270 tl.store(RSTRD + xindex, rstd, xmask)
272 cindex = xindex % C
273 for roffset in range(0, N, RBLOCK):
274 rindex = roffset + rbase
275 rmask = rindex < N
276 x = tl.load(X + (rindex + (N * xindex)), rmask & xmask, other=0.0).to(
277 tl.float32
278 )
279 if HAS_WEIGHT_BIAS:
280 w = tl.load(W + cindex, xmask)
281 b = tl.load(B + cindex, xmask)
282 else:
283 w = 1
284 b = 0
285 x_hat = (x - mean) * rstd
286 y = x_hat * w + b
287 tl.store(Y + (rindex + (N * xindex)), y, rmask & xmask)
290def instance_norm_use_running_stats_kernel_heur_tile_n(args):
291 return 8192
294@libentry()
295# @triton.autotune(
296# configs=runtime.get_tuned_config("instancenorm"),
297# key=["M", "N"],
298# )
299@triton.jit(do_not_specialize=["eps"])
300def instance_norm_use_running_stats_kernel(
301 in_ptr,
302 out_ptr,
303 weight_ptr,
304 bias_ptr,
305 running_mean_ptr, # pointer to the mean
306 running_var_ptr, # pointer to the var
307 out_mean_ptr, # pointer to the mean
308 out_rstd_ptr, # pointer to the 1/std
309 M, # M = B * C
310 N,
311 C,
312 eps,
313 TILE_N: tl.constexpr,
314 HAS_WEIGHT_BIAS: tl.constexpr,
315):
316 # using 1d tile makes code clean
317 # Map the program id to the row of X and Y it should compute.
318 pid = tl.program_id(0)
319 m_mask = pid < M
320 c_offsets = pid % C
322 n_offsets = tl.arange(0, TILE_N)
323 mask = n_offsets < N
325 x = tl.load(in_ptr + pid * N + n_offsets, mask, other=0.0).to(tl.float32)
326 m = tl.load(running_mean_ptr + c_offsets, mask=m_mask).to(tl.float32)
327 var = tl.load(running_var_ptr + c_offsets, mask=m_mask).to(tl.float32)
328 rstd = tl.math.rsqrt(var + eps)
330 tl.store(out_mean_ptr + pid, m)
331 tl.store(out_rstd_ptr + pid, rstd)
333 if HAS_WEIGHT_BIAS:
334 w = tl.load(weight_ptr + c_offsets, mask=m_mask).to(tl.float32)
335 b = tl.load(bias_ptr + c_offsets, mask=m_mask).to(tl.float32)
336 out = (x - m) * rstd * w + b
337 else:
338 out = (x - m) * rstd
340 tl.store(out_ptr + pid * N + n_offsets, out, mask=mask)
343@triton.jit
344def update_running_stats_kernel(
345 mean_ptr, # pointer to the mean
346 rstd_ptr, # pointer to the 1/std
347 running_mean_ptr,
348 running_var_ptr,
349 momentum,
350 B,
351 C,
352 N,
353 eps,
354 BLOCK_BATCH_SIZE: tl.constexpr = 1,
355 BLOCK_CHANNEL_SIZE: tl.constexpr = 2048,
356):
357 cid = tl.program_id(0) * BLOCK_CHANNEL_SIZE + tl.arange(0, BLOCK_CHANNEL_SIZE)
358 col_mask = cid < C
359 running_mean = tl.load(running_mean_ptr + cid, mask=col_mask).to(tl.float32)
360 running_var = tl.load(running_var_ptr + cid, mask=col_mask).to(tl.float32)
362 new_mean = tl.zeros((BLOCK_CHANNEL_SIZE,), dtype=tl.float32)
363 new_var = tl.zeros((BLOCK_CHANNEL_SIZE,), dtype=tl.float32)
364 for b in range(0, B, BLOCK_BATCH_SIZE):
365 bid = b * BLOCK_BATCH_SIZE + tl.arange(0, BLOCK_BATCH_SIZE)[:, None]
366 row_mask = bid < B
367 mask = row_mask and col_mask[None, :]
368 mean = tl.load(mean_ptr + bid * C + cid[None, :], mask=mask, other=0.0).to(
369 tl.float32
370 )
371 rstd = tl.load(rstd_ptr + bid * C + cid[None, :], mask=mask, other=0.0).to(
372 tl.float32
373 )
374 var = (
375 (1 / (rstd * rstd) + eps) * N / (N - 1)
376 ) # NOTE: use unbiased var to update running_var
378 new_mean += tl.sum(mean, axis=0)
379 new_var += tl.sum(var, axis=0)
381 new_running_mean = (1 - momentum) * running_mean + momentum * new_mean / B
382 new_running_var = (1 - momentum) * running_var + momentum * new_var / B
384 tl.store(running_mean_ptr + cid, new_running_mean, mask=col_mask)
385 tl.store(running_var_ptr + cid, new_running_var, mask=col_mask)
388def instance_norm_backward_kernel_heur_block_row_size(args):
389 return 1
392def instance_norm_backward_kernel_heur_block_col_size(args):
393 import builtins
395 return builtins.min(triton.next_power_of_2(args.get("N", 0)), 8192)
398@libentry()
399# @triton.autotune(
400# configs=runtime.get_tuned_config("instance_norm_backward"),
401# key=["M", "N", "C"],
402# )
403@triton.heuristics(
404 values={
405 "BLOCK_ROW_SIZE": instance_norm_backward_kernel_heur_block_row_size,
406 "BLOCK_COL_SIZE": instance_norm_backward_kernel_heur_block_col_size,
407 },
408)
409@triton.jit
410def instance_norm_backward_kernel(
411 dY,
412 X,
413 W,
414 Mean, # [B, C]
415 Rstd, # [B, C]
416 dX,
417 M, # M = B * C
418 N,
419 C,
420 BLOCK_ROW_SIZE: tl.constexpr,
421 BLOCK_COL_SIZE: tl.constexpr,
422 HAS_WEIGHT_BIAS: tl.constexpr,
423):
424 pid = tl.program_id(0) * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None]
425 c_offsets = pid % C
426 row_mask = pid < M
427 dY += pid * N
428 X += pid * N
429 dX += pid * N
430 Mean += pid
431 Rstd += pid
433 mean = tl.load(Mean, mask=row_mask, other=0.0).to(tl.float32)
434 rstd = tl.load(Rstd, mask=row_mask, other=1.0).to(tl.float32)
435 if HAS_WEIGHT_BIAS:
436 w = tl.load(W + c_offsets, mask=row_mask).to(tl.float32)
437 else:
438 w = 1
440 dx_part2 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
441 dx_part3 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
443 for off in range(0, N, BLOCK_COL_SIZE):
444 cols = off + tl.arange(0, BLOCK_COL_SIZE)
445 col_mask = cols[None, :] < N
446 mask = row_mask and col_mask
447 dy = tl.load(dY + cols[None, :], mask).to(tl.float32)
448 x = tl.load(X + cols[None, :], mask).to(tl.float32)
449 x = tl.where(mask, x - mean, 0.0)
450 x_hat = x * rstd
451 dx_hat = dy * w
452 dx_part2 += dx_hat
453 dx_part3 += dx_hat * x_hat
455 dx_2 = tl.sum(dx_part2, axis=1)[:, None]
456 dx_3 = tl.sum(dx_part3, axis=1)[:, None]
458 for off in range(0, N, BLOCK_COL_SIZE):
459 cols = off + tl.arange(0, BLOCK_COL_SIZE)
460 col_mask = cols[None, :] < N
461 mask = row_mask and col_mask
462 dy = tl.load(dY + cols[None, :], mask).to(tl.float32)
463 x = tl.load(X + cols[None, :], mask).to(tl.float32)
464 x = tl.where(mask, x - mean, 0.0)
465 x_hat = x * rstd
466 dx_hat = dy * w
467 dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / N)
468 tl.store(dX + cols, dx, mask=mask)
471def weight_bias_backward_kernel_heur_block_batch_size(args):
472 return 1
475def weight_bias_backward_kernel_heur_block_col_size(args):
476 return triton.next_power_of_2(triton.cdiv(args.get("C", 1), 12)) # cluster_num
479@libentry()
480# @triton.autotune(
481# configs=runtime.get_tuned_config("instance_norm_weight_bias_backward"),
482# key=["N", "B", "C"],
483# )
484@triton.heuristics(
485 values={
486 "BLOCK_BATCH_SIZE": weight_bias_backward_kernel_heur_block_batch_size,
487 "BLOCK_COL_SIZE": weight_bias_backward_kernel_heur_block_col_size,
488 },
489)
490@triton.jit
491def weight_bias_backward_kernel(
492 dY,
493 X,
494 Mean, # [B, C]
495 Rstd, # [B, C]
496 dW,
497 dB,
498 M,
499 N,
500 B,
501 C,
502 BLOCK_BATCH_SIZE: tl.constexpr,
503 BLOCK_COL_SIZE: tl.constexpr,
504):
505 cid = tl.program_id(0)[:, None]
506 dW += cid
507 dB += cid
508 c_mask = cid < C
510 accW = tl.zeros([BLOCK_BATCH_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
511 accB = tl.zeros([BLOCK_BATCH_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
513 for b_off in range(0, B, BLOCK_BATCH_SIZE):
514 bid = b_off + tl.arange(0, BLOCK_BATCH_SIZE)[:, None]
515 mid = bid * C + cid
516 row_mask = bid < B
517 mean = tl.load(Mean + mid, mask=row_mask).to(tl.float32)
518 rstd = tl.load(Rstd + mid, mask=row_mask).to(tl.float32)
519 for off in range(0, N, BLOCK_COL_SIZE):
520 cols = off + tl.arange(0, BLOCK_COL_SIZE)
521 col_mask = cols[None, :] < N
522 mask = row_mask and col_mask
523 dy = tl.load(dY + mid * N + cols[None, :], mask).to(tl.float32)
524 x = tl.load(X + mid * N + cols[None, :], mask).to(tl.float32)
525 x = tl.where(mask, x - mean, 0.0)
526 x_hat = x * rstd
527 accW += dy * x_hat
528 accB += dy
529 dw = tl.sum(accW)
530 db = tl.sum(accB)
531 tl.store(dW, dw, mask=c_mask)
532 tl.store(dB, db, mask=c_mask)
535class InstanceNorm(torch.autograd.Function):
536 @staticmethod
537 def forward(
538 ctx,
539 x,
540 weight=None,
541 bias=None,
542 running_mean=None,
543 running_var=None,
544 use_input_stats=False,
545 momentum=0.1,
546 eps=1e-05,
547 cudnn_enable=False,
548 ):
549 logger.debug("GEMS_KUNLUNXIN INSTANCENORM FORWARD")
550 assert len(x.shape) in [
551 3,
552 4,
553 5,
554 ], f"x.shape should be [B, C, N] or [B, C, H, W] or [B, C, H, W, L], but got {x.shape}"
555 B, C = x.shape[:2]
556 N = math.prod(x.shape[2:])
557 M = x.numel() // N
559 x = x.contiguous()
560 weight = weight.contiguous() if weight is not None else None
561 bias = bias.contiguous() if bias is not None else None
562 y = torch.empty_like(x)
564 has_weight_bias = weight is not None and bias is not None
566 has_running_stats = running_mean is not None
567 if has_running_stats:
568 assert (
569 N > 1
570 ), f"Expected more than 1 spatial element when training, got input size {x.shape}"
571 assert (
572 running_mean is not None and running_var is not None
573 ), "running_mean and running_var should not both be None"
574 assert (
575 running_mean.shape == running_var.shape and running_mean.shape[0] == C
576 ), f"running_mean and running_var should have shape as {[C,]}"
577 assert (
578 running_mean.dtype == running_var.dtype
579 ), "running_mean and running_var should have the same dtype"
580 if not use_input_stats:
581 assert (
582 has_running_stats
583 ), "Expected running_mean and running_var to be defined when use_input_stats is False"
585 # NOTE: when the input is half-precision(either float16 or bfloat16)
586 # these statistical data saved for backward is in single precision
587 acc_type = get_accumulator_dtype(x.dtype)
588 mean = torch.empty(size=(B, C), dtype=acc_type, device=x.device)
589 rstd = torch.empty(size=(B, C), dtype=acc_type, device=x.device)
591 with torch_device_fn.device(x.device):
592 if use_input_stats:
593 grid = (12, 1, 1)
594 instancenorm_fwd_kernel_xpu[grid](
595 x,
596 y,
597 weight,
598 bias,
599 mean,
600 rstd,
601 M,
602 N,
603 C,
604 eps,
605 HAS_WEIGHT_BIAS=has_weight_bias,
606 XBLOCK=triton.next_power_of_2(triton.cdiv(M, 12)),
607 RBLOCK=8192,
608 isCloseUnrollControl=True,
609 buffer_size_limit=512,
610 )
611 if has_running_stats and use_input_stats: # update running stats
612 grid = lambda meta: (
613 triton.cdiv(C, meta["BLOCK_CHANNEL_SIZE"]),
614 1,
615 1,
616 )
617 update_running_stats_kernel[grid](
618 mean,
619 rstd,
620 running_mean,
621 running_var,
622 momentum,
623 B,
624 C,
625 N,
626 eps,
627 isCloseCoreTiling=True,
628 isCloseVectorization=True,
629 isCloseUnrollControl=True,
630 )
631 else: # use running stats instead of input stats
632 TILE_N = triton.next_power_of_2(N)
633 grid = (M, 1, 1)
634 instance_norm_use_running_stats_kernel[grid](
635 x,
636 y,
637 weight,
638 bias,
639 running_mean,
640 running_var,
641 mean,
642 rstd,
643 M,
644 N,
645 C,
646 eps,
647 TILE_N,
648 HAS_WEIGHT_BIAS=has_weight_bias,
649 isCloseUnrollControl=True,
650 )
652 ctx.save_for_backward(x, weight, mean, rstd)
653 ctx.M = M
654 ctx.N = N
655 ctx.C = C
656 ctx.has_weight_bias = has_weight_bias
657 return y
659 @staticmethod
660 def backward(ctx, out_grad):
661 logger.debug("GEMS_KUNLUNXIN INSTANCENORM BACKWARD")
662 out_grad = out_grad.contiguous()
663 x, weight, mean, rstd = ctx.saved_tensors
664 M = ctx.M
665 N = ctx.N
666 C = ctx.C
667 B = M // C
669 with torch_device_fn.device(x.device):
670 in_grad = torch.empty_like(x)
671 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_ROW_SIZE"]), 1, 1)
673 instance_norm_backward_kernel[grid](
674 out_grad,
675 x,
676 weight,
677 mean,
678 rstd,
679 in_grad,
680 M,
681 N,
682 C,
683 HAS_WEIGHT_BIAS=ctx.has_weight_bias,
684 isCloseCoreTiling=True,
685 )
687 if ctx.has_weight_bias:
688 grid = lambda meta: (C, 1, 1)
689 weight_grad = torch.empty_like(weight)
690 bias_grad = torch.empty_like(weight)
691 weight_bias_backward_kernel[grid](
692 out_grad,
693 x,
694 mean,
695 rstd,
696 weight_grad,
697 bias_grad,
698 M,
699 N,
700 B,
701 C,
702 )
703 else:
704 weight_grad = None
705 bias_grad = None
706 return in_grad, weight_grad, bias_grad, None, None, None, None, None, None
709def instance_norm(
710 input: Tensor,
711 weight: Optional[Tensor] = None,
712 bias: Optional[Tensor] = None,
713 running_mean: Optional[Tensor] = None,
714 running_var: Optional[Tensor] = None,
715 use_input_stats: bool = True,
716 momentum: float = 0.1,
717 eps: float = 1e-5,
718 cudnn_enable: bool = False,
719) -> Tensor:
720 r"""Applies Instance Normalization for each channel in each data sample in a
721 batch.
722 Inputs:
723 input: input tensor of shape :math:`(N, C, *)`
724 weight: weight tensor of shape :math:`(C)`
725 bias: bias tensor of shape :math:`(C)`
726 running_mean: running mean tensor of shape :math:`(C)`
727 running_var: running variance tensor of shape :math:`(C)`
728 use_input_stats: whether to use the mean and variance of the input tensor
729 momentum: momentum value for the running mean and variance
730 eps: epsilon value for numerical stability
731 cudnn_enable: whether to use cudnn for normalization
732 Returns:
733 output tensor of shape :math:`(N, C, *)`
734 """
736 return InstanceNorm.apply(
737 input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps
738 )