Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/all.py: 0%

100 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import dim_compress, libentry 

9from flag_gems.utils import triton_lang_extension as ext 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14# torch.all: Tests if all elements in input evaluate to True. If the dtype of input 

15# is not BOOL, then test if all elements in input evaluate to non-zero value 

16# In triton function, test if all elements in input evaluate to non-zero value is ok. 

17 

18cluster_num = 12 

19core_num = 64 

20buf_len_per_core = 2048 

21vector_size = 16 

22 

23 

24def heur_m_block_size(args): 

25 return triton.next_power_of_2(min(triton.cdiv(args["M"], cluster_num), core_num)) 

26 

27 

28def heur_n_block_size(args): 

29 return triton.next_power_of_2(min(args["N"], 512)) 

30 

31 

32@triton.jit 

33def reduce_all(a, b): 

34 return a and b 

35 

36 

37@libentry() 

38@triton.jit 

39def all_global_kernel( 

40 inp, 

41 out, 

42 n_elements, 

43 BLOCK_SIZE: tl.constexpr, 

44): 

45 """Global all over all elements. C++ handler replaces with api::all<T,bool>. 

46 Triton fallback: single program loops over chunks of BLOCK_SIZE.""" 

47 _all = tl.full([BLOCK_SIZE], value=1, dtype=tl.int1) 

48 for off in range(0, n_elements, BLOCK_SIZE): 

49 offset = off + tl.arange(0, BLOCK_SIZE) 

50 mask = offset < n_elements 

51 val = tl.load(inp + offset, mask=mask, other=1.0) 

52 _all = _all and (val != 0) 

53 result = tl.reduce(_all, axis=0, combine_fn=reduce_all) 

54 tl.store(out, result) 

55 

56 

57@libentry() 

58@triton.heuristics( 

59 values={ 

60 "BLOCK_M": heur_m_block_size, 

61 "BLOCK_N": heur_n_block_size, 

62 }, 

63) 

64@triton.jit 

65def all_kernel_dim( 

66 inp, 

67 out, 

68 M, 

69 N, 

70 BLOCK_M: tl.constexpr, 

71 BLOCK_N: tl.constexpr, 

72): 

73 # Map the program id to the row of inp it should compute. 

74 pid = ext.program_id(0) 

75 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

76 inp = inp + rows * N 

77 out = out + rows 

78 row_mask = rows < M 

79 

80 _all = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1) 

81 for off in range(0, N, BLOCK_N): 

82 cols = off + tl.arange(0, BLOCK_N)[None, :] 

83 col_mask = cols < N 

84 mask = row_mask and col_mask 

85 

86 a = tl.load(inp + cols, mask, other=1.0) 

87 _all = _all and (a != 0) 

88 all = tl.reduce(_all, axis=1, combine_fn=reduce_all) 

89 tl.store(out, all[:, None], row_mask) 

90 

91 

92def all(inp): 

93 logger.debug("GEMS ALL") 

94 n_elements = inp.numel() 

95 # BLOCK_SIZE must fit in XPU per-core local buffer so the Triton fallback 

96 # kernel always compiles. The C++ handler (api::all<T,bool>) ignores this 

97 # value and handles any n_elements internally. 

98 BLOCK_SIZE = min(triton.next_power_of_2(n_elements), buf_len_per_core) 

99 out = torch.empty([], dtype=torch.bool, device=inp.device) 

100 with torch_device_fn.device(inp.device): 

101 all_global_kernel[(1, 1)]( 

102 inp, out, n_elements, BLOCK_SIZE, buffer_size_limit=2048 

103 ) 

104 return out 

105 

106 

107def all_dim(inp, dim=None, keepdim=False): 

108 logger.debug("GEMS ALL DIM") 

109 shape = list(inp.shape) 

110 if dim is None: 

111 out = all(inp) 

112 if keepdim: 

113 out = torch.reshape(out, [1] * inp.ndim) 

114 else: 

115 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

116 dim = dim % inp.ndim 

117 inp = dim_compress(inp, dim) 

118 N = shape[dim] 

119 shape[dim] = 1 

120 M = inp.numel() // N 

121 

122 if N == 1: 

123 # N==1: each row has a single element; avoid kernel dispatch for 

124 # trivial case that some hardware configs cannot handle. 

125 out = (inp.reshape(M) != 0).reshape(shape) 

126 else: 

127 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

128 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

129 with torch_device_fn.device(inp.device): 

130 all_kernel_dim[grid](inp, out, M, N, buffer_size_limit=2048) 

131 

132 if not keepdim: 

133 out = out.squeeze(dim=dim) 

134 return out 

135 

136 

137def all_dims(inp, dim=None, keepdim=False): 

138 logger.debug("GEMS ALL DIMS") 

139 

140 if dim is None or isinstance(dim, int): 

141 return all_dim(inp, dim=dim, keepdim=keepdim) 

142 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" 

143 

144 shape = list(inp.shape) 

145 dim = [d % inp.ndim for d in dim] 

146 inp = dim_compress(inp, dim) 

147 N = 1 

148 for i in dim: 

149 N *= shape[i] 

150 shape[i] = 1 

151 M = inp.numel() // N 

152 

153 if N == 1: 

154 out = (inp.reshape(M) != 0).reshape(shape) 

155 else: 

156 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

157 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

158 with torch_device_fn.device(inp.device): 

159 all_kernel_dim[grid](inp, out, M, N, buffer_size_limit=2048) 

160 

161 if not keepdim: 

162 out = out.squeeze(dim=dim) 

163 return out