Coverage for src/flag_gems/runtime/backend/_ascend/ops/var_mean.py: 0%

156 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.runtime.backend._ascend import heuristics_config_utils as _hcu 

10from flag_gems.utils import dim_compress, libentry 

11from flag_gems.utils import triton_lang_extension as ext 

12 

13logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

14 

15 

16@triton.jit 

17def welford_func(mean_x, count_x, M_x, mean_y, count_y, M_y): 

18 count = count_x + count_y 

19 _count = tl.maximum(count, 1) 

20 mc_x = mean_x * count_x 

21 mc_y = mean_y * count_y 

22 mean = (mc_x + mc_y) / _count 

23 M = M_x + mc_x * mean_x + M_y + mc_y * mean_y - count * mean * mean 

24 return mean, count, M 

25 

26 

27@libentry() 

28@triton.autotune(configs=runtime.get_tuned_config("var_mean"), key=["M", "N"]) 

29@triton.jit(do_not_specialize=["correction"]) 

30def var_mean_welford_kernel( 

31 X, 

32 Var, 

33 Mean, 

34 M, 

35 N, 

36 correction, 

37 BLOCK_M: tl.constexpr, 

38 BLOCK_N: tl.constexpr, 

39): 

40 # Map the program id to the row of X it should compute. 

41 pid = ext.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

42 X = X + pid * N 

43 Var = Var + pid 

44 Mean = Mean + pid 

45 row_mask = pid < M 

46 

47 _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

48 _acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

49 _count = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

50 

51 for off in range(0, N, BLOCK_N): 

52 cols = off + tl.arange(0, BLOCK_N)[None, :] 

53 col_mask = cols < N 

54 mask = row_mask and col_mask 

55 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

56 

57 count = _count + mask 

58 cnt = tl.maximum(count, 1) 

59 cur_mean = (_mean * _count + x) / cnt 

60 _acc += (x - cur_mean) * (x - _mean) * mask 

61 _mean = cur_mean 

62 _count = count 

63 

64 # 手动实现 tl.reduce 的功能,沿着 axis=1 进行归约 

65 # 使用 tl.sum 来进行归约,这等价于 welford 算法在这种情况下的行为 

66 

67 # 计算每行的总计数 

68 total_count = tl.sum(_count, axis=1) # shape: (BLOCK_M,) 

69 

70 # 计算加权平均值 

71 weighted_sum = tl.sum(_mean * _count, axis=1) # shape: (BLOCK_M,) 

72 mean = weighted_sum / tl.maximum(total_count, 1) # shape: (BLOCK_M,) 

73 

74 # 计算方差累积值 

75 # 对于每个元素,计算其对总体方差的贡献 

76 mean_expanded = mean[:, None] # shape: (BLOCK_M, 1) 

77 

78 # 计算每个局部统计量对总体方差的贡献 

79 # 这是 Welford 算法的并行化版本 

80 local_var_contrib = _acc + _count * (_mean - mean_expanded) * ( 

81 _mean - mean_expanded 

82 ) 

83 acc = tl.sum(local_var_contrib, axis=1) # shape: (BLOCK_M,) 

84 

85 var = acc / (N - correction) 

86 mean = mean[:, None] 

87 var = var[:, None] 

88 

89 # Write mean / var 

90 tl.store(Mean, mean, row_mask) 

91 tl.store(Var, var, row_mask) 

92 

93 

94@libentry() 

95@triton.autotune(configs=runtime.get_tuned_config("var_mean"), key=["M", "N"]) 

96@triton.jit(do_not_specialize=["correction"]) 

97def var_mean_welford_kernel_simple( 

98 X, 

99 Var, 

100 Mean, 

101 M, 

102 N, 

103 correction, 

104 BLOCK_M: tl.constexpr, 

105 BLOCK_N: tl.constexpr, 

106): 

107 # 程序ID映射 

108 pid = ext.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

109 X = X + pid * N 

110 Var = Var + pid 

111 Mean = Mean + pid 

112 row_mask = pid < M 

113 

114 # 每行单独处理 

115 for row in range(BLOCK_M): 

116 if row < BLOCK_M: 

117 current_row_mask = (tl.arange(0, BLOCK_M) == row)[:, None] & row_mask 

118 

119 if tl.sum(current_row_mask.to(tl.int32)) > 0: 

120 # 初始化当前行的统计量 

121 running_mean = 0.0 

122 running_M = 0.0 

123 count = 0 

124 

125 # 按块处理当前行 

126 for off in range(0, N, BLOCK_N): 

127 cols = off + tl.arange(0, BLOCK_N) 

128 col_mask = cols < N 

129 

130 # 加载数据 

131 x_vals = tl.load(X + row * N + cols, col_mask, other=0.0).to( 

132 tl.float32 

133 ) 

134 

135 # 对块内每个有效元素进行在线更新 

136 for i in range(BLOCK_N): 

137 if i < BLOCK_N and (off + i) < N: 

138 count += 1 

139 x = x_vals[i] 

140 

141 delta = x - running_mean 

142 running_mean += delta / count 

143 delta2 = x - running_mean 

144 running_M += delta * delta2 

145 

146 # 计算方差 

147 variance = running_M / (N - correction) if N > correction else 0.0 

148 

149 # 存储结果 

150 tl.store(Mean + row, running_mean, current_row_mask[:, 0]) 

151 tl.store(Var + row, variance, current_row_mask[:, 0]) 

152 

153 

154@libentry() 

155@triton.jit 

156def var_mean_kernel_1( 

157 X, 

158 Acc, 

159 Average, 

160 Count, 

161 N, 

162 BLOCK_N: tl.constexpr, 

163): 

164 # Map the program id to the row of X it should compute. 

165 pid = ext.program_id(0) 

166 offset = pid * BLOCK_N + tl.arange(0, BLOCK_N) 

167 

168 X = X + offset 

169 Acc = Acc + pid 

170 Average = Average + pid 

171 Count = Count + pid 

172 mask = offset < N 

173 

174 x = tl.load(X, mask, other=0.0).to(tl.float32) 

175 

176 count = tl.sum(mask.to(tl.float32)) 

177 average = tl.sum(x) / count 

178 acc = tl.sum(x * x) - count * average * average 

179 

180 tl.store(Average, average) 

181 tl.store(Acc, acc) 

182 tl.store(Count, count) 

183 

184 

185@libentry() 

186@triton.heuristics(_hcu.HEURISTICS_CONFIGS["var_mean"]) 

187@triton.jit(do_not_specialize=["correction"]) 

188def var_mean_kernel_2( 

189 Acc, 

190 Average, 

191 Count, 

192 Var, 

193 Mean, 

194 N, 

195 correction, 

196 BLOCK_NUM, 

197 BLOCK_N: tl.constexpr, 

198): 

199 offset = tl.arange(0, BLOCK_N) 

200 mask = offset < BLOCK_NUM 

201 Acc = Acc + offset 

202 Average = Average + offset 

203 Count = Count + offset 

204 acc = tl.load(Acc, mask, other=0.0).to(tl.float32) 

205 average = tl.load(Average, mask, other=0.0).to(tl.float32) 

206 count = tl.load(Count, mask, other=0.0).to(tl.float32) 

207 

208 # mean, _, nvar = tl.reduce((average, count, acc), axis=0, combine_fn=welford_func) 

209 # 手动实现 tl.reduce 的功能,沿着 axis=0 进行归约 

210 # 计算总计数 

211 total_count = tl.sum(count) 

212 

213 # 计算加权平均值 

214 weighted_sum = tl.sum(average * count) 

215 mean = weighted_sum / tl.maximum(total_count, 1) 

216 

217 # 计算方差累积值 

218 # 对于每个块,计算其对总体方差的贡献 

219 # 这是 Welford 算法的并行化版本 

220 local_var_contrib = acc + count * (average - mean) * (average - mean) 

221 nvar = tl.sum(local_var_contrib) 

222 

223 var = nvar / (N - correction) 

224 tl.store(Mean, mean) 

225 tl.store(Var, var) 

226 

227 

228def var_mean(x, dim=None, *, correction=None, keepdim=False): 

229 logger.debug("GEMS_ASCEND VAR MEAN") 

230 if correction is None: 

231 correction = 1.0 

232 

233 if dim is None or len(dim) == x.ndim: 

234 dim = list(range(x.ndim)) 

235 shape = [1] * x.ndim 

236 N = x.numel() 

237 var = torch.empty(shape, dtype=x.dtype, device=x.device) 

238 mean = torch.empty(shape, dtype=x.dtype, device=x.device) 

239 BLOCK_N = 1024 

240 BLOCK_NUM = triton.cdiv(N, BLOCK_N) 

241 acc = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device) 

242 average = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device) 

243 count = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device) 

244 

245 with torch_device_fn.device(x.device): 

246 var_mean_kernel_1[(BLOCK_NUM,)](x, acc, average, count, N, BLOCK_N=BLOCK_N) 

247 var_mean_kernel_2[(1,)]( 

248 acc, average, count, var, mean, N, correction, BLOCK_NUM 

249 ) 

250 else: 

251 shape = list(x.shape) 

252 dim = [d % x.ndim for d in dim] 

253 x = dim_compress(x, dim) 

254 N = 1 

255 for i in dim: 

256 N *= shape[i] 

257 shape[i] = 1 

258 M = x.numel() // N 

259 var = torch.empty(shape, dtype=x.dtype, device=x.device) 

260 mean = torch.empty(shape, dtype=x.dtype, device=x.device) 

261 

262 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

263 with torch_device_fn.device(x.device): 

264 var_mean_welford_kernel[grid](x, var, mean, M, N, correction) 

265 

266 if not keepdim: 

267 var = var.squeeze(dim=dim) 

268 mean = mean.squeeze(dim=dim) 

269 return var, mean