Coverage for src/flag_gems/runtime/backend/_cambricon/ops/sum.py: 0%
151 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import dim_compress, libentry, libtuner
11from ..utils import MAX_GRID_SIZE_X, TOTAL_CORE_NUM, cfggen_reduce_op
12from .zeros import zero_
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17@libentry()
18@libtuner(
19 configs=cfggen_reduce_op(), key=["M"], strategy=["log"], reset_to_zero=["out"]
20)
21@triton.jit
22def sum_kernel_1(
23 inp,
24 out,
25 M,
26 BLOCK_SIZE: tl.constexpr,
27):
28 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
29 inp.dtype.element_ty == tl.bfloat16
30 ):
31 cdtype = tl.float32
32 else:
33 cdtype = inp.dtype.element_ty
35 pid = tl.program_id(0)
36 num_jobs = tl.num_programs(axis=0)
37 block_start = pid * BLOCK_SIZE
38 step = num_jobs * BLOCK_SIZE
39 _tmp = tl.zeros([BLOCK_SIZE], dtype=cdtype)
40 block_start = block_start.to(tl.int64)
41 for off in range(block_start, M, step):
42 offset = off + tl.arange(0, BLOCK_SIZE)
43 mask = offset < M
44 inp_val = tl.load(inp + offset, mask=mask, other=0.0)
45 _tmp = inp_val + _tmp
47 sum_val = tl.sum(_tmp)
48 tl.atomic_add(out, sum_val)
51@libentry()
52@libtuner(
53 configs=runtime.get_tuned_config("sum"),
54 key=["M", "N"],
55 strategy=["log", "log"],
56)
57@triton.jit
58def sum_kernel(
59 inp,
60 out,
61 M,
62 N,
63 BLOCK_M: tl.constexpr,
64 BLOCK_N: tl.constexpr,
65):
66 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
67 inp.dtype.element_ty == tl.bfloat16
68 ):
69 cdtype = tl.float32
70 elif tl.constexpr(inp.dtype.element_ty == tl.int1):
71 cdtype = tl.int32
72 else:
73 cdtype = inp.dtype.element_ty
74 prog_num = tl.num_programs(0).to(tl.uint64)
75 sub_pid = tl.program_id(0).to(tl.uint64)
76 task_num = tl.cdiv(M, BLOCK_M).to(tl.uint64)
77 while sub_pid < task_num:
78 # Map the program id to the row of inp it should compute.
79 pid = sub_pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
80 inp_ = inp + pid * N
81 out_ = out + pid
82 row_mask = pid < M
84 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype)
85 for off in range(0, N, BLOCK_N):
86 cols = off + tl.arange(0, BLOCK_N)[None, :]
87 col_mask = cols < N
88 mask = row_mask and col_mask
90 a = tl.load(inp_ + cols, mask, other=0).to(cdtype)
91 _sum += a
92 sum = tl.sum(_sum, axis=1)[:, None]
93 tl.store(out_, sum, row_mask)
94 sub_pid += prog_num
97def sum(inp, *, dtype=None):
98 logger.debug("GEMS_CAMBRICON SUM")
99 inp = inp.contiguous()
100 M = inp.numel()
101 if dtype is None:
102 dtype = inp.dtype
103 if dtype is torch.bool:
104 inp = inp.to(torch.int32)
105 dtype = torch.int32
107 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
108 out = torch.zeros([], dtype=dtype, device=inp.device)
110 with torch_device_fn.device(inp.device):
111 sum_kernel_1[grid](inp, out, M)
112 return out.to(dtype)
115def sum_out(inp, *, dtype=None, out):
116 logger.debug("GEMS_CAMBRICON SUM_OUT")
117 M = inp.numel()
118 if dtype is None:
119 dtype = inp.dtype
120 if dtype is torch.bool:
121 inp = inp.to(torch.int32)
122 dtype = torch.int32
124 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
126 with torch_device_fn.device(inp.device):
127 sum_kernel_1[grid](inp, out, M)
128 return out.to(dtype)
131def sum_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None):
132 if dtype is None:
133 dtype = inp.dtype
134 if dtype is torch.bool:
135 dtype = torch.int64
137 if dim is None:
138 result = torch.sum(inp, dtype=dtype)
139 if keepdim:
140 result = result.reshape([1] * inp.ndim)
141 return result
143 if dim == []:
144 if not keepdim:
145 return sum(inp, dtype=dtype)
146 else:
147 dim_num = inp.ndim
148 return torch.reshape(sum(inp, dtype=dtype), [1] * dim_num)
149 shape = list(inp.shape)
150 dim = [d % inp.ndim for d in dim]
152 inp = dim_compress(inp, dim)
153 N = 1
154 for i in dim:
155 N *= shape[i]
156 shape[i] = 1
157 M = inp.numel() // N
158 _out_provided = out is not None
159 if _out_provided:
160 dim_set = set(dim)
161 if keepdim:
162 out.resize_(shape)
163 else:
164 out.resize_([s for i, s in enumerate(shape) if i not in dim_set])
165 else:
166 out = torch.empty(shape, dtype=dtype, device=inp.device)
167 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_M"]), MAX_GRID_SIZE_X // 4),)
168 with torch_device_fn.device(inp.device):
169 sum_kernel[grid](inp, out, M, N)
170 if not keepdim and not _out_provided:
171 for d in sorted(dim, reverse=True):
172 out = out.squeeze(dim=d)
173 return out
176def sum_dim(inp, dim=None, keepdim=False, *, dtype=None):
177 logger.debug("GEMS_CAMBRICON SUM DIM")
178 # support dim = 0, which are consistent with PyTorch
179 if inp.numel() == 0:
180 if dtype is None:
181 dtype = inp.dtype
182 if dtype is torch.bool:
183 dtype = torch.int64
185 out_shape = list(inp.shape)
186 if dim is None:
187 if keepdim:
188 out_shape = [1] * len(out_shape)
189 else:
190 out_shape = []
191 elif isinstance(dim, (list, tuple)) and len(dim) == 0:
192 if keepdim:
193 out_shape = [1] * len(out_shape)
194 else:
195 out_shape = []
196 else:
197 dims_to_reduce = dim if isinstance(dim, (list, tuple)) else [dim]
198 if keepdim:
199 for d in dims_to_reduce:
200 out_shape[d % inp.ndim] = 1
201 else:
202 sorted_dims_to_remove = sorted(
203 dims_to_reduce, key=lambda x: x % inp.ndim, reverse=True
204 )
205 for d in sorted_dims_to_remove:
206 index_to_remove = d % inp.ndim
207 out_shape.pop(index_to_remove)
208 out = torch.empty(out_shape, dtype=dtype, device=inp.device)
209 zero_(out)
210 return out
211 return sum_dim_comm(inp, dim, keepdim, dtype=dtype)
214def sum_dim_out(inp, dim=None, keepdim=False, *, dtype=None, out):
215 logger.debug("GEMS_CAMBRICON SUM_DIM_OUT")
216 return sum_dim_comm(inp, dim, keepdim, dtype=dtype, out=out)