Coverage for src/flag_gems/ops/searchsorted.py: 66%
136 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import device as runtime_device
8from flag_gems.runtime import torch_device_fn
10logger = logging.getLogger(__name__)
12_CUDA_BLOCK_SIZE = 256
13_ASCEND_BLOCK_SIZE = 512
14_SUPPORTED_INPUT_DTYPES = {
15 torch.uint8,
16 torch.int8,
17 torch.int16,
18 torch.int32,
19 torch.int64,
20 torch.float16,
21 torch.bfloat16,
22 torch.float32,
23 torch.float64,
24}
27@triton.jit
28def _searchsorted_kernel(
29 sorted_sequence,
30 values,
31 sorter,
32 out,
33 total_values,
34 values_per_row,
35 sequence_len,
36 LOG_SEQUENCE_LEN: tl.constexpr,
37 RIGHT: tl.constexpr,
38 HAS_SORTER: tl.constexpr,
39 IS_1D_SEQUENCE: tl.constexpr,
40 USE_INT32_INDEX: tl.constexpr,
41 BLOCK_SIZE: tl.constexpr,
42):
43 offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
44 mask = offsets < total_values
45 values_in = tl.load(values + offsets, mask=mask, other=0)
47 if IS_1D_SEQUENCE:
48 if USE_INT32_INDEX:
49 row_offsets = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)
50 else:
51 row_offsets = tl.zeros((BLOCK_SIZE,), dtype=tl.int64)
52 else:
53 row_offsets = (offsets // values_per_row) * sequence_len
54 if USE_INT32_INDEX:
55 row_offsets = row_offsets.to(tl.int32)
57 if USE_INT32_INDEX:
58 low = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)
59 else:
60 low = tl.zeros((BLOCK_SIZE,), dtype=tl.int64)
61 high = low + sequence_len
63 for _ in range(LOG_SEQUENCE_LEN):
64 active = mask & (low < high)
65 mid = low + (high - low) // 2
66 sorted_offsets = row_offsets + mid
67 if HAS_SORTER:
68 sorted_index = tl.load(sorter + sorted_offsets, mask=active, other=0)
69 if USE_INT32_INDEX:
70 sorted_index = sorted_index.to(tl.int32)
71 sorted_offsets = row_offsets + sorted_index
73 mid_values = tl.load(sorted_sequence + sorted_offsets, mask=active, other=0)
74 if RIGHT:
75 go_left = values_in < mid_values
76 else:
77 go_left = values_in <= mid_values
79 high = tl.where(active & go_left, mid, high)
80 low = tl.where(active & ~go_left, mid + 1, low)
82 tl.store(out + offsets, low, mask=mask)
85def _normalize_right(right: bool, side: str | None) -> bool:
86 if side is None:
87 return bool(right)
88 if side == "left":
89 if right:
90 raise RuntimeError(
91 "torch.searchsorted(): side and right can't be set to opposites, "
92 "got side of left while right was True"
93 )
94 return False
95 if side == "right":
96 return True
97 raise RuntimeError(
98 f"torch.searchsorted(): side can only be 'left' or 'right' but got {side}"
99 )
102def _check_dtype(tensor: torch.Tensor, name: str):
103 if tensor.dtype not in _SUPPORTED_INPUT_DTYPES:
104 raise NotImplementedError(
105 f"searchsorted is not implemented for {name} dtype {tensor.dtype}"
106 )
109def _check_tensor_values_shape(sorted_sequence: torch.Tensor, values: torch.Tensor):
110 if sorted_sequence.dim() == 0:
111 raise RuntimeError(
112 "torch.searchsorted(): boundaries tensor should be 1 dimension or "
113 "the first N-1 dimensions of boundaries tensor and input value tensor "
114 "must match"
115 )
116 if sorted_sequence.dim() == 1:
117 return
118 if values.dim() != sorted_sequence.dim() or (
119 tuple(values.shape[:-1]) != tuple(sorted_sequence.shape[:-1])
120 ):
121 raise RuntimeError(
122 "torch.searchsorted(): boundaries tensor should be 1 dimension or "
123 "the first N-1 dimensions of boundaries tensor and input value tensor "
124 "must match, but we got boundaries tensor "
125 f"{list(sorted_sequence.shape)} and input value tensor {list(values.shape)}"
126 )
129def _check_scalar_values_shape(sorted_sequence: torch.Tensor):
130 if sorted_sequence.dim() != 1:
131 raise RuntimeError(
132 "torch.searchsorted(): input value can be a scalar only when boundaries "
133 "tensor dimension is 1, but we got boundaries tensor "
134 f"dim({sorted_sequence.dim()}) and input value's dim(0) numel(1)"
135 )
138def _check_sorter(sorted_sequence: torch.Tensor, sorter: torch.Tensor | None):
139 if sorter is None:
140 return
141 if tuple(sorter.shape) != tuple(sorted_sequence.shape):
142 raise RuntimeError(
143 "torch.searchsorted(): boundary and sorter must have the same size, "
144 f"but got boundary tensor {list(sorted_sequence.shape)}"
145 f"and got sorter tensor {list(sorter.shape)}"
146 )
147 if sorter.dtype != torch.int64:
148 raise RuntimeError(
149 "torch.searchsorted(): sorter must be a tensor of long dtype but got "
150 f"dtype {sorter.dtype}"
151 )
152 if sorter.device != sorted_sequence.device:
153 raise RuntimeError(
154 "torch.searchsorted(): sorter and boundary tensors must be on the same device"
155 )
156 sequence_len = sorted_sequence.shape[-1]
157 if sorter.numel() != 0 and (
158 torch.any(sorter < 0).item() or torch.any(sorter >= sequence_len).item()
159 ):
160 raise RuntimeError("torch.searchsorted(): sorter index out of range")
163def _prepare_out(
164 values: torch.Tensor,
165 out_int32: bool,
166 out: torch.Tensor | None,
167):
168 out_dtype = torch.int32 if out_int32 else torch.int64
169 if out is None:
170 return torch.empty(values.shape, dtype=out_dtype, device=values.device)
171 if out.dtype != out_dtype:
172 raise RuntimeError(
173 "torch.searchsorted(): output tensor's dtype is wrong, it can only be "
174 "Int(int32) or Long(int64) depending on whether out_int32 flag is True"
175 )
176 if out.device != values.device:
177 raise RuntimeError(
178 "torch.searchsorted(): output tensor must be on the same device as input"
179 )
180 if tuple(out.shape) != tuple(values.shape):
181 out.resize_(values.shape)
182 return out
185def _searchsorted_impl(
186 sorted_sequence: torch.Tensor,
187 values: torch.Tensor,
188 *,
189 out_int32: bool,
190 right: bool,
191 side: str | None,
192 sorter: torch.Tensor | None,
193 out: torch.Tensor | None = None,
194):
195 right = _normalize_right(right, side)
196 _check_dtype(sorted_sequence, "sorted_sequence")
197 _check_dtype(values, "values")
198 _check_tensor_values_shape(sorted_sequence, values)
199 _check_sorter(sorted_sequence, sorter)
200 if values.device != sorted_sequence.device:
201 raise RuntimeError(
202 "torch.searchsorted(): sorted_sequence and values must be on the same device"
203 )
205 out = _prepare_out(values, out_int32, out)
206 if values.numel() == 0:
207 return out
208 if sorted_sequence.shape[-1] == 0:
209 out.zero_()
210 return out
212 sorted_sequence_contiguous = sorted_sequence.contiguous()
213 values_contiguous = values.contiguous()
214 sorter_contiguous = sorter.contiguous() if sorter is not None else None
215 is_ascend = runtime_device.vendor_name == "ascend"
216 if sorter_contiguous is not None and is_ascend:
217 sorted_sequence_contiguous = torch.gather(
218 sorted_sequence_contiguous, -1, sorter_contiguous
219 )
220 sorter_contiguous = None
221 kernel_out = (
222 out
223 if out.is_contiguous()
224 else torch.empty(out.shape, dtype=out.dtype, device=out.device)
225 )
227 sequence_len = sorted_sequence.shape[-1]
228 values_per_row = values.shape[-1] if sorted_sequence.dim() != 1 else values.numel()
229 block_size = (
230 _ASCEND_BLOCK_SIZE
231 if is_ascend and sorted_sequence.dtype.is_floating_point
232 else _CUDA_BLOCK_SIZE
233 )
234 use_int32_index = (
235 is_ascend
236 and values.numel() < torch.iinfo(torch.int32).max
237 and sorted_sequence.numel() < torch.iinfo(torch.int32).max
238 )
240 with torch_device_fn.device(sorted_sequence.device):
241 grid = (triton.cdiv(values.numel(), block_size),)
242 _searchsorted_kernel[grid](
243 sorted_sequence_contiguous,
244 values_contiguous,
245 (
246 sorter_contiguous
247 if sorter_contiguous is not None
248 else sorted_sequence_contiguous
249 ),
250 kernel_out,
251 values.numel(),
252 values_per_row,
253 sequence_len,
254 LOG_SEQUENCE_LEN=sequence_len.bit_length(),
255 RIGHT=right,
256 HAS_SORTER=sorter_contiguous is not None,
257 IS_1D_SEQUENCE=sorted_sequence.dim() == 1,
258 USE_INT32_INDEX=use_int32_index,
259 BLOCK_SIZE=block_size,
260 )
262 if kernel_out is not out:
263 out.copy_(kernel_out)
264 return out
267def searchsorted(
268 sorted_sequence,
269 self,
270 *,
271 out_int32=False,
272 right=False,
273 side=None,
274 sorter=None,
275):
276 logger.debug("GEMS SEARCHSORTED")
277 return _searchsorted_impl(
278 sorted_sequence,
279 self,
280 out_int32=out_int32,
281 right=right,
282 side=side,
283 sorter=sorter,
284 )
287def searchsorted_out(
288 sorted_sequence,
289 self,
290 *,
291 out_int32=False,
292 right=False,
293 side=None,
294 sorter=None,
295 out,
296):
297 logger.debug("GEMS SEARCHSORTED OUT")
298 return _searchsorted_impl(
299 sorted_sequence,
300 self,
301 out_int32=out_int32,
302 right=right,
303 side=side,
304 sorter=sorter,
305 out=out,
306 )
309def searchsorted_scalar(
310 sorted_sequence,
311 self,
312 *,
313 out_int32=False,
314 right=False,
315 side=None,
316 sorter=None,
317):
318 logger.debug("GEMS SEARCHSORTED SCALAR")
319 _check_scalar_values_shape(sorted_sequence)
320 values = torch.scalar_tensor(self, device=sorted_sequence.device)
321 return _searchsorted_impl(
322 sorted_sequence,
323 values,
324 out_int32=out_int32,
325 right=right,
326 side=side,
327 sorter=sorter,
328 )
331def searchsorted_scalar_out(
332 sorted_sequence,
333 self,
334 *,
335 out_int32=False,
336 right=False,
337 side=None,
338 sorter=None,
339 out,
340):
341 logger.debug("GEMS SEARCHSORTED SCALAR OUT")
342 _check_scalar_values_shape(sorted_sequence)
343 values = torch.scalar_tensor(self, device=sorted_sequence.device)
344 return _searchsorted_impl(
345 sorted_sequence,
346 values,
347 out_int32=out_int32,
348 right=right,
349 side=side,
350 sorter=sorter,
351 out=out,
352 )