Coverage for src/flag_gems/runtime/backend/_arm/ops/all.py: 0%

96 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import builtins 

2import logging 

3import math 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.utils import dim_compress 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13# torch.all: Tests if all elements in input evaluate to True. If the dtype of input 

14# is not BOOL, then test if all elements in input evaluate to non-zero value 

15# In triton function, test if all elements in input evaluate to non-zero value is ok. 

16 

17 

18@triton.jit 

19def reduce_all(a, b): 

20 return a and b 

21 

22 

23# @libentry() 

24@triton.autotune(configs=runtime.get_tuned_config("all"), key=["M", "N"]) 

25@triton.jit 

26def all_kernel_dim( 

27 inp, 

28 out, 

29 M, 

30 N, 

31 BLOCK_M: tl.constexpr, 

32 BLOCK_N: tl.constexpr, 

33): 

34 # Map the program id to the row of inp it should compute. 

35 pid = tle.program_id(0) 

36 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

37 inp = inp + rows * N 

38 out = out + rows 

39 row_mask = rows < M 

40 

41 _all = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1) 

42 for off in range(0, N, BLOCK_N): 

43 cols = off + tl.arange(0, BLOCK_N)[None, :] 

44 col_mask = cols < N 

45 mask = row_mask and col_mask 

46 

47 a = tl.load(inp + cols, mask, other=1.0) 

48 _all = _all and (a != 0) 

49 all = tl.reduce(_all, axis=1, combine_fn=reduce_all) 

50 tl.store(out, all[:, None], row_mask) 

51 

52 

53# @libentry() 

54@triton.jit 

55def all_kernel_1( 

56 inp, 

57 mid, 

58 n_elements, 

59 mid_size, 

60 BLOCK_SIZE: tl.constexpr, 

61): 

62 pid = tle.program_id(0) 

63 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

64 inp_ptrs = inp + offset 

65 mask = offset < n_elements 

66 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0) 

67 all_val = tl.reduce(inp_val != 0, axis=0, combine_fn=reduce_all) 

68 mid_ptr = mid + pid 

69 tl.store(mid_ptr, all_val) 

70 

71 

72# @libentry() 

73@triton.jit 

74def all_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr): 

75 offset = tl.arange(0, BLOCK_MID) 

76 mid_ptrs = mid + offset 

77 mask = offset < MID_SIZE 

78 mid_val = tl.load(mid_ptrs, mask=mask, other=1).to(tl.int1) 

79 all_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_all) 

80 tl.store(out, all_val) 

81 

82 

83def all(inp): 

84 logging.debug("GEMS ALL") 

85 n_elements = inp.numel() 

86 block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements))) 

87 mid_size = triton.cdiv(n_elements, block_size) 

88 block_mid = triton.next_power_of_2(mid_size) 

89 

90 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device) 

91 out = torch.empty([], dtype=torch.bool, device=inp.device) 

92 

93 # with torch_device_fn.device(inp.device): 

94 all_kernel_1[(mid_size, 1)](inp, mid, n_elements, mid_size, block_size) 

95 all_kernel_2[(1, 1)](mid, out, mid_size, block_mid) 

96 

97 return out 

98 

99 

100def all_dim(inp, dim=None, keepdim=False): 

101 logging.debug("GEMS ALL DIM") 

102 shape = list(inp.shape) 

103 if dim is None: 

104 out = all(inp) 

105 if keepdim: 

106 out = torch.reshape(out, [1] * inp.ndim) 

107 else: 

108 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

109 dim = dim % inp.ndim 

110 inp = dim_compress(inp, dim) 

111 N = shape[dim] 

112 shape[dim] = 1 

113 M = inp.numel() // N 

114 

115 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

116 

117 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

118 # with torch_device_fn.device(inp.device): 

119 all_kernel_dim[grid](inp, out, M, N) 

120 if not keepdim: 

121 out = out.squeeze(dim=dim) 

122 return out 

123 

124 

125def all_dims(inp, dim=None, keepdim=False): 

126 logging.debug("GEMS ALL DIMS") 

127 

128 if dim is None or isinstance(dim, int): 

129 return all_dim(inp, dim=dim, keepdim=keepdim) 

130 assert builtins.all((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" 

131 

132 shape = list(inp.shape) 

133 dim = [d % inp.ndim for d in dim] 

134 inp = dim_compress(inp, dim) 

135 N = 1 

136 for i in dim: 

137 N *= shape[i] 

138 shape[i] = 1 

139 M = inp.numel() // N 

140 

141 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

142 

143 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

144 all_kernel_dim[grid](inp, out, M, N) 

145 if not keepdim: 

146 out = out.squeeze(dim=dim) 

147 return out