Coverage for src/flag_gems/ops/_is_all_true.py: 68%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as ext 

11 

12logger = logging.getLogger(__name__) 

13 

14# _is_all_true: Tests if all elements in a boolean tensor are True. 

15# This is a specialized version of torch.all that only accepts bool tensors. 

16# Returns a scalar boolean tensor. 

17 

18 

19@triton.jit 

20def reduce_all(a, b): 

21 return a and b 

22 

23 

24@libentry() 

25@triton.jit 

26def is_all_true_kernel_1( 

27 inp, 

28 mid, 

29 n_elements, 

30 BLOCK_SIZE: tl.constexpr, 

31): 

32 pid = ext.program_id(0) 

33 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

34 inp_ptrs = inp + offset 

35 mask = offset < n_elements 

36 # Load bool values, use True (1) as the "other" value for masked elements 

37 inp_val = tl.load(inp_ptrs, mask=mask, other=1).to(tl.int1) 

38 all_val = tl.reduce(inp_val, axis=0, combine_fn=reduce_all) 

39 mid_ptr = mid + pid 

40 tl.store(mid_ptr, all_val) 

41 

42 

43@libentry() 

44@triton.jit 

45def is_all_true_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr): 

46 offset = tl.arange(0, BLOCK_MID) 

47 mid_ptrs = mid + offset 

48 mask = offset < MID_SIZE 

49 mid_val = tl.load(mid_ptrs, mask=mask, other=1).to(tl.int1) 

50 all_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_all) 

51 tl.store(out, all_val) 

52 

53 

54def _is_all_true(inp): 

55 logger.debug("GEMS _IS_ALL_TRUE") 

56 # _is_all_true only accepts bool tensors 

57 assert inp.dtype == torch.bool, "Input tensor must be of type bool" 

58 

59 n_elements = inp.numel() 

60 

61 # Handle empty tensor case: all() of empty set is True (vacuous truth) 

62 if n_elements == 0: 

63 return torch.tensor(True, dtype=torch.bool, device=inp.device) 

64 

65 block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements))) 

66 mid_size = triton.cdiv(n_elements, block_size) 

67 block_mid = triton.next_power_of_2(mid_size) 

68 

69 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device) 

70 out = torch.empty([], dtype=torch.bool, device=inp.device) 

71 

72 with torch_device_fn.device(inp.device): 

73 is_all_true_kernel_1[(mid_size, 1)](inp, mid, n_elements, block_size) 

74 is_all_true_kernel_2[(1, 1)](mid, out, mid_size, block_mid) 

75 

76 return out