Coverage for src/flag_gems/runtime/backend/_hygon/ops/all.py: 0%

103 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import dim_compress, libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger(__name__) 

14 

15# torch.all: Tests if all elements in input evaluate to True. If the dtype of input 

16# is not BOOL, then test if all elements in input evaluate to non-zero value 

17# In triton function, test if all elements in input evaluate to non-zero value is ok. 

18 

19 

20@triton.jit 

21def reduce_all(a, b): 

22 return a and b 

23 

24 

25@libentry() 

26@triton.autotune(configs=runtime.get_tuned_config("all"), key=["M", "N"]) 

27@triton.jit 

28def all_kernel_dim( 

29 inp, 

30 out, 

31 M, 

32 N, 

33 BLOCK_M: tl.constexpr, 

34 BLOCK_N: tl.constexpr, 

35): 

36 # Map the program id to the row of inp it should compute. 

37 pid = tle.program_id(0) 

38 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

39 inp = inp + rows * N 

40 out = out + rows 

41 row_mask = rows < M 

42 

43 _all = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1) 

44 for off in range(0, N, BLOCK_N): 

45 cols = off + tl.arange(0, BLOCK_N)[None, :] 

46 col_mask = cols < N 

47 mask = row_mask and col_mask 

48 

49 a = tl.load(inp + cols, mask, other=1.0) 

50 _all = _all and (a != 0) 

51 all = tl.reduce(_all, axis=1, combine_fn=reduce_all) 

52 tl.store(out, all[:, None], row_mask) 

53 

54 

55@libentry() 

56@triton.jit 

57def all_kernel_1( 

58 inp, 

59 mid, 

60 n_elements, 

61 mid_size, 

62 BLOCK_SIZE: tl.constexpr, 

63): 

64 pid = tle.program_id(0) 

65 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

66 inp_ptrs = inp + offset 

67 mask = offset < n_elements 

68 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0) 

69 all_val = tl.reduce(inp_val != 0, axis=0, combine_fn=reduce_all) 

70 mid_ptr = mid + pid 

71 tl.store(mid_ptr, all_val) 

72 

73 

74@libentry() 

75@triton.jit 

76def all_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr): 

77 offset = tl.arange(0, BLOCK_MID) 

78 mid_ptrs = mid + offset 

79 mask = offset < MID_SIZE 

80 mid_val = tl.load(mid_ptrs, mask=mask, other=1).to(tl.int1) 

81 all_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_all) 

82 tl.store(out, all_val) 

83 

84 

85def all(inp): 

86 logger.debug("GEMS_HYGON ALL") 

87 n_elements = inp.numel() 

88 block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements))) 

89 mid_size = triton.cdiv(n_elements, block_size) 

90 block_mid = triton.next_power_of_2(mid_size) 

91 

92 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device) 

93 out = torch.empty([], dtype=torch.bool, device=inp.device) 

94 

95 with torch_device_fn.device(inp.device): 

96 all_kernel_1[(mid_size, 1)](inp, mid, n_elements, mid_size, block_size) 

97 all_kernel_2[(1, 1)](mid, out, mid_size, block_mid) 

98 

99 return out 

100 

101 

102def all_dim(inp, dim=None, keepdim=False): 

103 logger.debug("GEMS_HYGON ALL_DIM") 

104 shape = list(inp.shape) 

105 if dim is None: 

106 out = all(inp) 

107 if keepdim: 

108 out = torch.reshape(out, [1] * inp.ndim) 

109 else: 

110 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

111 dim = dim % inp.ndim 

112 inp = dim_compress(inp, dim) 

113 N = shape[dim] 

114 shape[dim] = 1 

115 M = inp.numel() // N 

116 

117 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

118 

119 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

120 with torch_device_fn.device(inp.device): 

121 all_kernel_dim[grid](inp, out, M, N) 

122 if not keepdim: 

123 out = out.squeeze(dim=dim) 

124 return out 

125 

126 

127def all_dims(inp, dim=None, keepdim=False): 

128 logger.debug("GEMS_HYGON ALL_DIMS") 

129 

130 if dim is None or isinstance(dim, int): 

131 return all_dim(inp, dim=dim, keepdim=keepdim) 

132 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" 

133 

134 shape = list(inp.shape) 

135 dim = [d % inp.ndim for d in dim] 

136 inp = dim_compress(inp, dim) 

137 N = 1 

138 for i in dim: 

139 N *= shape[i] 

140 shape[i] = 1 

141 M = inp.numel() // N 

142 

143 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

144 

145 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

146 with torch_device_fn.device(inp.device): 

147 all_kernel_dim[grid](inp, out, M, N) 

148 if not keepdim: 

149 out = out.squeeze(dim=dim) 

150 return out