Coverage for src/flag_gems/ops/mean.py: 46%
223 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
2import math
3from functools import reduce
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import dim_compress, libentry, libtuner
12from flag_gems.utils import triton_lang_extension as tle
14logger = logging.getLogger(__name__)
17@libentry()
18@triton.jit
19def mean_kernel_1(
20 inp,
21 mid,
22 M,
23 BLOCK_SIZE: tl.constexpr,
24):
25 # accumulation dtype
26 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
27 inp.dtype.element_ty == tl.bfloat16
28 ):
29 cdtype = tl.float32
30 else:
31 cdtype = inp.dtype.element_ty
33 pid = tle.program_id(0)
34 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
35 inp_ptrs = inp + offset
36 mask = offset < M
38 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(cdtype)
39 sum_val = tl.sum(inp_val)
40 mid_ptr = mid + pid
41 tl.store(mid_ptr, sum_val)
44@libentry()
45@triton.jit
46def mean_kernel_2(mid, out, M, MID_SIZE, BLOCK_MID: tl.constexpr):
47 if tl.constexpr(mid.dtype.element_ty == tl.float16) or tl.constexpr(
48 mid.dtype.element_ty == tl.bfloat16
49 ):
50 cdtype = tl.float32
51 else:
52 cdtype = mid.dtype.element_ty
54 offset = tl.arange(0, BLOCK_MID)
55 mid_ptrs = mid + offset
56 mask = offset < MID_SIZE
57 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(cdtype)
58 sum_val = tl.sum(mid_val)
59 # divide by total element count M to get mean
60 mean_val = sum_val / M
61 tl.store(out, mean_val)
64def mean(inp, *, dtype=None):
65 logger.debug("GEMS MEAN")
66 inp = inp.contiguous()
67 M = inp.numel()
68 if dtype is None:
69 dtype = inp.dtype
70 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
71 mid_size = triton.cdiv(M, block_size)
72 block_mid = triton.next_power_of_2(mid_size)
74 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
75 out = torch.empty([], dtype=dtype, device=inp.device)
77 with torch_device_fn.device(inp.device):
78 mean_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)
79 mean_kernel_2[(1, 1, 1)](mid, out, M, mid_size, block_mid)
80 return out
83@libentry()
84@triton.jit
85def mean_dim_kernel_non_inner_vec(
86 output_ptr,
87 input_ptr,
88 M,
89 N,
90 K,
91 BLOCK_SIZE_K: tl.constexpr, # number of threads per block along K
92 VEC_SIZE: tl.constexpr, # elements per thread (1 for FP32, 8 for FP16/BF16)
93):
94 # Determine accumulation and load behavior
95 input_dtype = input_ptr.dtype.element_ty
96 if tl.constexpr(input_dtype == tl.float16) or tl.constexpr(
97 input_dtype == tl.bfloat16
98 ):
99 ACC_DTYPE = tl.float32
100 # VEC_SIZE should be 4 or 8 for vectorization
101 else:
102 ACC_DTYPE = input_dtype
103 # VEC_SIZE = 1 for FP32
105 pid_m = tle.program_id(0)
106 pid_k = tle.program_id(1)
108 # Each thread handles VEC_SIZE consecutive elements
109 k_base = pid_k * BLOCK_SIZE_K * VEC_SIZE
110 k_offsets = (
111 k_base
112 + tl.arange(0, BLOCK_SIZE_K)[:, None] * VEC_SIZE
113 + tl.arange(0, VEC_SIZE)[None, :]
114 )
115 # Shape: [BLOCK_SIZE_K, VEC_SIZE]
116 k_mask = k_offsets < K
118 # Accumulator: [BLOCK_SIZE_K, VEC_SIZE]
119 acc = tl.zeros((BLOCK_SIZE_K, VEC_SIZE), dtype=ACC_DTYPE)
121 base = pid_m * N * K
123 for n in range(N):
124 offsets = base + n * K + k_offsets
125 # This will trigger vectorized load if VEC_SIZE >= 4 and aligned
126 val = tl.load(input_ptr + offsets, mask=k_mask, other=0.0)
127 acc += val.to(ACC_DTYPE)
129 mean_val = acc / N
131 # Store back
132 out_offsets = pid_m * K + k_offsets
133 tl.store(output_ptr + out_offsets, mean_val, mask=k_mask)
136@libentry()
137@triton.heuristics(runtime.get_heuristic_config("mean_non_inner"))
138@triton.jit
139def mean_dim_kernel_non_inner(
140 output_ptr,
141 input_ptr,
142 M,
143 N,
144 K,
145 TILE_N: tl.constexpr,
146 TILE_K: tl.constexpr,
147 ONE_TILE_PER_CTA: tl.constexpr,
148):
149 # accumulation dtype
150 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
151 input_ptr.dtype.element_ty == tl.bfloat16
152 ):
153 cdtype = tl.float32
154 else:
155 cdtype = input_ptr.dtype.element_ty
157 pid_m = tle.program_id(0)
158 pid_k = tle.program_id(1)
160 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)[None, :]
162 if ONE_TILE_PER_CTA:
163 n_offsets = tl.arange(0, TILE_N)[:, None]
164 inp_offset = pid_m * N * K + n_offsets * K + k_offsets
165 mask = (n_offsets < N) & (k_offsets < K)
166 input_ptrs = input_ptr + inp_offset
167 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype)
168 # sum along reduction axis (N) -> keep dims so axis 0 corresponds to TILE_K
169 summed = tl.sum(inp, axis=0, keep_dims=True)
170 # divide by N to get mean
171 out = summed / N
172 out_offset = pid_m * K + k_offsets
173 output_ptrs = output_ptr + out_offset
174 tl.store(output_ptrs, out, mask=k_offsets < K)
175 else:
176 sum_tile = tl.zeros([TILE_N, TILE_K], dtype=cdtype)
177 for start_n in range(0, N, TILE_N):
178 n_offsets = start_n + tl.arange(0, TILE_N)[:, None]
179 inp_offsets = pid_m * N * K + n_offsets * K + k_offsets
180 mask = (n_offsets < N) & (k_offsets < K)
181 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype)
182 sum_tile += inp
183 summed = tl.sum(sum_tile, axis=0, keep_dims=True)
184 out = summed / N
185 out_offset = pid_m * K + k_offsets
186 output_ptrs = output_ptr + out_offset
187 tl.store(output_ptrs, out, mask=k_offsets < K)
190@libentry()
191@triton.heuristics(runtime.get_heuristic_config("softmax_inner"))
192@triton.jit
193def mean_dim_kernel_inner(
194 output_ptr,
195 input_ptr,
196 M,
197 N,
198 TILE_N: tl.constexpr,
199 ONE_TILE_PER_CTA: tl.constexpr,
200):
201 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
202 input_ptr.dtype.element_ty == tl.bfloat16
203 ):
204 cdtype = tl.float32
205 else:
206 cdtype = input_ptr.dtype.element_ty
208 pid_m = tle.program_id(0)
209 if ONE_TILE_PER_CTA:
210 n_offsets = tl.arange(0, TILE_N)
211 inp_offset = pid_m * N + n_offsets
212 input_ptrs = input_ptr + inp_offset
213 mask = n_offsets < N
214 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype)
215 summed = tl.sum(inp, axis=0)
216 out = summed / N
217 out_offset = pid_m
218 output_ptrs = output_ptr + out_offset
219 tl.store(output_ptrs, out)
220 else:
221 sum_vec = tl.zeros(
222 [
223 TILE_N,
224 ],
225 dtype=cdtype,
226 )
227 for start_n in range(0, N, TILE_N):
228 n_offsets = start_n + tl.arange(0, TILE_N)
229 inp_offsets = pid_m * N + n_offsets
230 mask = n_offsets < N
231 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype)
232 sum_vec += inp
233 summed = tl.sum(sum_vec, axis=0)
234 out = summed / N
235 out_offset = pid_m
236 output_ptrs = output_ptr + out_offset
237 tl.store(output_ptrs, out)
240@libentry()
241@libtuner(
242 configs=runtime.get_tuned_config("naive_reduction"),
243 key=["M", "N"],
244)
245@triton.jit
246def mean_dim_kernel(
247 inp,
248 out,
249 M,
250 N,
251 BLOCK_M: tl.constexpr,
252 BLOCK_N: tl.constexpr,
253):
254 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
255 inp.dtype.element_ty == tl.bfloat16
256 ):
257 cdtype = tl.float32
258 else:
259 cdtype = inp.dtype.element_ty
261 # Map the program id to the row of inp it should compute.
262 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
263 inp = inp + pid * N
264 out = out + pid
265 row_mask = pid < M
267 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype)
268 for off in range(0, N, BLOCK_N):
269 cols = off + tl.arange(0, BLOCK_N)[None, :]
270 col_mask = cols < N
271 mask = row_mask and col_mask
273 a = tl.load(inp + cols, mask, other=0).to(cdtype)
274 _sum += a
275 summed = tl.sum(_sum, axis=1)[:, None]
276 mean = summed / N
277 tl.store(out, mean, row_mask)
280def mean_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None):
281 logger.debug("GEMS MEAN_DIM")
282 if dtype is None:
283 dtype = inp.dtype
284 if dtype is torch.bool:
285 inp = inp.to(torch.int64)
286 dtype = torch.int64
288 if dim == []:
289 # mean over all elements
290 if not keepdim:
291 return mean(inp, dtype=dtype)
292 else:
293 dim_num = inp.ndim
294 return torch.reshape(mean(inp, dtype=dtype), [1] * dim_num)
296 shape = list(inp.shape)
298 # -------- normalize dim to a list of ints --------
299 if isinstance(dim, int):
300 dim = [dim]
301 else:
302 try:
303 dim = list(dim)
304 except TypeError:
305 raise TypeError(
306 f"dim must be an int, iterable of ints, or [], got {type(dim)}"
307 )
309 dim = [d % inp.ndim for d in dim]
310 # -------------------------------------------------
312 if len(dim) == 1:
313 dim0 = dim[0]
314 N = inp.shape[dim0] # reduction length
315 # product of dims before dim0; use initializer 1 for empty slice
316 M = reduce(lambda x, y: x * y, shape[:dim0], 1)
317 inp = inp.contiguous()
318 K = inp.numel() // M // N
319 shape[dim0] = 1
320 if out is None:
321 out = torch.empty(shape, dtype=dtype, device=inp.device)
323 with torch_device_fn.device(inp.device):
324 if K >= 1024:
325 input_dtype = inp.dtype
326 if input_dtype in (torch.float16, torch.bfloat16):
327 VEC_SIZE = 8
328 BLOCK_SIZE_K = 128
329 else:
330 VEC_SIZE = 1
331 BLOCK_SIZE_K = min(triton.next_power_of_2(K), 512)
332 grid = (M, triton.cdiv(K, BLOCK_SIZE_K * VEC_SIZE))
333 mean_dim_kernel_non_inner_vec[grid](
334 out,
335 inp,
336 M,
337 N,
338 K,
339 BLOCK_SIZE_K=BLOCK_SIZE_K,
340 VEC_SIZE=VEC_SIZE,
341 num_warps=8 if BLOCK_SIZE_K <= 128 else 16,
342 )
343 elif K > 1:
344 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
345 mean_dim_kernel_non_inner[grid](
346 out,
347 inp,
348 M,
349 N,
350 K,
351 )
352 else:
353 grid = (M, 1, 1)
354 mean_dim_kernel_inner[grid](
355 out,
356 inp,
357 M,
358 N,
359 )
360 if not keepdim:
361 out = out.squeeze(dim=dim0)
362 return out
363 else:
364 inp = dim_compress(inp, dim)
365 N = 1
366 for i in dim:
367 N *= shape[i]
368 shape[i] = 1
369 M = inp.numel() // N
370 if out is None:
371 out = torch.empty(shape, dtype=dtype, device=inp.device)
373 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
374 with torch_device_fn.device(inp.device):
375 mean_dim_kernel[grid](inp, out, M, N)
376 if not keepdim:
377 out = out.squeeze(dim=dim)
378 return out
381def mean_dim(inp, dim=None, keepdim=False, *, dtype=None):
382 logger.debug("GEMS MEAN_DIM (wrapper)")
384 return mean_dim_comm(inp, dim, keepdim, dtype=dtype)