Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/all.py: 0%
100 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import dim_compress, libentry
9from flag_gems.utils import triton_lang_extension as ext
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14# torch.all: Tests if all elements in input evaluate to True. If the dtype of input
15# is not BOOL, then test if all elements in input evaluate to non-zero value
16# In triton function, test if all elements in input evaluate to non-zero value is ok.
18cluster_num = 12
19core_num = 64
20buf_len_per_core = 2048
21vector_size = 16
24def heur_m_block_size(args):
25 return triton.next_power_of_2(min(triton.cdiv(args["M"], cluster_num), core_num))
28def heur_n_block_size(args):
29 return triton.next_power_of_2(min(args["N"], 512))
32@triton.jit
33def reduce_all(a, b):
34 return a and b
37@libentry()
38@triton.jit
39def all_global_kernel(
40 inp,
41 out,
42 n_elements,
43 BLOCK_SIZE: tl.constexpr,
44):
45 """Global all over all elements. C++ handler replaces with api::all<T,bool>.
46 Triton fallback: single program loops over chunks of BLOCK_SIZE."""
47 _all = tl.full([BLOCK_SIZE], value=1, dtype=tl.int1)
48 for off in range(0, n_elements, BLOCK_SIZE):
49 offset = off + tl.arange(0, BLOCK_SIZE)
50 mask = offset < n_elements
51 val = tl.load(inp + offset, mask=mask, other=1.0)
52 _all = _all and (val != 0)
53 result = tl.reduce(_all, axis=0, combine_fn=reduce_all)
54 tl.store(out, result)
57@libentry()
58@triton.heuristics(
59 values={
60 "BLOCK_M": heur_m_block_size,
61 "BLOCK_N": heur_n_block_size,
62 },
63)
64@triton.jit
65def all_kernel_dim(
66 inp,
67 out,
68 M,
69 N,
70 BLOCK_M: tl.constexpr,
71 BLOCK_N: tl.constexpr,
72):
73 # Map the program id to the row of inp it should compute.
74 pid = ext.program_id(0)
75 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
76 inp = inp + rows * N
77 out = out + rows
78 row_mask = rows < M
80 _all = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1)
81 for off in range(0, N, BLOCK_N):
82 cols = off + tl.arange(0, BLOCK_N)[None, :]
83 col_mask = cols < N
84 mask = row_mask and col_mask
86 a = tl.load(inp + cols, mask, other=1.0)
87 _all = _all and (a != 0)
88 all = tl.reduce(_all, axis=1, combine_fn=reduce_all)
89 tl.store(out, all[:, None], row_mask)
92def all(inp):
93 logger.debug("GEMS ALL")
94 n_elements = inp.numel()
95 # BLOCK_SIZE must fit in XPU per-core local buffer so the Triton fallback
96 # kernel always compiles. The C++ handler (api::all<T,bool>) ignores this
97 # value and handles any n_elements internally.
98 BLOCK_SIZE = min(triton.next_power_of_2(n_elements), buf_len_per_core)
99 out = torch.empty([], dtype=torch.bool, device=inp.device)
100 with torch_device_fn.device(inp.device):
101 all_global_kernel[(1, 1)](
102 inp, out, n_elements, BLOCK_SIZE, buffer_size_limit=2048
103 )
104 return out
107def all_dim(inp, dim=None, keepdim=False):
108 logger.debug("GEMS ALL DIM")
109 shape = list(inp.shape)
110 if dim is None:
111 out = all(inp)
112 if keepdim:
113 out = torch.reshape(out, [1] * inp.ndim)
114 else:
115 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
116 dim = dim % inp.ndim
117 inp = dim_compress(inp, dim)
118 N = shape[dim]
119 shape[dim] = 1
120 M = inp.numel() // N
122 if N == 1:
123 # N==1: each row has a single element; avoid kernel dispatch for
124 # trivial case that some hardware configs cannot handle.
125 out = (inp.reshape(M) != 0).reshape(shape)
126 else:
127 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
128 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
129 with torch_device_fn.device(inp.device):
130 all_kernel_dim[grid](inp, out, M, N, buffer_size_limit=2048)
132 if not keepdim:
133 out = out.squeeze(dim=dim)
134 return out
137def all_dims(inp, dim=None, keepdim=False):
138 logger.debug("GEMS ALL DIMS")
140 if dim is None or isinstance(dim, int):
141 return all_dim(inp, dim=dim, keepdim=keepdim)
142 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim"
144 shape = list(inp.shape)
145 dim = [d % inp.ndim for d in dim]
146 inp = dim_compress(inp, dim)
147 N = 1
148 for i in dim:
149 N *= shape[i]
150 shape[i] = 1
151 M = inp.numel() // N
153 if N == 1:
154 out = (inp.reshape(M) != 0).reshape(shape)
155 else:
156 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
157 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
158 with torch_device_fn.device(inp.device):
159 all_kernel_dim[grid](inp, out, M, N, buffer_size_limit=2048)
161 if not keepdim:
162 out = out.squeeze(dim=dim)
163 return out