Coverage for src/flag_gems/runtime/backend/_ascend/ops/var_mean.py: 0%
156 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.runtime.backend._ascend import heuristics_config_utils as _hcu
10from flag_gems.utils import dim_compress, libentry
11from flag_gems.utils import triton_lang_extension as ext
13logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
16@triton.jit
17def welford_func(mean_x, count_x, M_x, mean_y, count_y, M_y):
18 count = count_x + count_y
19 _count = tl.maximum(count, 1)
20 mc_x = mean_x * count_x
21 mc_y = mean_y * count_y
22 mean = (mc_x + mc_y) / _count
23 M = M_x + mc_x * mean_x + M_y + mc_y * mean_y - count * mean * mean
24 return mean, count, M
27@libentry()
28@triton.autotune(configs=runtime.get_tuned_config("var_mean"), key=["M", "N"])
29@triton.jit(do_not_specialize=["correction"])
30def var_mean_welford_kernel(
31 X,
32 Var,
33 Mean,
34 M,
35 N,
36 correction,
37 BLOCK_M: tl.constexpr,
38 BLOCK_N: tl.constexpr,
39):
40 # Map the program id to the row of X it should compute.
41 pid = ext.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
42 X = X + pid * N
43 Var = Var + pid
44 Mean = Mean + pid
45 row_mask = pid < M
47 _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
48 _acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
49 _count = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
51 for off in range(0, N, BLOCK_N):
52 cols = off + tl.arange(0, BLOCK_N)[None, :]
53 col_mask = cols < N
54 mask = row_mask and col_mask
55 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
57 count = _count + mask
58 cnt = tl.maximum(count, 1)
59 cur_mean = (_mean * _count + x) / cnt
60 _acc += (x - cur_mean) * (x - _mean) * mask
61 _mean = cur_mean
62 _count = count
64 # 手动实现 tl.reduce 的功能,沿着 axis=1 进行归约
65 # 使用 tl.sum 来进行归约,这等价于 welford 算法在这种情况下的行为
67 # 计算每行的总计数
68 total_count = tl.sum(_count, axis=1) # shape: (BLOCK_M,)
70 # 计算加权平均值
71 weighted_sum = tl.sum(_mean * _count, axis=1) # shape: (BLOCK_M,)
72 mean = weighted_sum / tl.maximum(total_count, 1) # shape: (BLOCK_M,)
74 # 计算方差累积值
75 # 对于每个元素,计算其对总体方差的贡献
76 mean_expanded = mean[:, None] # shape: (BLOCK_M, 1)
78 # 计算每个局部统计量对总体方差的贡献
79 # 这是 Welford 算法的并行化版本
80 local_var_contrib = _acc + _count * (_mean - mean_expanded) * (
81 _mean - mean_expanded
82 )
83 acc = tl.sum(local_var_contrib, axis=1) # shape: (BLOCK_M,)
85 var = acc / (N - correction)
86 mean = mean[:, None]
87 var = var[:, None]
89 # Write mean / var
90 tl.store(Mean, mean, row_mask)
91 tl.store(Var, var, row_mask)
94@libentry()
95@triton.autotune(configs=runtime.get_tuned_config("var_mean"), key=["M", "N"])
96@triton.jit(do_not_specialize=["correction"])
97def var_mean_welford_kernel_simple(
98 X,
99 Var,
100 Mean,
101 M,
102 N,
103 correction,
104 BLOCK_M: tl.constexpr,
105 BLOCK_N: tl.constexpr,
106):
107 # 程序ID映射
108 pid = ext.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
109 X = X + pid * N
110 Var = Var + pid
111 Mean = Mean + pid
112 row_mask = pid < M
114 # 每行单独处理
115 for row in range(BLOCK_M):
116 if row < BLOCK_M:
117 current_row_mask = (tl.arange(0, BLOCK_M) == row)[:, None] & row_mask
119 if tl.sum(current_row_mask.to(tl.int32)) > 0:
120 # 初始化当前行的统计量
121 running_mean = 0.0
122 running_M = 0.0
123 count = 0
125 # 按块处理当前行
126 for off in range(0, N, BLOCK_N):
127 cols = off + tl.arange(0, BLOCK_N)
128 col_mask = cols < N
130 # 加载数据
131 x_vals = tl.load(X + row * N + cols, col_mask, other=0.0).to(
132 tl.float32
133 )
135 # 对块内每个有效元素进行在线更新
136 for i in range(BLOCK_N):
137 if i < BLOCK_N and (off + i) < N:
138 count += 1
139 x = x_vals[i]
141 delta = x - running_mean
142 running_mean += delta / count
143 delta2 = x - running_mean
144 running_M += delta * delta2
146 # 计算方差
147 variance = running_M / (N - correction) if N > correction else 0.0
149 # 存储结果
150 tl.store(Mean + row, running_mean, current_row_mask[:, 0])
151 tl.store(Var + row, variance, current_row_mask[:, 0])
154@libentry()
155@triton.jit
156def var_mean_kernel_1(
157 X,
158 Acc,
159 Average,
160 Count,
161 N,
162 BLOCK_N: tl.constexpr,
163):
164 # Map the program id to the row of X it should compute.
165 pid = ext.program_id(0)
166 offset = pid * BLOCK_N + tl.arange(0, BLOCK_N)
168 X = X + offset
169 Acc = Acc + pid
170 Average = Average + pid
171 Count = Count + pid
172 mask = offset < N
174 x = tl.load(X, mask, other=0.0).to(tl.float32)
176 count = tl.sum(mask.to(tl.float32))
177 average = tl.sum(x) / count
178 acc = tl.sum(x * x) - count * average * average
180 tl.store(Average, average)
181 tl.store(Acc, acc)
182 tl.store(Count, count)
185@libentry()
186@triton.heuristics(_hcu.HEURISTICS_CONFIGS["var_mean"])
187@triton.jit(do_not_specialize=["correction"])
188def var_mean_kernel_2(
189 Acc,
190 Average,
191 Count,
192 Var,
193 Mean,
194 N,
195 correction,
196 BLOCK_NUM,
197 BLOCK_N: tl.constexpr,
198):
199 offset = tl.arange(0, BLOCK_N)
200 mask = offset < BLOCK_NUM
201 Acc = Acc + offset
202 Average = Average + offset
203 Count = Count + offset
204 acc = tl.load(Acc, mask, other=0.0).to(tl.float32)
205 average = tl.load(Average, mask, other=0.0).to(tl.float32)
206 count = tl.load(Count, mask, other=0.0).to(tl.float32)
208 # mean, _, nvar = tl.reduce((average, count, acc), axis=0, combine_fn=welford_func)
209 # 手动实现 tl.reduce 的功能,沿着 axis=0 进行归约
210 # 计算总计数
211 total_count = tl.sum(count)
213 # 计算加权平均值
214 weighted_sum = tl.sum(average * count)
215 mean = weighted_sum / tl.maximum(total_count, 1)
217 # 计算方差累积值
218 # 对于每个块,计算其对总体方差的贡献
219 # 这是 Welford 算法的并行化版本
220 local_var_contrib = acc + count * (average - mean) * (average - mean)
221 nvar = tl.sum(local_var_contrib)
223 var = nvar / (N - correction)
224 tl.store(Mean, mean)
225 tl.store(Var, var)
228def var_mean(x, dim=None, *, correction=None, keepdim=False):
229 logger.debug("GEMS_ASCEND VAR MEAN")
230 if correction is None:
231 correction = 1.0
233 if dim is None or len(dim) == x.ndim:
234 dim = list(range(x.ndim))
235 shape = [1] * x.ndim
236 N = x.numel()
237 var = torch.empty(shape, dtype=x.dtype, device=x.device)
238 mean = torch.empty(shape, dtype=x.dtype, device=x.device)
239 BLOCK_N = 1024
240 BLOCK_NUM = triton.cdiv(N, BLOCK_N)
241 acc = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device)
242 average = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device)
243 count = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device)
245 with torch_device_fn.device(x.device):
246 var_mean_kernel_1[(BLOCK_NUM,)](x, acc, average, count, N, BLOCK_N=BLOCK_N)
247 var_mean_kernel_2[(1,)](
248 acc, average, count, var, mean, N, correction, BLOCK_NUM
249 )
250 else:
251 shape = list(x.shape)
252 dim = [d % x.ndim for d in dim]
253 x = dim_compress(x, dim)
254 N = 1
255 for i in dim:
256 N *= shape[i]
257 shape[i] = 1
258 M = x.numel() // N
259 var = torch.empty(shape, dtype=x.dtype, device=x.device)
260 mean = torch.empty(shape, dtype=x.dtype, device=x.device)
262 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
263 with torch_device_fn.device(x.device):
264 var_mean_welford_kernel[grid](x, var, mean, M, N, correction)
266 if not keepdim:
267 var = var.squeeze(dim=dim)
268 mean = mean.squeeze(dim=dim)
269 return var, mean