Coverage for src/flag_gems/ops/_is_all_true.py: 68%
47 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as ext
12logger = logging.getLogger(__name__)
14# _is_all_true: Tests if all elements in a boolean tensor are True.
15# This is a specialized version of torch.all that only accepts bool tensors.
16# Returns a scalar boolean tensor.
19@triton.jit
20def reduce_all(a, b):
21 return a and b
24@libentry()
25@triton.jit
26def is_all_true_kernel_1(
27 inp,
28 mid,
29 n_elements,
30 BLOCK_SIZE: tl.constexpr,
31):
32 pid = ext.program_id(0)
33 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
34 inp_ptrs = inp + offset
35 mask = offset < n_elements
36 # Load bool values, use True (1) as the "other" value for masked elements
37 inp_val = tl.load(inp_ptrs, mask=mask, other=1).to(tl.int1)
38 all_val = tl.reduce(inp_val, axis=0, combine_fn=reduce_all)
39 mid_ptr = mid + pid
40 tl.store(mid_ptr, all_val)
43@libentry()
44@triton.jit
45def is_all_true_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr):
46 offset = tl.arange(0, BLOCK_MID)
47 mid_ptrs = mid + offset
48 mask = offset < MID_SIZE
49 mid_val = tl.load(mid_ptrs, mask=mask, other=1).to(tl.int1)
50 all_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_all)
51 tl.store(out, all_val)
54def _is_all_true(inp):
55 logger.debug("GEMS _IS_ALL_TRUE")
56 # _is_all_true only accepts bool tensors
57 assert inp.dtype == torch.bool, "Input tensor must be of type bool"
59 n_elements = inp.numel()
61 # Handle empty tensor case: all() of empty set is True (vacuous truth)
62 if n_elements == 0:
63 return torch.tensor(True, dtype=torch.bool, device=inp.device)
65 block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements)))
66 mid_size = triton.cdiv(n_elements, block_size)
67 block_mid = triton.next_power_of_2(mid_size)
69 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device)
70 out = torch.empty([], dtype=torch.bool, device=inp.device)
72 with torch_device_fn.device(inp.device):
73 is_all_true_kernel_1[(mid_size, 1)](inp, mid, n_elements, block_size)
74 is_all_true_kernel_2[(1, 1)](mid, out, mid_size, block_mid)
76 return out