Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/isin.py: 0%

141 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import math 

3import os 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import triton_lang_extension as ext 

11from flag_gems.utils.libentry import libentry 

12 

13from .all import reduce_all 

14from .any import reduce_any 

15from .unique import _unique2 

16 

17logger = logging.getLogger(__name__) 

18 

19 

20def launch_arg(BLOCK_M, BLOCK_N, N, num_warps): 

21 return BLOCK_M, min(BLOCK_N, triton.next_power_of_2(N)), num_warps 

22 

23 

24@triton.jit 

25def isin_by_comparation_impl( 

26 global_pid, 

27 in0_ravel_ptr: tl.tensor, 

28 in1_ravel_ptr: tl.tensor, # in 

29 out_ptr: tl.tensor, # out 

30 M: int, # num_tasks 

31 N: int, # num_tasks_1 

32 BLOCK_M: tl.constexpr, # tile_size 

33 BLOCK_N: tl.constexpr, # tile_size_1 

34 invert: tl.constexpr, 

35): 

36 row_off = global_pid * BLOCK_M 

37 rows = row_off + tl.arange(0, BLOCK_M)[:, None] 

38 row_mask = rows < M 

39 out_ptr += rows 

40 in0_ravel_ptr += rows + tl.zeros([BLOCK_N], dtype=tl.int32) 

41 in1_ravel_ptr += tl.zeros([BLOCK_M], dtype=tl.int32)[:, None] 

42 

43 block = tl.full([BLOCK_M, BLOCK_N], value=(1 if invert else 0), dtype=tl.int1) 

44 in0 = tl.load(in0_ravel_ptr, row_mask, other=0) 

45 for col_off in range(0, N, BLOCK_N): 

46 cols = col_off + tl.arange(0, BLOCK_N)[None, :] 

47 col_mask = cols < N 

48 mask = row_mask and col_mask 

49 in1 = tl.load(in1_ravel_ptr + cols, mask, other=0) 

50 block = tl.where( 

51 mask, 

52 tl.where(invert, block and (in0 != in1), block or (in0 == in1)), 

53 invert, 

54 ) 

55 out = tl.reduce(block, axis=1, combine_fn=(reduce_all if invert else reduce_any)) 

56 tl.store(out_ptr, out[:, None], row_mask) 

57 

58 

59@libentry() 

60@triton.jit 

61def isin_by_comparation_kernel( 

62 in0_ravel_ptr: tl.tensor, 

63 in1_ravel_ptr: tl.tensor, # in 

64 out_ptr: tl.tensor, # out 

65 M: int, # num_tasks 

66 N: int, # num_tasks_1 

67 BLOCK_M: tl.constexpr, # tile_size 

68 BLOCK_N: tl.constexpr, # tile_size_1 

69 tiles_per_cta: int, 

70 invert: tl.constexpr, 

71): 

72 pid = ext.program_id(0) 

73 ctas_num = ext.num_programs(0) 

74 # grid-stride-loop style kernel 

75 for j in range(0, tiles_per_cta): 

76 global_pid = pid + j * ctas_num 

77 isin_by_comparation_impl( 

78 global_pid, 

79 in0_ravel_ptr, 

80 in1_ravel_ptr, # in 

81 out_ptr, # out 

82 M, 

83 N, 

84 BLOCK_M, 

85 BLOCK_N, 

86 invert, 

87 ) 

88 

89 

90def isin_by_comparation( 

91 in0: torch.tensor, 

92 in1: torch.tensor, 

93 invert: bool, 

94): 

95 in0_ravel = in0.contiguous().ravel() 

96 in1_ravel = in1.contiguous().ravel() 

97 M = in0.numel() 

98 N = in1.numel() 

99 if M <= 1024: 

100 BLOCK_M, BLOCK_N, num_warps = launch_arg(1, 256, N, 4) 

101 elif M <= 3072: 

102 BLOCK_M, BLOCK_N, num_warps = launch_arg(2, 256, N, 4) 

103 elif M <= 6144: 

104 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4) 

105 elif M <= 9216: 

106 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 256, N, 8) 

107 else: 

108 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4) 

109 ctas_num = min(65536, triton.cdiv(M, BLOCK_M)) 

110 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num) 

111 grid = (ctas_num,) 

112 out = torch.empty_like(in0_ravel, dtype=torch.bool) 

113 with torch_device_fn.device(in0_ravel.device.index): 

114 isin_by_comparation_kernel[grid]( 

115 in0_ravel, 

116 in1_ravel, # in 

117 out, # out 

118 M, 

119 N, 

120 BLOCK_M, 

121 BLOCK_N, 

122 tiles_per_cta=tiles_per_cta, 

123 invert=invert, 

124 num_warps=num_warps, 

125 ) 

126 return out.view_as(in0) 

127 

128 

129@triton.jit 

130def isin_by_search_impl( 

131 global_pid, 

132 in0_ravel_ptr: tl.tensor, 

133 in1_sorted_ptr: tl.tensor, # in 

134 out_ptr: tl.tensor, # out 

135 M: int, # num_tasks 

136 N: int, # num_tasks_1 

137 log_n: tl.constexpr, 

138 BLOCK_M: tl.constexpr, # tile_size 

139 invert: tl.constexpr, 

140): 

141 r = tl.arange(0, BLOCK_M) 

142 i0 = global_pid * BLOCK_M + r 

143 mask = i0 < M 

144 

145 # load in0_ravel 

146 in0_ravel = tl.load(in0_ravel_ptr + i0, mask=mask) 

147 

148 # binary search: lower_bound 

149 out = tl.zeros_like(r).to(tl.int1) 

150 start = tl.zeros_like(r) 

151 end = start + N 

152 while_mask = start < end 

153 for i in range(log_n): 

154 mid = tl.where(while_mask, start + (end - start) // 2, 0) 

155 mid_val = tl.load(in1_sorted_ptr + mid, mask=while_mask) 

156 out = tl.where(while_mask, out or (mid_val == in0_ravel), out) # found 

157 start = tl.where(while_mask and (mid_val < in0_ravel), mid + 1, start) 

158 end = tl.where(while_mask and (mid_val > in0_ravel), mid, end) 

159 while_mask = start < end 

160 

161 # store out 

162 out_offset = tl.where(mask, i0, M + 1) 

163 tl.store(out_ptr + out_offset, not out if invert else out, mask=mask) 

164 

165 

166@libentry() 

167@triton.jit 

168def isin_by_search_kernel( 

169 in0_ravel_ptr: tl.tensor, 

170 in1_sorted_ptr: tl.tensor, # in 

171 out_ptr: tl.tensor, # out 

172 M: int, # num_tasks 

173 N: int, # num_tasks_1 

174 log_n: tl.constexpr, 

175 BLOCK_M: tl.constexpr, # tile_size 

176 tiles_per_cta: int, 

177 invert: tl.constexpr, 

178): 

179 pid = ext.program_id(0) 

180 ctas_num = ext.num_programs(0) 

181 # grid-stride-loop style kernel 

182 for j in range(0, tiles_per_cta): 

183 global_pid = pid + j * ctas_num 

184 isin_by_search_impl( 

185 global_pid, 

186 in0_ravel_ptr, 

187 in1_sorted_ptr, # in 

188 out_ptr, # out 

189 M, 

190 N, 

191 log_n, 

192 BLOCK_M, 

193 invert, 

194 ) 

195 

196 

197def isin_by_search( 

198 in0: torch.tensor, 

199 in1: torch.tensor, 

200 invert: bool, 

201 unique_in0: bool, 

202 unique_in1: bool, 

203): 

204 # unique or sort or ravel 

205 if unique_in0: 

206 # print("hit _unique2!!!") 

207 in0_ravel, unique_order, _ = _unique2( 

208 in0, sorted=True, return_inverse=True, return_counts=False 

209 ) 

210 else: 

211 in0_ravel = in0.contiguous().ravel() 

212 if unique_in1: 

213 # print("hit _unique2!!!") 

214 in1_ravel, _, _ = _unique2( 

215 in1, sorted=True, return_inverse=False, return_counts=False 

216 ) 

217 else: 

218 in1_ravel, _ = torch.sort(in1.ravel()) 

219 # launch kernel func 

220 M = in0_ravel.numel() 

221 N = in1_ravel.numel() 

222 if M <= 1048576: # 2 ** 20 = 1024 * 1024 

223 _, BLOCK_M, num_warps = launch_arg(None, 512, M, 8) 

224 elif M <= 4194304: # 2 ** 22 = 1024 * 4096 

225 _, BLOCK_M, num_warps = launch_arg(None, 1024, M, 8) 

226 elif M <= 8388608: # 2 ** 23 = 1024 * 8192 

227 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16) 

228 elif M <= 268435456: # 2 ** 28 = 1024 * 262144 

229 _, BLOCK_M, num_warps = launch_arg(None, 4096, M, 32) 

230 else: 

231 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16) 

232 log_n = int(math.log2(N)) + 1 

233 ctas_num = min(65536, triton.cdiv(M, BLOCK_M)) 

234 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num) 

235 # print(f"M = {M}") 

236 # print(f"BLOCK_M = {BLOCK_M}") 

237 # print(f"ctas_num = {ctas_num}") 

238 # print(f"tiles_per_cta = {tiles_per_cta}") 

239 grid = (ctas_num,) 

240 out = torch.empty_like(in0_ravel, dtype=torch.bool) 

241 with torch_device_fn.device(in0_ravel.device.index): 

242 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

243 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

244 os.environ["TRITONXPU_INTERLEAVE"] = "0" 

245 isin_by_search_kernel[grid]( 

246 in0_ravel, 

247 in1_ravel, # in 

248 out, # out 

249 M, 

250 N, 

251 log_n, 

252 BLOCK_M, 

253 tiles_per_cta=tiles_per_cta, 

254 invert=invert, 

255 num_warps=num_warps, 

256 isCloseUnrollControl=True, 

257 ) 

258 if "TRITONXPU_OTHER_SIM" in os.environ: 

259 del os.environ["TRITONXPU_OTHER_SIM"] 

260 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

261 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

262 if "TRITONXPU_INTERLEAVE" in os.environ: 

263 del os.environ["TRITONXPU_INTERLEAVE"] 

264 

265 if unique_in0: 

266 out = torch.gather(out, 0, unique_order.ravel().to(torch.int64)) 

267 return out.view_as(in0) 

268 

269 

270def isin( 

271 in0, 

272 in1, 

273 *, 

274 assume_unique: bool = False, 

275 invert: bool = False, 

276) -> torch.Tensor: 

277 logger.debug("GEMS_KUNLUNXIN ISIN") 

278 if not torch.is_tensor(in0): 

279 assert torch.is_tensor(in1) 

280 in0 = torch.tensor(in0, device=in1.device) 

281 elif not torch.is_tensor(in1): 

282 assert torch.is_tensor(in0) 

283 in1 = torch.tensor(in1, device=in0.device) 

284 if in0.numel() == 0 or in1.numel() == 0: 

285 return torch.zeros_like(in0, dtype=torch.bool) 

286 elif in0.numel() <= 2048 and in1.numel() <= 2048: 

287 # Use comparison only for very small sizes where kernel launch overhead dominates 

288 return isin_by_comparation(in0, in1, invert) 

289 elif assume_unique or in1.numel() <= 4194304: # 1024 * 4096 

290 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=False) 

291 else: 

292 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=True)