Coverage for src/flag_gems/ops/var.py: 50%
114 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import dim_compress, libentry
10from flag_gems.utils import triton_lang_extension as ext
12logger = logging.getLogger(__name__)
15@triton.jit
16def welford_func(mean_x, count_x, M_x, mean_y, count_y, M_y):
17 count = count_x + count_y
18 _count = tl.maximum(count, 1)
19 mc_x = mean_x * count_x
20 mc_y = mean_y * count_y
21 mean = (mc_x + mc_y) / _count
22 M = M_x + mc_x * mean_x + M_y + mc_y * mean_y - count * mean * mean
23 return mean, count, M
26@libentry()
27@triton.autotune(configs=runtime.get_tuned_config("var_mean"), key=["M", "N"])
28@triton.jit(do_not_specialize=["correction"])
29def var_welford_kernel(
30 X,
31 Var,
32 M,
33 N,
34 correction,
35 BLOCK_M: tl.constexpr,
36 BLOCK_N: tl.constexpr,
37):
38 # Map the program id to the row of X it should compute.
39 pid = ext.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
40 X = X + pid * N
41 Var = Var + pid
42 row_mask = pid < M
44 _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
45 _acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
46 _count = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
47 for off in range(0, N, BLOCK_N):
48 cols = off + tl.arange(0, BLOCK_N)[None, :]
49 col_mask = cols < N
50 mask = row_mask and col_mask
52 x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
54 count = _count + mask
55 cnt = tl.maximum(count, 1)
56 cur_mean = (_mean * _count + x) / cnt
57 _acc += (x - cur_mean) * (x - _mean) * mask
58 _mean = cur_mean
59 _count = count
61 mean, _, acc = tl.reduce((_mean, _count, _acc), axis=1, combine_fn=welford_func)
62 var = acc / (N - correction)
63 var = var[:, None]
64 # Write var
65 tl.store(Var, var, row_mask)
68@libentry()
69@triton.jit
70def var_kernel_1(
71 X,
72 Acc,
73 Average,
74 Count,
75 N,
76 BLOCK_N: tl.constexpr,
77):
78 # Map the program id to the row of X it should compute.
79 pid = ext.program_id(0)
80 offset = pid * BLOCK_N + tl.arange(0, BLOCK_N)
82 X = X + offset
83 Acc = Acc + pid
84 Average = Average + pid
85 Count = Count + pid
86 mask = offset < N
88 x = tl.load(X, mask, other=0.0).to(tl.float32)
90 count = tl.sum(mask.to(tl.float32))
91 average = tl.sum(x) / count
92 acc = tl.sum(x * x) - count * average * average
94 tl.store(Average, average)
95 tl.store(Acc, acc)
96 tl.store(Count, count)
99@libentry()
100@triton.heuristics(runtime.get_heuristic_config("var_mean"))
101@triton.jit(do_not_specialize=["correction"])
102def var_kernel_2(
103 Acc,
104 Average,
105 Count,
106 Var,
107 N,
108 correction,
109 BLOCK_NUM,
110 BLOCK_N: tl.constexpr,
111):
112 offset = tl.arange(0, BLOCK_N)
113 mask = offset < BLOCK_NUM
114 Acc = Acc + offset
115 Average = Average + offset
116 Count = Count + offset
117 acc = tl.load(Acc, mask, other=0.0).to(tl.float32)
118 average = tl.load(Average, mask, other=0.0).to(tl.float32)
119 count = tl.load(Count, mask, other=0.0).to(tl.float32)
121 mean, _, nvar = tl.reduce((average, count, acc), axis=0, combine_fn=welford_func)
123 var = nvar / (N - correction)
124 tl.store(Var, var)
127def var(x, dim=None, *, correction=None, keepdim=False):
128 logger.debug("GEMS VAR")
129 if correction is None:
130 correction = 1.0
132 if dim is None or len(dim) == x.ndim:
133 dim = list(range(x.ndim))
134 shape = [1] * x.ndim
135 N = x.numel()
136 var = torch.empty(shape, dtype=x.dtype, device=x.device)
137 BLOCK_N = 1024
138 BLOCK_NUM = triton.cdiv(N, BLOCK_N)
139 acc = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device)
140 average = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device)
141 count = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device)
143 with torch_device_fn.device(x.device):
144 var_kernel_1[(BLOCK_NUM,)](x, acc, average, count, N, BLOCK_N=BLOCK_N)
145 var_kernel_2[(1,)](acc, average, count, var, N, correction, BLOCK_NUM)
146 else:
147 shape = list(x.shape)
148 dim = [d % x.ndim for d in dim]
149 x = dim_compress(x, dim)
150 N = 1
151 for i in dim:
152 N *= shape[i]
153 shape[i] = 1
154 M = x.numel() // N
155 var = torch.empty(shape, dtype=x.dtype, device=x.device)
157 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
158 with torch_device_fn.device(x.device):
159 var_welford_kernel[grid](x, var, M, N, correction)
161 if not keepdim:
162 var = var.squeeze(dim=dim)
163 return var
166def var_dim(x, dim=None, *, correction=None, keepdim=False):
167 logger.debug("GEMS VAR_DIM")
168 return var(x, dim=dim, correction=correction, keepdim=keepdim)
171def var_correction(x, dim=None, *, correction=None, keepdim=False):
172 logger.debug("GEMS VAR_CORRECTION")
173 return var(x, dim=dim, correction=correction, keepdim=keepdim)