Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/mean.py: 0%
88 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import builtins
2import logging
4import torch
5import triton
6import triton.language as tl
8# from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry
11from flag_gems.utils import triton_lang_extension as ext
13from ..utils.block_size_utils import get_block_size_1d
15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
18@libentry()
19@triton.jit
20def mean_scalar_kernel(inp, out, M, BLOCK_SIZE: tl.constexpr):
21 """Scalar mean over all M elements.
22 On XPU (USE_XHPC): intercepted by baidu::xpu::api::mean binding.
23 Triton fallback (single CTA): sequential accumulation for correctness.
24 Params for binding:
25 kernelParams[0] = inp, kernelParams[1] = out
26 kernelConsts[2] = M, kernelConsts[3] = BLOCK_SIZE
27 """
28 acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
29 for off in range(0, M, BLOCK_SIZE):
30 offset = off + tl.arange(0, BLOCK_SIZE)
31 mask = offset < M
32 v = tl.load(inp + offset, mask=mask, other=0.0).to(tl.float32)
33 acc += v
34 result = tl.sum(acc) / M
35 tl.store(out, result)
38def mean(inp, *, dtype=None):
39 logger.debug("GEMS MEAN")
40 M = inp.numel()
41 if dtype is None:
42 dtype = inp.dtype
43 BLOCK_SIZE = get_block_size_1d(M, inp.element_size())
44 out = torch.empty([], dtype=dtype, device=inp.device)
46 with torch_device_fn.device(inp.device):
47 mean_scalar_kernel[(1, 1, 1)](inp, out, M, BLOCK_SIZE, buffer_size_limit=2048)
48 return out
51def heur_m_block_size(args):
52 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num
55def heur_n_block_size(args):
56 return builtins.min(triton.next_power_of_2(args["N"]), 8192)
59@libentry()
60# @triton.autotune(
61# configs=runtime.get_tuned_config("mean"),
62# key=["M", "N"],
63# )
64@triton.heuristics(
65 values={
66 "BLOCK_M": heur_m_block_size,
67 "BLOCK_N": heur_n_block_size,
68 },
69)
70@triton.jit
71def mean_dim_kernel(X, Mean, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
72 """2-D reduction: reduce N-dim for each of M rows.
73 On XPU (USE_XHPC): intercepted by baidu::xpu::api::mean_dim binding.
74 Params for binding:
75 kernelParams[0] = X, kernelParams[1] = Mean
76 kernelParams[2] = M, kernelParams[3] = N (runtime scalars)
77 kernelConsts[4] = BLOCK_M (constexpr), kernelConsts[5] = BLOCK_N (constexpr)
78 """
79 # Map the program id to the row of X it should compute.
80 pid = ext.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
81 X = X + pid * N
82 Mean = Mean + pid
83 row_mask = pid < M
85 # Compute mean
86 _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
87 for off in range(0, N, BLOCK_N):
88 cols = off + tl.arange(0, BLOCK_N)[None, :]
89 col_mask = cols < N
90 mask = row_mask and col_mask
92 a = tl.load(X + cols, mask, other=0.0).to(tl.float32)
93 _mean += a
94 mean = tl.sum(_mean, axis=1) / N
95 mean = mean[:, None]
96 tl.store(Mean, mean, row_mask)
99def mean_dim(x, dim, keepdim=False, *, dtype=None):
100 logger.debug("GEMS MEAN DIM")
102 if dtype is None:
103 dtype = x.dtype
104 if dim is None:
105 out = mean(x, dtype=dtype)
106 if not keepdim:
107 out = out.reshape([1] * x.ndim)
108 return out
110 shape = list(x.shape)
111 dim = [d % x.ndim for d in dim]
112 x = dim_compress(x, dim)
113 N = 1
114 for i in dim:
115 N *= shape[i]
116 shape[i] = 1
117 M = x.numel() // N
119 # Edge case: M=1 means all dims are reduced → global mean over N elements.
120 # mean_dim XPU API does not support M=1.
121 if M == 1:
122 scalar_out = mean(x, dtype=dtype) # 0-d tensor
123 out = scalar_out.reshape(shape)
124 if not keepdim:
125 out = out.squeeze(dim)
126 return out
128 # Edge case: N=1 means reducing a trivial (size-1) dimension.
129 # mean of 1 element = that element; just copy with dtype conversion.
130 # mean_dim XPU API does not support N=1.
131 if N == 1:
132 out = x.to(dtype=dtype).reshape(shape)
133 if not keepdim:
134 out = out.squeeze(dim)
135 return out
137 out = torch.empty(shape, dtype=dtype, device=x.device)
138 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
140 with torch_device_fn.device(x.device):
141 mean_dim_kernel[grid](x, out, M, N, buffer_size_limit=2048)
142 if not keepdim:
143 out = out.squeeze(dim)
144 return out