Coverage for src/flag_gems/runtime/backend/_ascend/ops/mean.py: 0%
75 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry
11from flag_gems.utils import triton_lang_extension as ext
13logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
16@libentry()
17@triton.jit
18def mean_kernel_1(
19 inp,
20 mid,
21 M,
22 BLOCK_SIZE: tl.constexpr,
23):
24 pid = tl.program_id(0)
25 off = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
26 mask = off < M
27 inp_val = tl.load(inp + off, mask=mask, other=0.0).to(tl.float32)
28 partial_sum = tl.sum(inp_val, axis=0)
29 tl.store(mid + pid, partial_sum)
32def mean(inp, *, dtype=None):
33 logger.debug("GEMS_ASCEND MEAN")
34 M = inp.numel()
35 if dtype is None:
36 dtype = inp.dtype
37 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
38 block_size = min(block_size, 2048)
39 out = torch.zeros([], dtype=torch.float32, device=inp.device)
40 num_ctas = triton.cdiv(M, block_size)
41 mid = torch.zeros([num_ctas], dtype=torch.float32, device=inp.device)
43 with torch_device_fn.device(inp.device):
44 mean_kernel_1[(num_ctas, 1, 1)](inp, mid, M, block_size)
45 out = mid.sum() / M
46 return out.to(dtype)
49@libentry()
50@triton.autotune(
51 configs=runtime.get_tuned_config("mean"),
52 key=["M", "N"],
53)
54@triton.jit
55def mean_dim_kernel(X, Mean, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
56 # Map the program id to the row of X it should compute.
57 pid = ext.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
58 X = X + pid * N
59 Mean = Mean + pid
60 row_mask = pid < M
62 # Compute mean
63 _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
64 for off in range(0, N, BLOCK_N):
65 cols = off + tl.arange(0, BLOCK_N)[None, :]
66 col_mask = cols < N
67 mask = row_mask and col_mask
69 a = tl.load(X + cols, mask, other=0.0).to(tl.float32)
70 _mean += a
71 mean = tl.sum(_mean, axis=1) / N
72 mean = mean[:, None]
73 tl.store(Mean, mean, row_mask)
76def mean_dim(x, dim, keepdim=False, *, dtype=None):
77 logger.debug("GEMS_ASCEND MEAN DIM")
79 if dtype is None:
80 dtype = x.dtype
81 if dim is None:
82 out = mean(x, dtype=dtype)
83 if not keepdim:
84 out = out.reshape([1] * x.ndim)
85 return out
87 shape = list(x.shape)
88 dim = [d % x.ndim for d in dim]
89 x = dim_compress(x, dim)
90 N = 1
91 for i in dim:
92 N *= shape[i]
93 shape[i] = 1
94 M = x.numel() // N
95 out = torch.empty(shape, dtype=dtype, device=x.device)
96 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
98 with torch_device_fn.device(x.device):
99 mean_dim_kernel[grid](x, out, M, N)
100 if not keepdim:
101 out = out.squeeze(dim)
102 return out