Coverage for src/flag_gems/runtime/backend/_sunrise/ops/mean.py: 0%
192 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
2import math
3from functools import reduce
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import dim_compress, libentry, libtuner
12from flag_gems.utils import triton_lang_extension as ext
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17@libentry()
18@triton.jit
19def mean_kernel_1(
20 inp,
21 mid,
22 M,
23 BLOCK_SIZE: tl.constexpr,
24):
25 # accumulation dtype
26 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
27 inp.dtype.element_ty == tl.bfloat16
28 ):
29 cdtype = tl.float32
30 else:
31 cdtype = inp.dtype.element_ty
33 pid = ext.program_id(0)
34 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
35 inp_ptrs = inp + offset
36 mask = offset < M
38 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(cdtype)
39 sum_val = tl.sum(inp_val)
40 mid_ptr = mid + pid
41 tl.store(mid_ptr, sum_val)
44@libentry()
45@triton.jit
46def mean_kernel_2(mid, out, M, MID_SIZE, BLOCK_MID: tl.constexpr):
47 if tl.constexpr(mid.dtype.element_ty == tl.float16) or tl.constexpr(
48 mid.dtype.element_ty == tl.bfloat16
49 ):
50 cdtype = tl.float32
51 else:
52 cdtype = mid.dtype.element_ty
54 offset = tl.arange(0, BLOCK_MID)
55 mid_ptrs = mid + offset
56 mask = offset < MID_SIZE
57 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(cdtype)
58 sum_val = tl.sum(mid_val)
59 # divide by total element count M to get mean
60 mean_val = sum_val / M
61 tl.store(out, mean_val)
64def mean(inp, *, dtype=None):
65 logger.debug("GEMS MEAN")
66 M = inp.numel()
67 if dtype is None:
68 dtype = inp.dtype
69 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
70 mid_size = triton.cdiv(M, block_size)
71 block_mid = triton.next_power_of_2(mid_size)
73 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
74 out = torch.empty([], dtype=dtype, device=inp.device)
76 with torch_device_fn.device(inp.device):
77 mean_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)
78 mean_kernel_2[(1, 1, 1)](mid, out, M, mid_size, block_mid)
79 return out
82@libentry()
83@triton.heuristics(runtime.get_heuristic_config("mean_non_inner"))
84@triton.jit
85def mean_dim_kernel_non_inner(
86 output_ptr,
87 input_ptr,
88 M,
89 N,
90 K,
91 TILE_N: tl.constexpr,
92 TILE_K: tl.constexpr,
93 ONE_TILE_PER_CTA: tl.constexpr,
94):
95 # accumulation dtype
96 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
97 input_ptr.dtype.element_ty == tl.bfloat16
98 ):
99 cdtype = tl.float32
100 else:
101 cdtype = input_ptr.dtype.element_ty
103 pid_m = ext.program_id(0)
104 pid_k = ext.program_id(1)
106 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)[None, :]
108 if ONE_TILE_PER_CTA:
109 n_offsets = tl.arange(0, TILE_N)[:, None]
110 inp_offset = pid_m * N * K + n_offsets * K + k_offsets
111 mask = (n_offsets < N) & (k_offsets < K)
112 input_ptrs = input_ptr + inp_offset
113 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype)
114 # sum along reduction axis (N) -> keep dims so axis 0 corresponds to TILE_K
115 summed = tl.sum(inp, axis=0, keep_dims=True)
116 # divide by N to get mean
117 out = summed / N
118 out_offset = pid_m * K + k_offsets
119 output_ptrs = output_ptr + out_offset
120 tl.store(output_ptrs, out, mask=k_offsets < K)
121 else:
122 sum_tile = tl.zeros([TILE_N, TILE_K], dtype=cdtype)
123 for start_n in range(0, N, TILE_N):
124 n_offsets = start_n + tl.arange(0, TILE_N)[:, None]
125 inp_offsets = pid_m * N * K + n_offsets * K + k_offsets
126 mask = (n_offsets < N) & (k_offsets < K)
127 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype)
128 sum_tile += inp
129 summed = tl.sum(sum_tile, axis=0, keep_dims=True)
130 out = summed / N
131 out_offset = pid_m * K + k_offsets
132 output_ptrs = output_ptr + out_offset
133 tl.store(output_ptrs, out, mask=k_offsets < K)
136@libentry()
137@triton.heuristics(runtime.get_heuristic_config("softmax_inner"))
138@triton.jit
139def mean_dim_kernel_inner(
140 output_ptr,
141 input_ptr,
142 M,
143 N,
144 TILE_N: tl.constexpr,
145 ONE_TILE_PER_CTA: tl.constexpr,
146):
147 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
148 input_ptr.dtype.element_ty == tl.bfloat16
149 ):
150 cdtype = tl.float32
151 else:
152 cdtype = input_ptr.dtype.element_ty
154 pid_m = ext.program_id(0)
155 if ONE_TILE_PER_CTA:
156 n_offsets = tl.arange(0, TILE_N)
157 inp_offset = pid_m * N + n_offsets
158 input_ptrs = input_ptr + inp_offset
159 mask = n_offsets < N
160 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype)
161 summed = tl.sum(inp, axis=0)
162 out = summed / N
163 out_offset = pid_m
164 output_ptrs = output_ptr + out_offset
165 tl.store(output_ptrs, out)
166 else:
167 sum_vec = tl.zeros(
168 [
169 TILE_N,
170 ],
171 dtype=cdtype,
172 )
173 for start_n in range(0, N, TILE_N):
174 n_offsets = start_n + tl.arange(0, TILE_N)
175 inp_offsets = pid_m * N + n_offsets
176 mask = n_offsets < N
177 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype)
178 sum_vec += inp
179 summed = tl.sum(sum_vec, axis=0)
180 out = summed / N
181 out_offset = pid_m
182 output_ptrs = output_ptr + out_offset
183 tl.store(output_ptrs, out)
186@libentry()
187@libtuner(
188 configs=runtime.get_tuned_config("naive_reduction"),
189 key=["M", "N"],
190)
191@triton.jit
192def mean_dim_kernel(
193 inp,
194 out,
195 M,
196 N,
197 BLOCK_M: tl.constexpr,
198 BLOCK_N: tl.constexpr,
199):
200 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
201 inp.dtype.element_ty == tl.bfloat16
202 ):
203 cdtype = tl.float32
204 else:
205 cdtype = inp.dtype.element_ty
207 # Map the program id to the row of inp it should compute.
208 pid = ext.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
209 inp = inp + pid * N
210 out = out + pid
211 row_mask = pid < M
213 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype)
214 for off in range(0, N, BLOCK_N):
215 cols = off + tl.arange(0, BLOCK_N)[None, :]
216 col_mask = cols < N
217 mask = row_mask and col_mask
219 a = tl.load(inp + cols, mask, other=0).to(cdtype)
220 _sum += a
221 summed = tl.sum(_sum, axis=1)[:, None]
222 mean = summed / N
223 tl.store(out, mean, row_mask)
226def mean_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None):
227 logger.debug("GEMS MEAN_DIM")
228 if dtype is None:
229 dtype = inp.dtype
230 if dtype is torch.bool:
231 inp = inp.to(torch.int64)
232 dtype = torch.int64
234 if dim == []:
235 # mean over all elements
236 if not keepdim:
237 return mean(inp, dtype=dtype)
238 else:
239 dim_num = inp.ndim
240 return torch.reshape(mean(inp, dtype=dtype), [1] * dim_num)
242 shape = list(inp.shape)
244 # -------- normalize dim to a list of ints --------
245 if isinstance(dim, int):
246 dim = [dim]
247 else:
248 try:
249 dim = list(dim)
250 except TypeError:
251 raise TypeError(
252 f"dim must be an int, iterable of ints, or [], got {type(dim)}"
253 )
255 dim = [d % inp.ndim for d in dim]
256 # -------------------------------------------------
258 if len(dim) == 1:
259 dim0 = dim[0]
260 N = inp.shape[dim0] # reduction length
261 # product of dims before dim0; use initializer 1 for empty slice
262 M = reduce(lambda x, y: x * y, shape[:dim0], 1)
263 inp = inp.contiguous()
264 K = inp.numel() // M // N
265 shape[dim0] = 1
266 if out is None:
267 out = torch.empty(shape, dtype=dtype, device=inp.device)
269 with torch_device_fn.device(inp.device):
270 if K > 1:
271 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
272 mean_dim_kernel_non_inner[grid](
273 out,
274 inp,
275 M,
276 N,
277 K,
278 )
279 else:
280 grid = (M, 1, 1)
281 mean_dim_kernel_inner[grid](
282 out,
283 inp,
284 M,
285 N,
286 )
287 if not keepdim:
288 out = out.squeeze(dim=dim0)
289 return out
290 else:
291 inp = dim_compress(inp, dim)
292 N = 1
293 for i in dim:
294 N *= shape[i]
295 shape[i] = 1
296 M = inp.numel() // N
297 if out is None:
298 out = torch.empty(shape, dtype=dtype, device=inp.device)
300 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
301 with torch_device_fn.device(inp.device):
302 mean_dim_kernel[grid](inp, out, M, N)
303 if not keepdim:
304 out = out.squeeze(dim=dim)
305 return out
308def mean_dim(inp, dim=None, keepdim=False, *, dtype=None):
309 logger.debug("GEMS MEAN_DIM (wrapper)")
311 return mean_dim_comm(inp, dim, keepdim, dtype=dtype)