Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/isin.py: 0%
141 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
3import os
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import triton_lang_extension as ext
11from flag_gems.utils.libentry import libentry
13from .all import reduce_all
14from .any import reduce_any
15from .unique import _unique2
17logger = logging.getLogger(__name__)
20def launch_arg(BLOCK_M, BLOCK_N, N, num_warps):
21 return BLOCK_M, min(BLOCK_N, triton.next_power_of_2(N)), num_warps
24@triton.jit
25def isin_by_comparation_impl(
26 global_pid,
27 in0_ravel_ptr: tl.tensor,
28 in1_ravel_ptr: tl.tensor, # in
29 out_ptr: tl.tensor, # out
30 M: int, # num_tasks
31 N: int, # num_tasks_1
32 BLOCK_M: tl.constexpr, # tile_size
33 BLOCK_N: tl.constexpr, # tile_size_1
34 invert: tl.constexpr,
35):
36 row_off = global_pid * BLOCK_M
37 rows = row_off + tl.arange(0, BLOCK_M)[:, None]
38 row_mask = rows < M
39 out_ptr += rows
40 in0_ravel_ptr += rows + tl.zeros([BLOCK_N], dtype=tl.int32)
41 in1_ravel_ptr += tl.zeros([BLOCK_M], dtype=tl.int32)[:, None]
43 block = tl.full([BLOCK_M, BLOCK_N], value=(1 if invert else 0), dtype=tl.int1)
44 in0 = tl.load(in0_ravel_ptr, row_mask, other=0)
45 for col_off in range(0, N, BLOCK_N):
46 cols = col_off + tl.arange(0, BLOCK_N)[None, :]
47 col_mask = cols < N
48 mask = row_mask and col_mask
49 in1 = tl.load(in1_ravel_ptr + cols, mask, other=0)
50 block = tl.where(
51 mask,
52 tl.where(invert, block and (in0 != in1), block or (in0 == in1)),
53 invert,
54 )
55 out = tl.reduce(block, axis=1, combine_fn=(reduce_all if invert else reduce_any))
56 tl.store(out_ptr, out[:, None], row_mask)
59@libentry()
60@triton.jit
61def isin_by_comparation_kernel(
62 in0_ravel_ptr: tl.tensor,
63 in1_ravel_ptr: tl.tensor, # in
64 out_ptr: tl.tensor, # out
65 M: int, # num_tasks
66 N: int, # num_tasks_1
67 BLOCK_M: tl.constexpr, # tile_size
68 BLOCK_N: tl.constexpr, # tile_size_1
69 tiles_per_cta: int,
70 invert: tl.constexpr,
71):
72 pid = ext.program_id(0)
73 ctas_num = ext.num_programs(0)
74 # grid-stride-loop style kernel
75 for j in range(0, tiles_per_cta):
76 global_pid = pid + j * ctas_num
77 isin_by_comparation_impl(
78 global_pid,
79 in0_ravel_ptr,
80 in1_ravel_ptr, # in
81 out_ptr, # out
82 M,
83 N,
84 BLOCK_M,
85 BLOCK_N,
86 invert,
87 )
90def isin_by_comparation(
91 in0: torch.tensor,
92 in1: torch.tensor,
93 invert: bool,
94):
95 in0_ravel = in0.contiguous().ravel()
96 in1_ravel = in1.contiguous().ravel()
97 M = in0.numel()
98 N = in1.numel()
99 if M <= 1024:
100 BLOCK_M, BLOCK_N, num_warps = launch_arg(1, 256, N, 4)
101 elif M <= 3072:
102 BLOCK_M, BLOCK_N, num_warps = launch_arg(2, 256, N, 4)
103 elif M <= 6144:
104 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4)
105 elif M <= 9216:
106 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 256, N, 8)
107 else:
108 BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4)
109 ctas_num = min(65536, triton.cdiv(M, BLOCK_M))
110 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num)
111 grid = (ctas_num,)
112 out = torch.empty_like(in0_ravel, dtype=torch.bool)
113 with torch_device_fn.device(in0_ravel.device.index):
114 isin_by_comparation_kernel[grid](
115 in0_ravel,
116 in1_ravel, # in
117 out, # out
118 M,
119 N,
120 BLOCK_M,
121 BLOCK_N,
122 tiles_per_cta=tiles_per_cta,
123 invert=invert,
124 num_warps=num_warps,
125 )
126 return out.view_as(in0)
129@triton.jit
130def isin_by_search_impl(
131 global_pid,
132 in0_ravel_ptr: tl.tensor,
133 in1_sorted_ptr: tl.tensor, # in
134 out_ptr: tl.tensor, # out
135 M: int, # num_tasks
136 N: int, # num_tasks_1
137 log_n: tl.constexpr,
138 BLOCK_M: tl.constexpr, # tile_size
139 invert: tl.constexpr,
140):
141 r = tl.arange(0, BLOCK_M)
142 i0 = global_pid * BLOCK_M + r
143 mask = i0 < M
145 # load in0_ravel
146 in0_ravel = tl.load(in0_ravel_ptr + i0, mask=mask)
148 # binary search: lower_bound
149 out = tl.zeros_like(r).to(tl.int1)
150 start = tl.zeros_like(r)
151 end = start + N
152 while_mask = start < end
153 for i in range(log_n):
154 mid = tl.where(while_mask, start + (end - start) // 2, 0)
155 mid_val = tl.load(in1_sorted_ptr + mid, mask=while_mask)
156 out = tl.where(while_mask, out or (mid_val == in0_ravel), out) # found
157 start = tl.where(while_mask and (mid_val < in0_ravel), mid + 1, start)
158 end = tl.where(while_mask and (mid_val > in0_ravel), mid, end)
159 while_mask = start < end
161 # store out
162 out_offset = tl.where(mask, i0, M + 1)
163 tl.store(out_ptr + out_offset, not out if invert else out, mask=mask)
166@libentry()
167@triton.jit
168def isin_by_search_kernel(
169 in0_ravel_ptr: tl.tensor,
170 in1_sorted_ptr: tl.tensor, # in
171 out_ptr: tl.tensor, # out
172 M: int, # num_tasks
173 N: int, # num_tasks_1
174 log_n: tl.constexpr,
175 BLOCK_M: tl.constexpr, # tile_size
176 tiles_per_cta: int,
177 invert: tl.constexpr,
178):
179 pid = ext.program_id(0)
180 ctas_num = ext.num_programs(0)
181 # grid-stride-loop style kernel
182 for j in range(0, tiles_per_cta):
183 global_pid = pid + j * ctas_num
184 isin_by_search_impl(
185 global_pid,
186 in0_ravel_ptr,
187 in1_sorted_ptr, # in
188 out_ptr, # out
189 M,
190 N,
191 log_n,
192 BLOCK_M,
193 invert,
194 )
197def isin_by_search(
198 in0: torch.tensor,
199 in1: torch.tensor,
200 invert: bool,
201 unique_in0: bool,
202 unique_in1: bool,
203):
204 # unique or sort or ravel
205 if unique_in0:
206 # print("hit _unique2!!!")
207 in0_ravel, unique_order, _ = _unique2(
208 in0, sorted=True, return_inverse=True, return_counts=False
209 )
210 else:
211 in0_ravel = in0.contiguous().ravel()
212 if unique_in1:
213 # print("hit _unique2!!!")
214 in1_ravel, _, _ = _unique2(
215 in1, sorted=True, return_inverse=False, return_counts=False
216 )
217 else:
218 in1_ravel, _ = torch.sort(in1.ravel())
219 # launch kernel func
220 M = in0_ravel.numel()
221 N = in1_ravel.numel()
222 if M <= 1048576: # 2 ** 20 = 1024 * 1024
223 _, BLOCK_M, num_warps = launch_arg(None, 512, M, 8)
224 elif M <= 4194304: # 2 ** 22 = 1024 * 4096
225 _, BLOCK_M, num_warps = launch_arg(None, 1024, M, 8)
226 elif M <= 8388608: # 2 ** 23 = 1024 * 8192
227 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16)
228 elif M <= 268435456: # 2 ** 28 = 1024 * 262144
229 _, BLOCK_M, num_warps = launch_arg(None, 4096, M, 32)
230 else:
231 _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16)
232 log_n = int(math.log2(N)) + 1
233 ctas_num = min(65536, triton.cdiv(M, BLOCK_M))
234 tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num)
235 # print(f"M = {M}")
236 # print(f"BLOCK_M = {BLOCK_M}")
237 # print(f"ctas_num = {ctas_num}")
238 # print(f"tiles_per_cta = {tiles_per_cta}")
239 grid = (ctas_num,)
240 out = torch.empty_like(in0_ravel, dtype=torch.bool)
241 with torch_device_fn.device(in0_ravel.device.index):
242 os.environ["TRITONXPU_OTHER_SIM"] = "1"
243 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
244 os.environ["TRITONXPU_INTERLEAVE"] = "0"
245 isin_by_search_kernel[grid](
246 in0_ravel,
247 in1_ravel, # in
248 out, # out
249 M,
250 N,
251 log_n,
252 BLOCK_M,
253 tiles_per_cta=tiles_per_cta,
254 invert=invert,
255 num_warps=num_warps,
256 isCloseUnrollControl=True,
257 )
258 if "TRITONXPU_OTHER_SIM" in os.environ:
259 del os.environ["TRITONXPU_OTHER_SIM"]
260 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
261 del os.environ["TRITONXPU_STORE_MASK_SIM"]
262 if "TRITONXPU_INTERLEAVE" in os.environ:
263 del os.environ["TRITONXPU_INTERLEAVE"]
265 if unique_in0:
266 out = torch.gather(out, 0, unique_order.ravel().to(torch.int64))
267 return out.view_as(in0)
270def isin(
271 in0,
272 in1,
273 *,
274 assume_unique: bool = False,
275 invert: bool = False,
276) -> torch.Tensor:
277 logger.debug("GEMS_KUNLUNXIN ISIN")
278 if not torch.is_tensor(in0):
279 assert torch.is_tensor(in1)
280 in0 = torch.tensor(in0, device=in1.device)
281 elif not torch.is_tensor(in1):
282 assert torch.is_tensor(in0)
283 in1 = torch.tensor(in1, device=in0.device)
284 if in0.numel() == 0 or in1.numel() == 0:
285 return torch.zeros_like(in0, dtype=torch.bool)
286 elif in0.numel() <= 2048 and in1.numel() <= 2048:
287 # Use comparison only for very small sizes where kernel launch overhead dominates
288 return isin_by_comparation(in0, in1, invert)
289 elif assume_unique or in1.numel() <= 4194304: # 1024 * 4096
290 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=False)
291 else:
292 return isin_by_search(in0, in1, invert, unique_in0=False, unique_in1=True)