Coverage for src/flag_gems/runtime/backend/_arm/ops/all.py: 0%
96 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import builtins
2import logging
3import math
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.utils import dim_compress
11from flag_gems.utils import triton_lang_extension as tle
13# torch.all: Tests if all elements in input evaluate to True. If the dtype of input
14# is not BOOL, then test if all elements in input evaluate to non-zero value
15# In triton function, test if all elements in input evaluate to non-zero value is ok.
18@triton.jit
19def reduce_all(a, b):
20 return a and b
23# @libentry()
24@triton.autotune(configs=runtime.get_tuned_config("all"), key=["M", "N"])
25@triton.jit
26def all_kernel_dim(
27 inp,
28 out,
29 M,
30 N,
31 BLOCK_M: tl.constexpr,
32 BLOCK_N: tl.constexpr,
33):
34 # Map the program id to the row of inp it should compute.
35 pid = tle.program_id(0)
36 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
37 inp = inp + rows * N
38 out = out + rows
39 row_mask = rows < M
41 _all = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1)
42 for off in range(0, N, BLOCK_N):
43 cols = off + tl.arange(0, BLOCK_N)[None, :]
44 col_mask = cols < N
45 mask = row_mask and col_mask
47 a = tl.load(inp + cols, mask, other=1.0)
48 _all = _all and (a != 0)
49 all = tl.reduce(_all, axis=1, combine_fn=reduce_all)
50 tl.store(out, all[:, None], row_mask)
53# @libentry()
54@triton.jit
55def all_kernel_1(
56 inp,
57 mid,
58 n_elements,
59 mid_size,
60 BLOCK_SIZE: tl.constexpr,
61):
62 pid = tle.program_id(0)
63 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
64 inp_ptrs = inp + offset
65 mask = offset < n_elements
66 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0)
67 all_val = tl.reduce(inp_val != 0, axis=0, combine_fn=reduce_all)
68 mid_ptr = mid + pid
69 tl.store(mid_ptr, all_val)
72# @libentry()
73@triton.jit
74def all_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr):
75 offset = tl.arange(0, BLOCK_MID)
76 mid_ptrs = mid + offset
77 mask = offset < MID_SIZE
78 mid_val = tl.load(mid_ptrs, mask=mask, other=1).to(tl.int1)
79 all_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_all)
80 tl.store(out, all_val)
83def all(inp):
84 logging.debug("GEMS ALL")
85 n_elements = inp.numel()
86 block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements)))
87 mid_size = triton.cdiv(n_elements, block_size)
88 block_mid = triton.next_power_of_2(mid_size)
90 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device)
91 out = torch.empty([], dtype=torch.bool, device=inp.device)
93 # with torch_device_fn.device(inp.device):
94 all_kernel_1[(mid_size, 1)](inp, mid, n_elements, mid_size, block_size)
95 all_kernel_2[(1, 1)](mid, out, mid_size, block_mid)
97 return out
100def all_dim(inp, dim=None, keepdim=False):
101 logging.debug("GEMS ALL DIM")
102 shape = list(inp.shape)
103 if dim is None:
104 out = all(inp)
105 if keepdim:
106 out = torch.reshape(out, [1] * inp.ndim)
107 else:
108 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
109 dim = dim % inp.ndim
110 inp = dim_compress(inp, dim)
111 N = shape[dim]
112 shape[dim] = 1
113 M = inp.numel() // N
115 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
117 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
118 # with torch_device_fn.device(inp.device):
119 all_kernel_dim[grid](inp, out, M, N)
120 if not keepdim:
121 out = out.squeeze(dim=dim)
122 return out
125def all_dims(inp, dim=None, keepdim=False):
126 logging.debug("GEMS ALL DIMS")
128 if dim is None or isinstance(dim, int):
129 return all_dim(inp, dim=dim, keepdim=keepdim)
130 assert builtins.all((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim"
132 shape = list(inp.shape)
133 dim = [d % inp.ndim for d in dim]
134 inp = dim_compress(inp, dim)
135 N = 1
136 for i in dim:
137 N *= shape[i]
138 shape[i] = 1
139 M = inp.numel() // N
141 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
143 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
144 all_kernel_dim[grid](inp, out, M, N)
145 if not keepdim:
146 out = out.squeeze(dim=dim)
147 return out