Coverage for src/flag_gems/runtime/backend/_spacemit/ops/mean.py: 0%
85 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry
11from flag_gems.utils import triton_lang_extension as tle
14@libentry()
15@triton.jit
16def mean_kernel_1(
17 inp,
18 mid,
19 M,
20 BLOCK_SIZE: tl.constexpr,
21):
22 pid = tle.program_id(0)
23 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
24 inp_ptrs = inp + offset
25 mask = offset < M
26 inp_val = tl.load(inp_ptrs, mask=mask, other=0.0)
27 sum_val = tl.sum(inp_val, axis=0)
28 mid_ptr = mid + pid
29 tl.store(mid_ptr, sum_val)
32@libentry()
33@triton.jit
34def mean_kernel_2(mid, out, M, MID_SIZE, BLOCK_MID: tl.constexpr):
35 offset = tl.arange(0, BLOCK_MID)
36 mid_ptrs = mid + offset
37 mask = offset < MID_SIZE
38 mid_val = tl.load(mid_ptrs, mask=mask, other=0.0)
39 sum_val = tl.sum(mid_val, axis=0) / M
40 tl.store(out, sum_val)
43def mean(inp, *, dtype=None):
44 logging.debug("GEMS_SPACEMIT MEAN")
45 M = inp.numel()
46 if dtype is None:
47 dtype = inp.dtype
48 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
49 mid_size = triton.cdiv(M, block_size)
50 block_mid = triton.next_power_of_2(mid_size)
52 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
53 out = torch.empty([], dtype=dtype, device=inp.device)
55 with torch_device_fn.device(inp.device):
56 mean_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)
57 mean_kernel_2[(1, 1, 1)](mid, out, M, mid_size, block_mid)
58 return out
61@libentry()
62@triton.autotune(
63 configs=runtime.get_tuned_config("mean"),
64 key=["M", "N"],
65)
66@triton.jit
67def mean_dim_kernel(X, Mean, M, N, TILE_N: tl.constexpr):
68 row = tl.program_id(0)
69 X = X + row * N
70 Mean = Mean + row
71 _mean = 0.0
73 num_pid_n = tl.cdiv(N, TILE_N)
75 x_ptr_desc = tl.make_block_ptr(
76 base=X,
77 shape=[N],
78 strides=[1],
79 offsets=[0],
80 block_shape=[TILE_N],
81 order=[0],
82 )
84 for off_n in range(0, num_pid_n):
85 a = tl.load(
86 x_ptr_desc,
87 boundary_check=[0],
88 )
90 _mean += tl.sum(a)
92 x_ptr_desc = tl.advance(x_ptr_desc, [TILE_N])
94 mean = _mean / N
96 tl.store(Mean, mean)
99def mean_dim(x, dim, keepdim=False, *, dtype=None):
100 logging.debug("GEMS_SPACEMIT MEAN_DIM")
102 if dtype is None:
103 dtype = x.dtype
104 if dim is None:
105 out = mean(x, dtype=dtype)
106 if not keepdim:
107 out = out.reshape([1] * x.ndim)
108 return out
110 shape = list(x.shape)
111 dim = [d % x.ndim for d in dim]
112 x = dim_compress(x, dim)
113 N = 1
114 for i in dim:
115 N *= shape[i]
116 shape[i] = 1
117 M = x.numel() // N
118 out = torch.empty(shape, dtype=dtype, device=x.device)
119 grid = (M,)
120 with torch_device_fn.device(x.device):
121 mean_dim_kernel[grid](x, out, M, N)
122 if not keepdim:
123 out = out.squeeze(dim)
124 return out
127def global_avg_pool(x, _output_size=None):
128 return mean_dim(x, dim=[2, 3], keepdim=True)