Coverage for src/flag_gems/runtime/backend/_sunrise/ops/sum.py: 0%
280 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
2import math
3from functools import reduce
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.ops.zeros import zero_
11from flag_gems.runtime import torch_device_fn
12from flag_gems.utils import dim_compress, libentry, libtuner
13from flag_gems.utils import triton_lang_extension as ext
15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
18@libentry()
19@triton.jit
20def sum_kernel_1(
21 inp,
22 mid,
23 M,
24 BLOCK_SIZE: tl.constexpr,
25):
26 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
27 inp.dtype.element_ty == tl.bfloat16
28 ):
29 cdtype = tl.float32
30 else:
31 cdtype = inp.dtype.element_ty
33 pid = ext.program_id(0)
34 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
35 inp_ptrs = inp + offset
36 mask = offset < M
38 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(cdtype)
39 sum_val = tl.sum(inp_val)
40 mid_ptr = mid + pid
41 tl.store(mid_ptr, sum_val)
44@libentry()
45@triton.jit
46def sum_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
47 if tl.constexpr(mid.dtype.element_ty == tl.float16) or tl.constexpr(
48 mid.dtype.element_ty == tl.bfloat16
49 ):
50 cdtype = tl.float32
51 else:
52 cdtype = mid.dtype.element_ty
54 offset = tl.arange(0, BLOCK_MID)
55 mid_ptrs = mid + offset
56 mask = offset < mid_size
57 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(cdtype)
58 sum_val = tl.sum(mid_val)
59 tl.store(out, sum_val)
62@libentry()
63@triton.autotune(configs=runtime.get_tuned_config("sum"), key=["M", "N"])
64@triton.jit
65def sum_kernel_dim0(
66 inp,
67 out,
68 M,
69 N,
70 BLOCK_M: tl.constexpr,
71 BLOCK_N: tl.constexpr,
72):
73 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
74 inp.dtype.element_ty == tl.bfloat16
75 ):
76 cdtype = tl.float32
77 else:
78 cdtype = inp.dtype.element_ty
80 # Map the program id to the row of inp it should compute.
81 pid = ext.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[None, :]
82 inp = inp + pid
83 out = out + pid
84 row_mask = pid < M
86 _sum = tl.zeros([BLOCK_N, BLOCK_M], dtype=cdtype)
87 for off in range(0, N, BLOCK_N):
88 cols = off + tl.arange(0, BLOCK_N)[:, None]
89 col_mask = cols < N
90 mask = row_mask & col_mask
92 a = tl.load(inp + cols * M, mask, other=0).to(cdtype)
93 _sum += a
94 sum = tl.sum(_sum, axis=0)[None, :]
95 tl.store(out, sum, row_mask)
98def sum(inp, *, dtype=None):
99 logger.debug("GEMS SUM")
100 inp = inp.contiguous()
101 M = inp.numel()
102 if dtype is None:
103 dtype = inp.dtype
104 if dtype is torch.bool:
105 inp = inp.to(torch.int64)
106 dtype = torch.int64
107 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
108 mid_size = triton.cdiv(M, block_size)
109 block_mid = triton.next_power_of_2(mid_size)
111 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
112 out = torch.empty([], dtype=dtype, device=inp.device)
114 with torch_device_fn.device(inp.device):
115 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)
116 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid)
117 return out
120def sum_out(inp, *, dtype=None, out):
121 logger.debug("GEMS SUM_OUT")
122 M = inp.numel()
123 if dtype is None:
124 dtype = inp.dtype
125 if dtype is torch.bool:
126 inp = inp.to(torch.int64)
127 dtype = torch.int64
128 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
129 mid_size = triton.cdiv(M, block_size)
130 block_mid = triton.next_power_of_2(mid_size)
132 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
133 with torch_device_fn.device(inp.device):
134 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)
135 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid)
136 return out
139@libentry()
140@triton.heuristics(runtime.get_heuristic_config("sum_non_inner"))
141@triton.jit
142def sum_dim_kernel_non_inner(
143 output_ptr,
144 input_ptr,
145 M,
146 N,
147 K,
148 TILE_N: tl.constexpr,
149 TILE_K: tl.constexpr,
150 ONE_TILE_PER_CTA: tl.constexpr,
151):
152 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
153 input_ptr.dtype.element_ty == tl.bfloat16
154 ):
155 cdtype = tl.float32
156 else:
157 cdtype = input_ptr.dtype.element_ty
159 pid_m = ext.program_id(0)
160 pid_k = ext.program_id(1)
162 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)[None, :]
164 if ONE_TILE_PER_CTA:
165 n_offsets = tl.arange(0, TILE_N)[:, None]
166 inp_offset = pid_m * N * K + n_offsets * K + k_offsets
167 mask = (n_offsets < N) & (k_offsets < K)
168 input_ptrs = input_ptr + inp_offset
169 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype)
170 out = tl.sum(inp, axis=0, keep_dims=True)
171 out_offset = pid_m * K + k_offsets
172 output_ptrs = output_ptr + out_offset
173 tl.store(output_ptrs, out, mask=k_offsets < K)
174 else:
175 sum = tl.zeros([TILE_N, TILE_K], dtype=cdtype)
177 # specialization does not improve performance inn this example, as tested
178 for start_n in range(0, N, TILE_N):
179 n_offsets = start_n + tl.arange(0, TILE_N)[:, None]
180 inp_offsets = pid_m * N * K + n_offsets * K + k_offsets
181 mask = (n_offsets < N) & (k_offsets < K)
182 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype)
183 sum += inp
184 out = tl.sum(sum, axis=0, keep_dims=True)
185 out_offset = pid_m * K + k_offsets
186 output_ptrs = output_ptr + out_offset
187 tl.store(output_ptrs, out, mask=k_offsets < K)
190@libentry()
191@triton.heuristics(runtime.get_heuristic_config("sum_inner"))
192@triton.jit
193def sum_dim_kernel_inner(
194 output_ptr,
195 input_ptr,
196 M,
197 N,
198 TILE_N: tl.constexpr,
199 ONE_TILE_PER_CTA: tl.constexpr,
200):
201 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
202 input_ptr.dtype.element_ty == tl.bfloat16
203 ):
204 cdtype = tl.float32
205 else:
206 cdtype = input_ptr.dtype.element_ty
208 pid_m = ext.program_id(0)
209 if ONE_TILE_PER_CTA:
210 n_offsets = tl.arange(0, TILE_N)
211 inp_offset = pid_m * N + n_offsets
212 input_ptrs = input_ptr + inp_offset
213 mask = n_offsets < N
214 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype)
215 out = tl.sum(inp, axis=0)
216 out_offset = pid_m
217 output_ptrs = output_ptr + out_offset
218 tl.store(output_ptrs, out)
219 else:
220 sum = tl.zeros(
221 [
222 TILE_N,
223 ],
224 dtype=cdtype,
225 )
226 for start_n in range(0, N, TILE_N):
227 n_offsets = start_n + tl.arange(0, TILE_N)
228 inp_offsets = pid_m * N + n_offsets
229 mask = n_offsets < N
230 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype)
231 sum += inp
232 out = tl.sum(sum, axis=0)
233 out_offset = pid_m
234 output_ptrs = output_ptr + out_offset
235 tl.store(output_ptrs, out)
238@libentry()
239@libtuner(
240 configs=runtime.get_tuned_config("naive_reduction"),
241 key=["M", "N"],
242)
243@triton.jit
244def sum_dim_kernel(
245 inp,
246 out,
247 M,
248 N,
249 BLOCK_M: tl.constexpr,
250 BLOCK_N: tl.constexpr,
251):
252 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
253 inp.dtype.element_ty == tl.bfloat16
254 ):
255 cdtype = tl.float32
256 else:
257 cdtype = inp.dtype.element_ty
259 # Map the program id to the row of inp it should compute.
260 pid = ext.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
261 inp = inp + pid * N
262 out = out + pid
263 row_mask = pid < M
265 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype)
266 for off in range(0, N, BLOCK_N):
267 cols = off + tl.arange(0, BLOCK_N)[None, :]
268 col_mask = cols < N
269 mask = row_mask and col_mask
271 a = tl.load(inp + cols, mask, other=0).to(cdtype)
272 _sum += a
273 sum = tl.sum(_sum, axis=1)[:, None]
274 tl.store(out, sum, row_mask)
277def sum_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None):
278 if dtype is None:
279 dtype = inp.dtype
280 if dtype is torch.bool:
281 dtype = torch.int64
283 if dim is None:
284 result = torch.sum(inp, dtype=dtype)
285 if keepdim:
286 result = result.reshape([1] * inp.ndim)
287 return result
289 if dim == []:
290 if not keepdim:
291 return sum(inp, dtype=dtype)
292 else:
293 dim_num = inp.ndim
294 return torch.reshape(sum(inp, dtype=dtype), [1] * dim_num)
296 shape = list(inp.shape)
297 dim = [d % inp.ndim for d in dim]
299 if check_dim0(inp, dim):
300 return sum_dim0(inp, dim, keepdim, dtype)
302 if len(dim) == 1:
303 dim = dim[0]
304 N = inp.shape[dim]
305 M = reduce(lambda x, y: x * y, shape[:dim], 1)
306 inp = inp.contiguous()
307 K = inp.numel() // M // N
308 shape[dim] = 1
309 if out is None:
310 out = torch.empty(shape, dtype=dtype, device=inp.device)
312 with torch_device_fn.device(inp.device):
313 if K > 1:
314 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
315 sum_dim_kernel_non_inner[grid](
316 out,
317 inp,
318 M,
319 N,
320 K,
321 )
322 else:
323 grid = (M, 1, 1)
324 sum_dim_kernel_inner[grid](
325 out,
326 inp,
327 M,
328 N,
329 )
330 if not keepdim:
331 out = out.squeeze(dim=dim)
332 return out
333 else:
334 inp = dim_compress(inp, dim)
335 N = 1
336 for i in dim:
337 N *= shape[i]
338 shape[i] = 1
339 M = inp.numel() // N
340 if out is None:
341 out = torch.empty(shape, dtype=dtype, device=inp.device)
343 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
344 with torch_device_fn.device(inp.device):
345 sum_dim_kernel[grid](inp, out, M, N)
346 if not keepdim:
347 out = out.squeeze(dim=dim)
348 return out
351def check_dim0(inp, dim):
352 shape = list(inp.shape)
353 if len(shape) == len(dim):
354 return False
355 for i in dim:
356 shape[i] = 1
357 if shape == [1] * len(shape):
358 return False
360 for i in range(max(dim)):
361 if shape[i] > 1:
362 return False
363 return True
366def sum_dim0(inp, dim, keepdim, dtype):
367 shape = list(inp.shape)
368 N = 1
369 for i in dim:
370 N *= shape[i]
371 shape[i] = 1
372 M = inp.numel() // N
373 out = torch.empty(shape, dtype=dtype, device=inp.device)
374 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
375 with torch_device_fn.device(inp.device):
376 sum_kernel_dim0[grid](inp, out, M, N)
377 if not keepdim:
378 out = out.squeeze(dim=dim)
379 return out
382def sum_dim(inp, dim=None, keepdim=False, *, dtype=None):
383 logger.debug("GEMS SUM_DIM")
384 # support dim = 0, which are consistent with PyTorch
385 if inp.numel() == 0:
386 if dtype is None:
387 dtype = inp.dtype
388 if dtype is torch.bool:
389 dtype = torch.int64
391 out_shape = list(inp.shape)
392 if dim is None:
393 if keepdim:
394 out_shape = [1] * len(out_shape)
395 else:
396 out_shape = []
397 elif isinstance(dim, (list, tuple)) and len(dim) == 0:
398 if keepdim:
399 out_shape = [1] * len(out_shape)
400 else:
401 out_shape = []
402 else:
403 dims_to_reduce = dim if isinstance(dim, (list, tuple)) else [dim]
404 if keepdim:
405 for d in dims_to_reduce:
406 out_shape[d % inp.ndim] = 1
407 else:
408 sorted_dims_to_remove = sorted(
409 dims_to_reduce, key=lambda x: x % inp.ndim, reverse=True
410 )
411 for d in sorted_dims_to_remove:
412 index_to_remove = d % inp.ndim
413 out_shape.pop(index_to_remove)
414 out = torch.empty(out_shape, dtype=dtype, device=inp.device)
415 zero_(out)
416 return out
417 return sum_dim_comm(inp, dim, keepdim, dtype=dtype)
420def sum_dim_out(inp, dim=None, keepdim=False, *, dtype=None, out):
421 logger.debug("GEMS SUM_DIM_OUT")
422 return sum_dim_comm(inp, dim, keepdim, dtype=dtype, out=out)