Coverage for src/flag_gems/runtime/backend/_ascend/ops/batch_norm.py: 0%
159 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from torch import Tensor
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry, tl_extra_shim
11logger = logging.getLogger(__name__)
12rsqrt = tl_extra_shim.rsqrt
15def make_3d_for_bn(input: Tensor) -> Tensor:
16 if input.ndim == 2:
17 input = input.unsqueeze(-1)
18 elif input.ndim >= 4:
19 input = input.flatten(2, -1)
20 return input
23def _block_m(batch_dim):
24 return min(64, triton.next_power_of_2(batch_dim))
27def _block_n(batch_dim, spatial_dim):
28 BLOCK_M = _block_m(batch_dim)
29 BLOCK_N = triton.next_power_of_2(spatial_dim)
30 return min(BLOCK_N, max(1, 2**10 // BLOCK_M))
33@libentry()
34@triton.jit
35def batch_norm_forward_kernel(
36 input_pointer,
37 weight_pointer,
38 bias_pointer,
39 mean_pointer,
40 inv_std_pointer,
41 output_pointer,
42 running_mean_pointer,
43 running_var_pointer,
44 batch_dim,
45 spatial_dim,
46 input_batch_stride,
47 input_feat_stride,
48 input_spatial_stride,
49 output_batch_stride,
50 output_feat_stride,
51 output_spatial_stride,
52 momentum,
53 eps,
54 is_train: tl.constexpr,
55 BLOCK_M: tl.constexpr,
56 BLOCK_N: tl.constexpr,
57):
58 feat_pid = tl.program_id(axis=0)
60 if is_train:
61 mean = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
62 var = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
63 cnt = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
65 m_num_steps = tl.cdiv(batch_dim, BLOCK_M)
66 n_num_steps = tl.cdiv(spatial_dim, BLOCK_N)
68 for m_step in range(0, m_num_steps):
69 for n_step in range(0, n_num_steps):
70 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N)
71 spatial_mask = spatial_offset < spatial_dim
73 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M)
74 batch_mask = batch_offset < batch_dim
76 curr_input_pointer = (
77 input_pointer
78 + input_feat_stride * feat_pid
79 + input_batch_stride * batch_offset[:, None]
80 + input_spatial_stride * spatial_offset[None, :]
81 )
83 mask = batch_mask[:, None] & spatial_mask[None, :]
84 curr_input = tl.load(curr_input_pointer, mask=mask).to(tl.float32)
86 step = m_step * n_num_steps + n_step + 1
87 new_mean = tl.where(mask, mean + (curr_input - mean) / step, mean)
88 new_var = tl.where(
89 mask, var + (curr_input - new_mean) * (curr_input - mean), var
90 )
91 cnt += mask.to(tl.int32)
92 mean = new_mean
93 var = new_var
95 final_mean = tl.sum(mean * cnt) / (batch_dim * spatial_dim)
96 var = tl.sum(var + cnt * (mean - final_mean) * (mean - final_mean)) / (
97 batch_dim * spatial_dim
98 )
99 inv_std = rsqrt(var + eps)
100 mean = final_mean
102 tl.store(feat_pid + mean_pointer, mean)
103 tl.store(feat_pid + inv_std_pointer, inv_std)
105 running_mean_pointer += feat_pid
106 running_var_pointer += feat_pid
108 running_mean = tl.load(running_mean_pointer)
109 running_var = tl.load(running_var_pointer)
111 n = batch_dim * spatial_dim
112 tl.store(running_mean_pointer, (1 - momentum) * running_mean + momentum * mean)
113 tl.store(
114 running_var_pointer,
115 (1 - momentum) * running_var + momentum * var * n / (n - 1),
116 )
118 else:
119 mean = tl.load(feat_pid + running_mean_pointer)
120 inv_std = rsqrt(tl.load(feat_pid + running_var_pointer) + eps)
122 if weight_pointer:
123 weight = tl.load(feat_pid + weight_pointer).to(tl.float32)
124 else:
125 weight = 1.0
126 if bias_pointer:
127 bias = tl.load(feat_pid + bias_pointer).to(tl.float32)
128 else:
129 bias = 0.0
131 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)):
132 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)):
133 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M)
134 batch_mask = batch_offset < batch_dim
136 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N)
137 spatial_mask = spatial_offset < spatial_dim
139 curr_input_pointer = (
140 input_pointer
141 + input_feat_stride * feat_pid
142 + input_batch_stride * batch_offset[:, None]
143 + input_spatial_stride * spatial_offset[None, :]
144 )
145 curr_output_pointer = (
146 output_pointer
147 + output_feat_stride * feat_pid
148 + output_batch_stride * batch_offset[:, None]
149 + output_spatial_stride * spatial_offset[None, :]
150 )
152 curr_input = tl.load(
153 curr_input_pointer, mask=batch_mask[:, None] & spatial_mask[None, :]
154 ).to(tl.float32)
155 output = weight * (curr_input - mean) * inv_std + bias
157 tl.store(
158 curr_output_pointer,
159 output,
160 mask=batch_mask[:, None] & spatial_mask[None, :],
161 )
164@libentry()
165@triton.jit
166def batch_norm_backward_kernel(
167 output_grad_pointer,
168 input_pointer,
169 mean_pointer,
170 inv_std_pointer,
171 weight_pointer,
172 input_grad_pointer,
173 weight_grad_pointer,
174 bias_grad_pointer,
175 batch_dim,
176 spatial_dim,
177 output_grad_batch_stride,
178 output_grad_feat_stride,
179 output_grad_spatial_stride,
180 input_batch_stride,
181 input_feat_stride,
182 input_spatial_stride,
183 input_grad_batch_stride,
184 input_grad_feat_stride,
185 input_grad_spatial_stride,
186 input_grad_mask: tl.constexpr,
187 weight_grad_mask: tl.constexpr,
188 bias_grad_mask: tl.constexpr,
189 BLOCK_M: tl.constexpr,
190 BLOCK_N: tl.constexpr,
191):
192 feat_pid = tl.program_id(axis=0)
194 mean = tl.load(feat_pid + mean_pointer).to(tl.float32)
195 inv_std = tl.load(feat_pid + inv_std_pointer).to(tl.float32)
197 term1 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
198 term2 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
200 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)):
201 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)):
202 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M)
203 batch_mask = batch_offset < batch_dim
205 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N)
206 spatial_mask = spatial_offset < spatial_dim
208 curr_output_grad_pointer = (
209 output_grad_pointer
210 + output_grad_feat_stride * feat_pid
211 + output_grad_batch_stride * batch_offset[:, None]
212 + output_grad_spatial_stride * spatial_offset[None, :]
213 )
214 curr_input_pointer = (
215 input_pointer
216 + input_feat_stride * feat_pid
217 + input_batch_stride * batch_offset[:, None]
218 + input_spatial_stride * spatial_offset[None, :]
219 )
221 mask = batch_mask[:, None] & spatial_mask[None, :]
222 curr_input = tl.load(curr_input_pointer, mask=mask).to(tl.float32)
224 curr_pre_lin = (curr_input - mean) * inv_std
225 curr_output_grad = tl.load(curr_output_grad_pointer, mask=mask).to(
226 tl.float32
227 )
229 term1 += curr_pre_lin * curr_output_grad
230 term2 += curr_output_grad
232 term1 = tl.sum(term1)
233 term2 = tl.sum(term2)
235 if weight_grad_mask:
236 tl.store(feat_pid + weight_grad_pointer, term1)
237 if bias_grad_mask:
238 tl.store(feat_pid + bias_grad_pointer, term2)
240 if not input_grad_mask:
241 return
243 if weight_pointer:
244 weight = tl.load(feat_pid + weight_pointer).to(tl.float32)
245 else:
246 weight = 1.0
248 count = batch_dim * spatial_dim
250 for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)):
251 for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)):
252 batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M)
253 batch_mask = batch_offset < batch_dim
255 spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N)
256 spatial_mask = spatial_offset < spatial_dim
258 curr_output_grad_pointer = (
259 output_grad_pointer
260 + output_grad_feat_stride * feat_pid
261 + output_grad_batch_stride * batch_offset[:, None]
262 + output_grad_spatial_stride * spatial_offset[None, :]
263 )
264 curr_input_pointer = (
265 input_pointer
266 + input_feat_stride * feat_pid
267 + input_batch_stride * batch_offset[:, None]
268 + input_spatial_stride * spatial_offset[None, :]
269 )
270 curr_input_grad_pointer = (
271 input_grad_pointer
272 + input_grad_feat_stride * feat_pid
273 + input_grad_batch_stride * batch_offset[:, None]
274 + input_grad_spatial_stride * spatial_offset[None, :]
275 )
277 curr_input = tl.load(
278 curr_input_pointer, mask=batch_mask[:, None] & spatial_mask[None, :]
279 ).to(tl.float32)
280 curr_pre_lin = (curr_input - mean) * inv_std
281 curr_output_grad = tl.load(
282 curr_output_grad_pointer,
283 mask=batch_mask[:, None] & spatial_mask[None, :],
284 ).to(tl.float32)
285 curr_input_grad = (
286 inv_std
287 * weight
288 * (curr_output_grad - (term1 * curr_pre_lin + term2) / count)
289 )
290 tl.store(
291 curr_input_grad_pointer,
292 curr_input_grad,
293 mask=batch_mask[:, None] & spatial_mask[None, :],
294 )
297def batch_norm(
298 input: Tensor,
299 weight=None,
300 bias=None,
301 running_mean=None,
302 running_var=None,
303 training=False,
304 momentum=0.1,
305 eps=1e-05,
306):
307 logger.debug("GEMS_ASCEND BATCHNORM FORWARD")
309 input_3d = make_3d_for_bn(input)
311 batch_dim, feat_dim, spatial_dim = input_3d.shape
312 output = torch.empty_like(input_3d)
314 mean = torch.empty(feat_dim, device=input.device, dtype=input.dtype)
315 inv_std = torch.empty(feat_dim, device=input.device, dtype=input.dtype)
317 running_mean = input if running_mean is None else running_mean
318 running_var = input if running_var is None else running_var
320 BM = _block_m(batch_dim)
321 BN = _block_n(batch_dim, spatial_dim)
323 with torch_device_fn.device(input.device):
324 batch_norm_forward_kernel[(feat_dim,)](
325 input_3d,
326 weight,
327 bias,
328 mean,
329 inv_std,
330 output,
331 running_mean,
332 running_var,
333 batch_dim,
334 spatial_dim,
335 *input_3d.stride(),
336 *output.stride(),
337 momentum,
338 eps,
339 is_train=training,
340 BLOCK_M=BM,
341 BLOCK_N=BN,
342 )
344 return output.view_as(input), mean, inv_std
347def batch_norm_backward(
348 grad_out,
349 input,
350 weight=None,
351 running_mean=None,
352 running_var=None,
353 save_mean=None,
354 save_invstd=None,
355 train=False,
356 eps=1e-05,
357 output_mask=None,
358):
359 logger.debug("GEMS_ASCEND BATCHNORM BACKWARD")
360 input_3d = make_3d_for_bn(input)
361 output_grad_3d = make_3d_for_bn(grad_out)
363 batch_dim, feat_dim, spatial_dim = input_3d.shape
365 if output_mask[0]:
366 input_grad = torch.empty_like(input_3d)
367 else:
368 input_grad = None
369 if output_mask[1]:
370 weight_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device)
371 else:
372 weight_grad = None
373 if output_mask[2]:
374 bias_grad = torch.empty((feat_dim,), dtype=input.dtype, device=input.device)
375 else:
376 bias_grad = None
378 BM = _block_m(batch_dim)
379 BN = _block_n(batch_dim, spatial_dim)
381 with torch_device_fn.device(input.device):
382 batch_norm_backward_kernel[(feat_dim,)](
383 output_grad_3d,
384 input_3d,
385 save_mean,
386 save_invstd,
387 weight,
388 input_grad,
389 weight_grad,
390 bias_grad,
391 batch_dim,
392 spatial_dim,
393 *output_grad_3d.stride(),
394 *input_3d.stride(),
395 *input_grad.stride(),
396 *output_mask,
397 BLOCK_M=BM,
398 BLOCK_N=BN,
399 )
401 return (
402 input_grad.view_as(input),
403 weight_grad,
404 bias_grad,
405 )