Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/all.py: 0%

109 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import dim_compress, libentry 

9from flag_gems.utils import triton_lang_extension as ext 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14# torch.all: Tests if all elements in input evaluate to True. If the dtype of input 

15# is not BOOL, then test if all elements in input evaluate to non-zero value 

16# In triton function, test if all elements in input evaluate to non-zero value is ok. 

17 

18cluster_num = 12 

19core_num = 64 

20buf_len_per_core = 2048 

21vector_size = 16 

22 

23 

24def heur_m_block_size(args): 

25 M = args["M"] 

26 # For very small M, use minimum BLOCK_M of 1 

27 block_m = min(triton.cdiv(M, cluster_num), core_num) 

28 return triton.next_power_of_2(max(block_m, 1)) 

29 

30 

31def heur_n_block_size(args): 

32 N = args["N"] 

33 # For very small N, use minimum BLOCK_N of 1 

34 block_n = min(N, 512) 

35 return triton.next_power_of_2(max(block_n, 1)) 

36 

37 

38@triton.jit 

39def reduce_all(a, b): 

40 return a and b 

41 

42 

43@libentry() 

44@triton.jit 

45def all_global_kernel( 

46 inp, 

47 out, 

48 n_elements, 

49 BLOCK_SIZE: tl.constexpr, 

50): 

51 """Global all over all elements. C++ handler replaces with api::all<T,bool>. 

52 Triton fallback: single program loops over chunks of BLOCK_SIZE.""" 

53 _all = tl.full([BLOCK_SIZE], value=1, dtype=tl.int1) 

54 for off in range(0, n_elements, BLOCK_SIZE): 

55 offset = off + tl.arange(0, BLOCK_SIZE) 

56 mask = offset < n_elements 

57 val = tl.load(inp + offset, mask=mask, other=1.0) 

58 _all = _all and (val != 0) 

59 result = tl.reduce(_all, axis=0, combine_fn=reduce_all) 

60 tl.store(out, result) 

61 

62 

63@libentry() 

64@triton.heuristics( 

65 values={ 

66 "BLOCK_M": heur_m_block_size, 

67 "BLOCK_N": heur_n_block_size, 

68 }, 

69) 

70@triton.jit 

71def all_kernel_dim( 

72 inp, 

73 out, 

74 M, 

75 N, 

76 BLOCK_M: tl.constexpr, 

77 BLOCK_N: tl.constexpr, 

78): 

79 # Map the program id to the row of inp it should compute. 

80 pid = ext.program_id(0) 

81 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

82 inp = inp + rows * N 

83 out = out + rows 

84 row_mask = rows < M 

85 

86 _all = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1) 

87 for off in range(0, N, BLOCK_N): 

88 cols = off + tl.arange(0, BLOCK_N)[None, :] 

89 col_mask = cols < N 

90 mask = row_mask and col_mask 

91 

92 a = tl.load(inp + cols, mask, other=1.0) 

93 _all = _all and (a != 0) 

94 all = tl.reduce(_all, axis=1, combine_fn=reduce_all) 

95 tl.store(out, all[:, None], row_mask) 

96 

97 

98def all(inp): 

99 logger.debug("GEMS_KUNLUNXIN ALL") 

100 n_elements = inp.numel() 

101 # BLOCK_SIZE must fit in XPU per-core local buffer so the Triton fallback 

102 # kernel always compiles. The C++ handler (api::all<T,bool>) ignores this 

103 # value and handles any n_elements internally. 

104 BLOCK_SIZE = min(triton.next_power_of_2(n_elements), buf_len_per_core) 

105 out = torch.empty([], dtype=torch.bool, device=inp.device) 

106 with torch_device_fn.device(inp.device): 

107 all_global_kernel[(1, 1)]( 

108 inp, out, n_elements, BLOCK_SIZE, buffer_size_limit=2048 

109 ) 

110 return out 

111 

112 

113def all_dim(inp, dim=None, keepdim=False): 

114 logger.debug("GEMS_KUNLUNXIN ALL_DIM") 

115 shape = list(inp.shape) 

116 orig_ndim = inp.ndim 

117 

118 if dim is None: 

119 out = all(inp) 

120 if keepdim: 

121 out = torch.reshape(out, [1] * orig_ndim) 

122 return out 

123 

124 assert dim >= -orig_ndim and dim < orig_ndim, "Invalid dim" 

125 dim = dim % orig_ndim 

126 N = shape[dim] 

127 inp = dim_compress(inp, dim) 

128 shape[dim] = 1 

129 M = inp.numel() // N 

130 

131 if inp.dtype != torch.bool and M * N <= 64: 

132 inp = inp != 0 

133 

134 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

135 grid = lambda meta: (max(triton.cdiv(M, meta["BLOCK_M"]), 1),) 

136 with torch_device_fn.device(inp.device): 

137 all_kernel_dim[grid](inp, out, M, N, buffer_size_limit=2048) 

138 

139 if not keepdim and out.ndim > 0: 

140 out = out.squeeze(dim) if dim < out.ndim else out 

141 return out 

142 

143 

144def all_dims(inp, dim=None, keepdim=False): 

145 logger.debug("GEMS_KUNLUNXIN ALL_DIMS") 

146 

147 if dim is None or isinstance(dim, int): 

148 return all_dim(inp, dim=dim, keepdim=keepdim) 

149 orig_ndim = inp.ndim 

150 assert ((i >= -orig_ndim and i < orig_ndim) for i in dim), "Invalid dim" 

151 

152 shape = list(inp.shape) 

153 dim = [d % orig_ndim for d in dim] 

154 inp = dim_compress(inp, dim) 

155 N = 1 

156 for i in dim: 

157 N *= shape[i] 

158 shape[i] = 1 

159 M = inp.numel() // N 

160 

161 if inp.dtype != torch.bool and M * N <= 64: 

162 inp = inp != 0 

163 

164 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

165 grid = lambda meta: (max(triton.cdiv(M, meta["BLOCK_M"]), 1),) 

166 with torch_device_fn.device(inp.device): 

167 all_kernel_dim[grid](inp, out, M, N, buffer_size_limit=2048) 

168 

169 if not keepdim: 

170 for d in sorted(dim): 

171 if out.ndim > 0: 

172 out = out.squeeze(dim=d) 

173 return out