Coverage for src/flag_gems/runtime/backend/_hygon/ops/all.py: 0%
103 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry
11from flag_gems.utils import triton_lang_extension as ext
13logger = logging.getLogger(__name__)
15# torch.all: Tests if all elements in input evaluate to True. If the dtype of input
16# is not BOOL, then test if all elements in input evaluate to non-zero value
17# In triton function, test if all elements in input evaluate to non-zero value is ok.
20@triton.jit
21def reduce_all(a, b):
22 return a and b
25@libentry()
26@triton.autotune(configs=runtime.get_tuned_config("all"), key=["M", "N"])
27@triton.jit
28def all_kernel_dim(
29 inp,
30 out,
31 M,
32 N,
33 BLOCK_M: tl.constexpr,
34 BLOCK_N: tl.constexpr,
35):
36 # Map the program id to the row of inp it should compute.
37 pid = ext.program_id(0)
38 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
39 inp = inp + rows * N
40 out = out + rows
41 row_mask = rows < M
43 _all = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1)
44 for off in range(0, N, BLOCK_N):
45 cols = off + tl.arange(0, BLOCK_N)[None, :]
46 col_mask = cols < N
47 mask = row_mask and col_mask
49 a = tl.load(inp + cols, mask, other=1.0)
50 _all = _all and (a != 0)
51 all = tl.reduce(_all, axis=1, combine_fn=reduce_all)
52 tl.store(out, all[:, None], row_mask)
55@libentry()
56@triton.jit
57def all_kernel_1(
58 inp,
59 mid,
60 n_elements,
61 mid_size,
62 BLOCK_SIZE: tl.constexpr,
63):
64 pid = ext.program_id(0)
65 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
66 inp_ptrs = inp + offset
67 mask = offset < n_elements
68 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0)
69 all_val = tl.reduce(inp_val != 0, axis=0, combine_fn=reduce_all)
70 mid_ptr = mid + pid
71 tl.store(mid_ptr, all_val)
74@libentry()
75@triton.jit
76def all_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr):
77 offset = tl.arange(0, BLOCK_MID)
78 mid_ptrs = mid + offset
79 mask = offset < MID_SIZE
80 mid_val = tl.load(mid_ptrs, mask=mask, other=1).to(tl.int1)
81 all_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_all)
82 tl.store(out, all_val)
85def all(inp):
86 logger.debug("GEMS_HYGON ALL")
87 n_elements = inp.numel()
88 block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements)))
89 mid_size = triton.cdiv(n_elements, block_size)
90 block_mid = triton.next_power_of_2(mid_size)
92 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device)
93 out = torch.empty([], dtype=torch.bool, device=inp.device)
95 with torch_device_fn.device(inp.device):
96 all_kernel_1[(mid_size, 1)](inp, mid, n_elements, mid_size, block_size)
97 all_kernel_2[(1, 1)](mid, out, mid_size, block_mid)
99 return out
102def all_dim(inp, dim=None, keepdim=False):
103 logger.debug("GEMS_HYGON ALL_DIM")
104 shape = list(inp.shape)
105 if dim is None:
106 out = all(inp)
107 if keepdim:
108 out = torch.reshape(out, [1] * inp.ndim)
109 else:
110 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
111 dim = dim % inp.ndim
112 inp = dim_compress(inp, dim)
113 N = shape[dim]
114 shape[dim] = 1
115 M = inp.numel() // N
117 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
119 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
120 with torch_device_fn.device(inp.device):
121 all_kernel_dim[grid](inp, out, M, N)
122 if not keepdim:
123 out = out.squeeze(dim=dim)
124 return out
127def all_dims(inp, dim=None, keepdim=False):
128 logger.debug("GEMS_HYGON ALL_DIMS")
130 if dim is None or isinstance(dim, int):
131 return all_dim(inp, dim=dim, keepdim=keepdim)
132 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim"
134 shape = list(inp.shape)
135 dim = [d % inp.ndim for d in dim]
136 inp = dim_compress(inp, dim)
137 N = 1
138 for i in dim:
139 N *= shape[i]
140 shape[i] = 1
141 M = inp.numel() // N
143 out = torch.empty(shape, dtype=torch.bool, device=inp.device)
145 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
146 with torch_device_fn.device(inp.device):
147 all_kernel_dim[grid](inp, out, M, N)
148 if not keepdim:
149 out = out.squeeze(dim=dim)
150 return out