Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/all.py: 0%
109 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import dim_compress, libentry
9from flag_gems.utils import triton_lang_extension as ext
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14# torch.all: Tests if all elements in input evaluate to True. If the dtype of input
15# is not BOOL, then test if all elements in input evaluate to non-zero value
16# In triton function, test if all elements in input evaluate to non-zero value is ok.
18cluster_num = 12
19core_num = 64
20buf_len_per_core = 2048
21vector_size = 16
24def heur_m_block_size(args):
25 M = args["M"]
26 # For very small M, use minimum BLOCK_M of 1
27 block_m = min(triton.cdiv(M, cluster_num), core_num)
28 return triton.next_power_of_2(max(block_m, 1))
31def heur_n_block_size(args):
32 N = args["N"]
33 # For very small N, use minimum BLOCK_N of 1
34 block_n = min(N, 512)
35 return triton.next_power_of_2(max(block_n, 1))
38@triton.jit
39def reduce_all(a, b):
40 return a and b
43@libentry()
44@triton.jit
45def all_global_kernel(
46 inp,
47 out,
48 n_elements,
49 BLOCK_SIZE: tl.constexpr,
50):
51 """Global all over all elements. C++ handler replaces with api::all<T,bool>.
52 Triton fallback: single program loops over chunks of BLOCK_SIZE."""
53 _all = tl.full([BLOCK_SIZE], value=1, dtype=tl.int1)
54 for off in range(0, n_elements, BLOCK_SIZE):
55 offset = off + tl.arange(0, BLOCK_SIZE)
56 mask = offset < n_elements
57 val = tl.load(inp + offset, mask=mask, other=1.0)
58 _all = _all and (val != 0)
59 result = tl.reduce(_all, axis=0, combine_fn=reduce_all)
60 tl.store(out, result)
63@libentry()
64@triton.heuristics(
65 values={
66 "BLOCK_M": heur_m_block_size,
67 "BLOCK_N": heur_n_block_size,
68 },
69)
70@triton.jit
71def all_kernel_dim(
72 inp,
73 out,
74 M,
75 N,
76 BLOCK_M: tl.constexpr,
77 BLOCK_N: tl.constexpr,
78):
79 # Map the program id to the row of inp it should compute.
80 pid = ext.program_id(0)
81 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
82 inp = inp + rows * N
83 out = out + rows
84 row_mask = rows < M
86 _all = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1)
87 for off in range(0, N, BLOCK_N):
88 cols = off + tl.arange(0, BLOCK_N)[None, :]
89 col_mask = cols < N
90 mask = row_mask and col_mask
92 a = tl.load(inp + cols, mask, other=1.0)
93 _all = _all and (a != 0)
94 all = tl.reduce(_all, axis=1, combine_fn=reduce_all)
95 tl.store(out, all[:, None], row_mask)
98def all(inp):
99 logger.debug("GEMS_KUNLUNXIN ALL")
100 n_elements = inp.numel()
101 # BLOCK_SIZE must fit in XPU per-core local buffer so the Triton fallback
102 # kernel always compiles. The C++ handler (api::all<T,bool>) ignores this
103 # value and handles any n_elements internally.
104 BLOCK_SIZE = min(triton.next_power_of_2(n_elements), buf_len_per_core)
105 out = torch.empty([], dtype=torch.bool, device=inp.device)
106 with torch_device_fn.device(inp.device):
107 all_global_kernel[(1, 1)](
108 inp, out, n_elements, BLOCK_SIZE, buffer_size_limit=2048
109 )
110 return out
113def all_dim(inp, dim=None, keepdim=False):
114 logger.debug("GEMS_KUNLUNXIN ALL_DIM")
115 shape = list(inp.shape)
116 orig_ndim = inp.ndim
118 if dim is None:
119 out = all(inp)
120 if keepdim:
121 out = torch.reshape(out, [1] * orig_ndim)
122 return out
124 assert dim >= -orig_ndim and dim < orig_ndim, "Invalid dim"
125 dim = dim % orig_ndim
126 N = shape[dim]
127 inp = dim_compress(inp, dim)
128 shape[dim] = 1
129 M = inp.numel() // N
131 if inp.dtype != torch.bool and M * N <= 64:
132 inp = inp != 0
134 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
135 grid = lambda meta: (max(triton.cdiv(M, meta["BLOCK_M"]), 1),)
136 with torch_device_fn.device(inp.device):
137 all_kernel_dim[grid](inp, out, M, N, buffer_size_limit=2048)
139 if not keepdim and out.ndim > 0:
140 out = out.squeeze(dim) if dim < out.ndim else out
141 return out
144def all_dims(inp, dim=None, keepdim=False):
145 logger.debug("GEMS_KUNLUNXIN ALL_DIMS")
147 if dim is None or isinstance(dim, int):
148 return all_dim(inp, dim=dim, keepdim=keepdim)
149 orig_ndim = inp.ndim
150 assert ((i >= -orig_ndim and i < orig_ndim) for i in dim), "Invalid dim"
152 shape = list(inp.shape)
153 dim = [d % orig_ndim for d in dim]
154 inp = dim_compress(inp, dim)
155 N = 1
156 for i in dim:
157 N *= shape[i]
158 shape[i] = 1
159 M = inp.numel() // N
161 if inp.dtype != torch.bool and M * N <= 64:
162 inp = inp != 0
164 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
165 grid = lambda meta: (max(triton.cdiv(M, meta["BLOCK_M"]), 1),)
166 with torch_device_fn.device(inp.device):
167 all_kernel_dim[grid](inp, out, M, N, buffer_size_limit=2048)
169 if not keepdim:
170 for d in sorted(dim):
171 if out.ndim > 0:
172 out = out.squeeze(dim=d)
173 return out