Coverage for src/flag_gems/runtime/backend/_cambricon/ops/mean.py: 0%
83 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import dim_compress, libentry, libtuner
11from ..utils import TOTAL_CORE_NUM, cfggen_reduce_op
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16@libentry()
17@libtuner(
18 configs=cfggen_reduce_op(), key=["M"], strategy=["log"], reset_to_zero=["out"]
19)
20@triton.jit
21def mean_kernel_1(
22 inp,
23 out,
24 M,
25 BLOCK_SIZE: tl.constexpr,
26):
27 pid = tl.program_id(0)
28 num_jobs = tl.num_programs(axis=0)
29 block_start = pid * BLOCK_SIZE
30 step = num_jobs * BLOCK_SIZE
31 _tmp = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
32 block_start = block_start.to(tl.int64)
33 for off in range(block_start, M, step):
34 offset = off + tl.arange(0, BLOCK_SIZE)
35 mask = offset < M
36 inp_val = tl.load(inp + offset, mask=mask, other=0.0)
37 _tmp = inp_val + _tmp
39 mean_val = tl.sum(_tmp, axis=0) / M
40 tl.atomic_add(out, mean_val)
43def mean(inp, *, dtype=None):
44 logger.debug("GEMS_CAMBRICON MEAN")
45 inp = inp.contiguous()
46 M = inp.numel()
47 if dtype is None:
48 dtype = inp.dtype
49 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
50 out = torch.zeros([], dtype=torch.float32, device=inp.device)
52 with torch_device_fn.device(inp.device):
53 mean_kernel_1[grid](inp, out, M)
54 return out.to(dtype)
57@libentry()
58@libtuner(
59 configs=runtime.get_tuned_config("mean"),
60 key=["M", "N"],
61 strategy=["log", "log"],
62)
63@triton.jit
64def mean_dim_kernel(X, Mean, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
65 # Map the program id to the row of X it should compute.
66 num_prog = tl.num_programs(0)
67 task_num = tl.cdiv(M, BLOCK_M)
68 iter_num = tl.cdiv(task_num, num_prog)
69 for i in range(0, iter_num):
70 pid = (i * num_prog + tl.program_id(0)) * BLOCK_M + tl.arange(0, BLOCK_M)[
71 :, None
72 ]
73 X_ptr = X + pid * N
74 Mean_ptr = Mean + pid
75 row_mask = pid < M
77 # Compute mean
78 _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
79 for off in range(0, N, BLOCK_N):
80 cols = off + tl.arange(0, BLOCK_N)[None, :]
81 col_mask = cols < N
82 mask = row_mask and col_mask
84 a = tl.load(X_ptr + cols, mask, other=0.0).to(tl.float32)
85 _mean += a
86 _mean /= N
87 mean = tl.sum(_mean, axis=1)[:, None]
88 tl.store(Mean_ptr, mean, row_mask)
91def mean_dim(x, dim, keepdim=False, *, dtype=None):
92 logger.debug("GEMS_CAMBRICON MEAN DIM")
94 if dtype is None:
95 dtype = x.dtype
96 if dim is None:
97 out = mean(x, dtype=dtype)
98 if not keepdim:
99 out = out.reshape([1] * x.ndim)
100 return out
102 shape = list(x.shape)
103 dim = [d % x.ndim for d in dim]
104 x = dim_compress(x, dim)
105 N = 1
106 for i in dim:
107 N *= shape[i]
108 shape[i] = 1
109 M = x.numel() // N
110 out = torch.empty(shape, dtype=dtype, device=x.device)
111 grid = lambda META: (min(triton.cdiv(M, META["BLOCK_M"]), TOTAL_CORE_NUM),)
112 with torch_device_fn.device(x.device):
113 mean_dim_kernel[grid](x, out, M, N)
114 if not keepdim:
115 out = out.squeeze(dim)
116 return out