Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/batch_norm.py: 0%
172 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from torch import Tensor
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, tl_extra_shim
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13rsqrt = tl_extra_shim.rsqrt
16def make_3d_for_bn(input: Tensor) -> Tensor:
17 if input.ndim == 2:
18 input = input.unsqueeze(-1)
19 elif input.ndim >= 4:
20 input = input.flatten(2, -1)
21 return input
24@libentry()
25@triton.heuristics(runtime.get_heuristic_config("batch_norm"))
26@triton.jit
27def batch_norm_forward_kernel(
28 input_pointer,
29 weight_pointer,
30 bias_pointer,
31 mean_pointer,
32 inv_std_pointer,
33 output_pointer,
34 running_mean_pointer,
35 running_var_pointer,
36 batch_dim,
37 spatial_dim,
38 input_batch_stride,
39 input_feat_stride,
40 input_spatial_stride,
41 output_batch_stride,
42 output_feat_stride,
43 output_spatial_stride,
44 momentum,
45 eps,
46 is_train: tl.constexpr,
47 BLOCK_M: tl.constexpr,
48 BLOCK_N: tl.constexpr,
49):
50 feat_pid = tl.program_id(axis=0)
52 if is_train:
53 # Two-pass algorithm: first compute sum, then variance
54 # Pass 1: Compute sum for mean
55 total_sum = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
57 m_num_steps = tl.cdiv(batch_dim, BLOCK_M)
58 n_num_steps = tl.cdiv(spatial_dim, BLOCK_N)
60 for m_step in range(0, m_num_steps):
61 for n_step in range(0, n_num_steps):
62 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N)
63 spatial_mask = spatial_offset < spatial_dim
65 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M)
66 batch_mask = batch_offset < batch_dim
68 curr_input_pointer = (
69 input_pointer
70 + input_feat_stride * feat_pid
71 + input_batch_stride * batch_offset[:, None]
72 + input_spatial_stride * spatial_offset[None, :]
73 )
75 mask = batch_mask[:, None] & spatial_mask[None, :]
76 curr_input = tl.load(curr_input_pointer, mask=mask, other=0.0).to(
77 tl.float32
78 )
79 total_sum += curr_input
81 n_elements = batch_dim * spatial_dim
82 mean = tl.sum(total_sum) / n_elements
84 # Pass 2: Compute variance
85 var_sum = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
87 for m_step in range(0, m_num_steps):
88 for n_step in range(0, n_num_steps):
89 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N)
90 spatial_mask = spatial_offset < spatial_dim
92 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M)
93 batch_mask = batch_offset < batch_dim
95 curr_input_pointer = (
96 input_pointer
97 + input_feat_stride * feat_pid
98 + input_batch_stride * batch_offset[:, None]
99 + input_spatial_stride * spatial_offset[None, :]
100 )
102 mask = batch_mask[:, None] & spatial_mask[None, :]
103 curr_input = tl.load(curr_input_pointer, mask=mask, other=0.0).to(
104 tl.float32
105 )
106 diff = tl.where(mask, curr_input - mean, 0.0)
107 var_sum += diff * diff
109 var = tl.sum(var_sum) / n_elements
110 inv_std = rsqrt(var + eps)
112 tl.store(feat_pid + mean_pointer, mean)
113 tl.store(feat_pid + inv_std_pointer, inv_std)
115 running_mean_pointer += feat_pid
116 running_var_pointer += feat_pid
118 running_mean = tl.load(running_mean_pointer)
119 running_var = tl.load(running_var_pointer)
121 tl.store(running_mean_pointer, (1 - momentum) * running_mean + momentum * mean)
122 tl.store(
123 running_var_pointer,
124 (1 - momentum) * running_var
125 + momentum * var * n_elements / (n_elements - 1),
126 )
128 else:
129 mean = tl.load(feat_pid + running_mean_pointer)
130 inv_std = rsqrt(tl.load(feat_pid + running_var_pointer) + eps)
132 if weight_pointer:
133 weight = tl.load(feat_pid + weight_pointer).to(tl.float32)
134 else:
135 weight = 1.0
136 if bias_pointer:
137 bias = tl.load(feat_pid + bias_pointer).to(tl.float32)
138 else:
139 bias = 0.0
141 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)):
142 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)):
143 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M)
144 batch_mask = batch_offset < batch_dim
146 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N)
147 spatial_mask = spatial_offset < spatial_dim
149 curr_input_pointer = (
150 input_pointer
151 + input_feat_stride * feat_pid
152 + input_batch_stride * batch_offset[:, None]
153 + input_spatial_stride * spatial_offset[None, :]
154 )
155 curr_output_pointer = (
156 output_pointer
157 + output_feat_stride * feat_pid
158 + output_batch_stride * batch_offset[:, None]
159 + output_spatial_stride * spatial_offset[None, :]
160 )
162 curr_input = tl.load(
163 curr_input_pointer, mask=batch_mask[:, None] & spatial_mask[None, :]
164 ).to(tl.float32)
165 output = weight * (curr_input - mean) * inv_std + bias
167 tl.store(
168 curr_output_pointer,
169 output,
170 mask=batch_mask[:, None] & spatial_mask[None, :],
171 )
174def batch_norm_heur_block_m(args):
175 return min(64, triton.next_power_of_2(args.get("batch_dim", 0)))
178def batch_norm_heur_block_n(args):
179 BLOCK_M = batch_norm_heur_block_m(args)
180 BLOCK_N = triton.next_power_of_2(args.get("spatial_dim", 0))
181 return min(BLOCK_N, max(1, 2**14 // BLOCK_M))
184@libentry()
185@triton.heuristics(
186 values={
187 "BLOCK_M": batch_norm_heur_block_m,
188 "BLOCK_N": batch_norm_heur_block_n,
189 },
190)
191@triton.jit
192def batch_norm_backward_kernel(
193 output_grad_pointer,
194 input_pointer,
195 mean_pointer,
196 inv_std_pointer,
197 weight_pointer,
198 input_grad_pointer,
199 weight_grad_pointer,
200 bias_grad_pointer,
201 batch_dim,
202 spatial_dim,
203 output_grad_batch_stride,
204 output_grad_feat_stride,
205 output_grad_spatial_stride,
206 input_batch_stride,
207 input_feat_stride,
208 input_spatial_stride,
209 input_grad_batch_stride,
210 input_grad_feat_stride,
211 input_grad_spatial_stride,
212 input_grad_mask: tl.constexpr,
213 weight_grad_mask: tl.constexpr,
214 bias_grad_mask: tl.constexpr,
215 BLOCK_M: tl.constexpr,
216 BLOCK_N: tl.constexpr,
217):
218 feat_pid = tl.program_id(axis=0)
220 mean = tl.load(feat_pid + mean_pointer).to(tl.float32)
221 inv_std = tl.load(feat_pid + inv_std_pointer).to(tl.float32)
223 term1 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
224 term2 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
226 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)):
227 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M)
228 batch_mask = batch_offset < batch_dim
230 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)):
231 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N)
232 spatial_mask = spatial_offset < spatial_dim
234 curr_output_grad_pointer = (
235 output_grad_pointer
236 + output_grad_feat_stride * feat_pid
237 + output_grad_batch_stride * batch_offset[:, None]
238 + output_grad_spatial_stride * spatial_offset[None, :]
239 )
240 curr_input_pointer = (
241 input_pointer
242 + input_feat_stride * feat_pid
243 + input_batch_stride * batch_offset[:, None]
244 + input_spatial_stride * spatial_offset[None, :]
245 )
247 mask = batch_mask[:, None] & spatial_mask[None, :]
248 curr_input = tl.load(curr_input_pointer, mask=mask, other=0).to(tl.float32)
250 curr_pre_lin = ((curr_input - mean) * inv_std).to(tl.float32)
251 curr_output_grad = tl.load(
252 curr_output_grad_pointer, mask=mask, other=0.0
253 ).to(tl.float32)
255 term1 += curr_pre_lin * curr_output_grad
256 term2 += curr_output_grad
258 term1 = tl.sum(term1)
259 term2 = tl.sum(term2)
261 if weight_grad_mask:
262 tl.store(feat_pid + weight_grad_pointer, term1)
263 if bias_grad_mask:
264 tl.store(feat_pid + bias_grad_pointer, term2)
266 if not input_grad_mask:
267 return
269 if weight_pointer:
270 weight = tl.load(feat_pid + weight_pointer).to(tl.float32)
271 else:
272 weight = 1.0
273 weight = weight.to(tl.float32)
275 count = batch_dim * spatial_dim
277 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)):
278 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)):
279 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M)
280 batch_mask = batch_offset < batch_dim
282 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N)
283 spatial_mask = spatial_offset < spatial_dim
285 curr_output_grad_pointer = (
286 output_grad_pointer
287 + output_grad_feat_stride * feat_pid
288 + output_grad_batch_stride * batch_offset[:, None]
289 + output_grad_spatial_stride * spatial_offset[None, :]
290 )
291 curr_input_pointer = (
292 input_pointer
293 + input_feat_stride * feat_pid
294 + input_batch_stride * batch_offset[:, None]
295 + input_spatial_stride * spatial_offset[None, :]
296 )
297 curr_input_grad_pointer = (
298 input_grad_pointer
299 + input_grad_feat_stride * feat_pid
300 + input_grad_batch_stride * batch_offset[:, None]
301 + input_grad_spatial_stride * spatial_offset[None, :]
302 )
304 curr_input = tl.load(
305 curr_input_pointer, mask=batch_mask[:, None] & spatial_mask[None, :]
306 ).to(tl.float32)
307 curr_pre_lin = (curr_input - mean) * inv_std
308 curr_output_grad = tl.load(
309 curr_output_grad_pointer,
310 mask=batch_mask[:, None] & spatial_mask[None, :],
311 ).to(tl.float32)
312 curr_input_grad = (
313 inv_std
314 * weight
315 * (curr_output_grad - (term1 * curr_pre_lin + term2) / count)
316 )
317 tl.store(
318 curr_input_grad_pointer,
319 curr_input_grad,
320 mask=batch_mask[:, None] & spatial_mask[None, :],
321 )
324def batch_norm(
325 input: Tensor,
326 weight=None,
327 bias=None,
328 running_mean=None,
329 running_var=None,
330 training=False,
331 momentum=0.1,
332 eps=1e-05,
333):
334 logger.debug("GEMS_KUNLUNXIN BATCH_NORM")
336 input_3d_i = make_3d_for_bn(input)
337 m, n, k = input_3d_i.shape
338 input_3d_f = input_3d_i.permute(0, 2, 1).reshape(-1, n)
339 input_3d = make_3d_for_bn(input_3d_f)
341 batch_dim, feat_dim, spatial_dim = input_3d.shape
342 output = torch.empty_like(input_3d)
344 mean = torch.empty(feat_dim, device=input.device, dtype=input.dtype)
345 inv_std = torch.empty(feat_dim, device=input.device, dtype=input.dtype)
347 running_mean = input if running_mean is None else running_mean
348 running_var = input if running_var is None else running_var
350 with torch_device_fn.device(input.device):
351 batch_norm_forward_kernel[(feat_dim,)](
352 input_3d,
353 weight,
354 bias,
355 mean,
356 inv_std,
357 output,
358 running_mean,
359 running_var,
360 batch_dim,
361 spatial_dim,
362 *input_3d.stride(),
363 *output.stride(),
364 momentum,
365 eps,
366 is_train=training,
367 buffer_size_limit=2048,
368 )
370 output_reshaped = output.reshape(m, k, n).permute(0, 2, 1)
371 return output_reshaped.view_as(input), mean, inv_std
374def batch_norm_backward(
375 grad_out,
376 input,
377 weight=None,
378 running_mean=None,
379 running_var=None,
380 save_mean=None,
381 save_invstd=None,
382 train=False,
383 eps=1e-05,
384 output_mask=None,
385):
386 logger.debug("GEMS_KUNLUNXIN BATCH_NORM_BACKWARD")
387 input_3d_i = make_3d_for_bn(input)
388 m, n, k = input_3d_i.shape
389 input_3d_f = input_3d_i.permute(0, 2, 1).reshape(-1, n)
390 input_3d = make_3d_for_bn(input_3d_f)
392 output_grad_3d_i = make_3d_for_bn(grad_out)
393 output_grad_3d_f = output_grad_3d_i.permute(0, 2, 1).reshape(-1, n)
394 output_grad_3d = make_3d_for_bn(output_grad_3d_f)
396 batch_dim, feat_dim, spatial_dim = input_3d.shape
398 if output_mask[0]:
399 input_grad = torch.empty_like(input_3d)
400 else:
401 input_grad = None
402 if output_mask[1]:
403 weight_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device)
404 else:
405 weight_grad = None
406 if output_mask[2]:
407 bias_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device)
408 else:
409 bias_grad = None
411 with torch_device_fn.device(input.device):
412 batch_norm_backward_kernel[(feat_dim, 1, 1)](
413 output_grad_3d,
414 input_3d,
415 save_mean,
416 save_invstd,
417 weight,
418 input_grad,
419 weight_grad,
420 bias_grad,
421 batch_dim,
422 spatial_dim,
423 *output_grad_3d.stride(),
424 *input_3d.stride(),
425 *input_grad.stride(),
426 *output_mask,
427 buffer_size_limit=2048,
428 )
430 return (
431 input_grad.reshape(m, k, n).permute(0, 2, 1).view_as(input),
432 weight_grad,
433 bias_grad,
434 )