Coverage for src/flag_gems/ops/searchsorted.py: 66%

136 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import device as runtime_device 

8from flag_gems.runtime import torch_device_fn 

9 

10logger = logging.getLogger(__name__) 

11 

12_CUDA_BLOCK_SIZE = 256 

13_ASCEND_BLOCK_SIZE = 512 

14_SUPPORTED_INPUT_DTYPES = { 

15 torch.uint8, 

16 torch.int8, 

17 torch.int16, 

18 torch.int32, 

19 torch.int64, 

20 torch.float16, 

21 torch.bfloat16, 

22 torch.float32, 

23 torch.float64, 

24} 

25 

26 

27@triton.jit 

28def _searchsorted_kernel( 

29 sorted_sequence, 

30 values, 

31 sorter, 

32 out, 

33 total_values, 

34 values_per_row, 

35 sequence_len, 

36 LOG_SEQUENCE_LEN: tl.constexpr, 

37 RIGHT: tl.constexpr, 

38 HAS_SORTER: tl.constexpr, 

39 IS_1D_SEQUENCE: tl.constexpr, 

40 USE_INT32_INDEX: tl.constexpr, 

41 BLOCK_SIZE: tl.constexpr, 

42): 

43 offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

44 mask = offsets < total_values 

45 values_in = tl.load(values + offsets, mask=mask, other=0) 

46 

47 if IS_1D_SEQUENCE: 

48 if USE_INT32_INDEX: 

49 row_offsets = tl.zeros((BLOCK_SIZE,), dtype=tl.int32) 

50 else: 

51 row_offsets = tl.zeros((BLOCK_SIZE,), dtype=tl.int64) 

52 else: 

53 row_offsets = (offsets // values_per_row) * sequence_len 

54 if USE_INT32_INDEX: 

55 row_offsets = row_offsets.to(tl.int32) 

56 

57 if USE_INT32_INDEX: 

58 low = tl.zeros((BLOCK_SIZE,), dtype=tl.int32) 

59 else: 

60 low = tl.zeros((BLOCK_SIZE,), dtype=tl.int64) 

61 high = low + sequence_len 

62 

63 for _ in range(LOG_SEQUENCE_LEN): 

64 active = mask & (low < high) 

65 mid = low + (high - low) // 2 

66 sorted_offsets = row_offsets + mid 

67 if HAS_SORTER: 

68 sorted_index = tl.load(sorter + sorted_offsets, mask=active, other=0) 

69 if USE_INT32_INDEX: 

70 sorted_index = sorted_index.to(tl.int32) 

71 sorted_offsets = row_offsets + sorted_index 

72 

73 mid_values = tl.load(sorted_sequence + sorted_offsets, mask=active, other=0) 

74 if RIGHT: 

75 go_left = values_in < mid_values 

76 else: 

77 go_left = values_in <= mid_values 

78 

79 high = tl.where(active & go_left, mid, high) 

80 low = tl.where(active & ~go_left, mid + 1, low) 

81 

82 tl.store(out + offsets, low, mask=mask) 

83 

84 

85def _normalize_right(right: bool, side: str | None) -> bool: 

86 if side is None: 

87 return bool(right) 

88 if side == "left": 

89 if right: 

90 raise RuntimeError( 

91 "torch.searchsorted(): side and right can't be set to opposites, " 

92 "got side of left while right was True" 

93 ) 

94 return False 

95 if side == "right": 

96 return True 

97 raise RuntimeError( 

98 f"torch.searchsorted(): side can only be 'left' or 'right' but got {side}" 

99 ) 

100 

101 

102def _check_dtype(tensor: torch.Tensor, name: str): 

103 if tensor.dtype not in _SUPPORTED_INPUT_DTYPES: 

104 raise NotImplementedError( 

105 f"searchsorted is not implemented for {name} dtype {tensor.dtype}" 

106 ) 

107 

108 

109def _check_tensor_values_shape(sorted_sequence: torch.Tensor, values: torch.Tensor): 

110 if sorted_sequence.dim() == 0: 

111 raise RuntimeError( 

112 "torch.searchsorted(): boundaries tensor should be 1 dimension or " 

113 "the first N-1 dimensions of boundaries tensor and input value tensor " 

114 "must match" 

115 ) 

116 if sorted_sequence.dim() == 1: 

117 return 

118 if values.dim() != sorted_sequence.dim() or ( 

119 tuple(values.shape[:-1]) != tuple(sorted_sequence.shape[:-1]) 

120 ): 

121 raise RuntimeError( 

122 "torch.searchsorted(): boundaries tensor should be 1 dimension or " 

123 "the first N-1 dimensions of boundaries tensor and input value tensor " 

124 "must match, but we got boundaries tensor " 

125 f"{list(sorted_sequence.shape)} and input value tensor {list(values.shape)}" 

126 ) 

127 

128 

129def _check_scalar_values_shape(sorted_sequence: torch.Tensor): 

130 if sorted_sequence.dim() != 1: 

131 raise RuntimeError( 

132 "torch.searchsorted(): input value can be a scalar only when boundaries " 

133 "tensor dimension is 1, but we got boundaries tensor " 

134 f"dim({sorted_sequence.dim()}) and input value's dim(0) numel(1)" 

135 ) 

136 

137 

138def _check_sorter(sorted_sequence: torch.Tensor, sorter: torch.Tensor | None): 

139 if sorter is None: 

140 return 

141 if tuple(sorter.shape) != tuple(sorted_sequence.shape): 

142 raise RuntimeError( 

143 "torch.searchsorted(): boundary and sorter must have the same size, " 

144 f"but got boundary tensor {list(sorted_sequence.shape)}" 

145 f"and got sorter tensor {list(sorter.shape)}" 

146 ) 

147 if sorter.dtype != torch.int64: 

148 raise RuntimeError( 

149 "torch.searchsorted(): sorter must be a tensor of long dtype but got " 

150 f"dtype {sorter.dtype}" 

151 ) 

152 if sorter.device != sorted_sequence.device: 

153 raise RuntimeError( 

154 "torch.searchsorted(): sorter and boundary tensors must be on the same device" 

155 ) 

156 sequence_len = sorted_sequence.shape[-1] 

157 if sorter.numel() != 0 and ( 

158 torch.any(sorter < 0).item() or torch.any(sorter >= sequence_len).item() 

159 ): 

160 raise RuntimeError("torch.searchsorted(): sorter index out of range") 

161 

162 

163def _prepare_out( 

164 values: torch.Tensor, 

165 out_int32: bool, 

166 out: torch.Tensor | None, 

167): 

168 out_dtype = torch.int32 if out_int32 else torch.int64 

169 if out is None: 

170 return torch.empty(values.shape, dtype=out_dtype, device=values.device) 

171 if out.dtype != out_dtype: 

172 raise RuntimeError( 

173 "torch.searchsorted(): output tensor's dtype is wrong, it can only be " 

174 "Int(int32) or Long(int64) depending on whether out_int32 flag is True" 

175 ) 

176 if out.device != values.device: 

177 raise RuntimeError( 

178 "torch.searchsorted(): output tensor must be on the same device as input" 

179 ) 

180 if tuple(out.shape) != tuple(values.shape): 

181 out.resize_(values.shape) 

182 return out 

183 

184 

185def _searchsorted_impl( 

186 sorted_sequence: torch.Tensor, 

187 values: torch.Tensor, 

188 *, 

189 out_int32: bool, 

190 right: bool, 

191 side: str | None, 

192 sorter: torch.Tensor | None, 

193 out: torch.Tensor | None = None, 

194): 

195 right = _normalize_right(right, side) 

196 _check_dtype(sorted_sequence, "sorted_sequence") 

197 _check_dtype(values, "values") 

198 _check_tensor_values_shape(sorted_sequence, values) 

199 _check_sorter(sorted_sequence, sorter) 

200 if values.device != sorted_sequence.device: 

201 raise RuntimeError( 

202 "torch.searchsorted(): sorted_sequence and values must be on the same device" 

203 ) 

204 

205 out = _prepare_out(values, out_int32, out) 

206 if values.numel() == 0: 

207 return out 

208 if sorted_sequence.shape[-1] == 0: 

209 out.zero_() 

210 return out 

211 

212 sorted_sequence_contiguous = sorted_sequence.contiguous() 

213 values_contiguous = values.contiguous() 

214 sorter_contiguous = sorter.contiguous() if sorter is not None else None 

215 is_ascend = runtime_device.vendor_name == "ascend" 

216 if sorter_contiguous is not None and is_ascend: 

217 sorted_sequence_contiguous = torch.gather( 

218 sorted_sequence_contiguous, -1, sorter_contiguous 

219 ) 

220 sorter_contiguous = None 

221 kernel_out = ( 

222 out 

223 if out.is_contiguous() 

224 else torch.empty(out.shape, dtype=out.dtype, device=out.device) 

225 ) 

226 

227 sequence_len = sorted_sequence.shape[-1] 

228 values_per_row = values.shape[-1] if sorted_sequence.dim() != 1 else values.numel() 

229 block_size = ( 

230 _ASCEND_BLOCK_SIZE 

231 if is_ascend and sorted_sequence.dtype.is_floating_point 

232 else _CUDA_BLOCK_SIZE 

233 ) 

234 use_int32_index = ( 

235 is_ascend 

236 and values.numel() < torch.iinfo(torch.int32).max 

237 and sorted_sequence.numel() < torch.iinfo(torch.int32).max 

238 ) 

239 

240 with torch_device_fn.device(sorted_sequence.device): 

241 grid = (triton.cdiv(values.numel(), block_size),) 

242 _searchsorted_kernel[grid]( 

243 sorted_sequence_contiguous, 

244 values_contiguous, 

245 ( 

246 sorter_contiguous 

247 if sorter_contiguous is not None 

248 else sorted_sequence_contiguous 

249 ), 

250 kernel_out, 

251 values.numel(), 

252 values_per_row, 

253 sequence_len, 

254 LOG_SEQUENCE_LEN=sequence_len.bit_length(), 

255 RIGHT=right, 

256 HAS_SORTER=sorter_contiguous is not None, 

257 IS_1D_SEQUENCE=sorted_sequence.dim() == 1, 

258 USE_INT32_INDEX=use_int32_index, 

259 BLOCK_SIZE=block_size, 

260 ) 

261 

262 if kernel_out is not out: 

263 out.copy_(kernel_out) 

264 return out 

265 

266 

267def searchsorted( 

268 sorted_sequence, 

269 self, 

270 *, 

271 out_int32=False, 

272 right=False, 

273 side=None, 

274 sorter=None, 

275): 

276 logger.debug("GEMS SEARCHSORTED") 

277 return _searchsorted_impl( 

278 sorted_sequence, 

279 self, 

280 out_int32=out_int32, 

281 right=right, 

282 side=side, 

283 sorter=sorter, 

284 ) 

285 

286 

287def searchsorted_out( 

288 sorted_sequence, 

289 self, 

290 *, 

291 out_int32=False, 

292 right=False, 

293 side=None, 

294 sorter=None, 

295 out, 

296): 

297 logger.debug("GEMS SEARCHSORTED OUT") 

298 return _searchsorted_impl( 

299 sorted_sequence, 

300 self, 

301 out_int32=out_int32, 

302 right=right, 

303 side=side, 

304 sorter=sorter, 

305 out=out, 

306 ) 

307 

308 

309def searchsorted_scalar( 

310 sorted_sequence, 

311 self, 

312 *, 

313 out_int32=False, 

314 right=False, 

315 side=None, 

316 sorter=None, 

317): 

318 logger.debug("GEMS SEARCHSORTED SCALAR") 

319 _check_scalar_values_shape(sorted_sequence) 

320 values = torch.scalar_tensor(self, device=sorted_sequence.device) 

321 return _searchsorted_impl( 

322 sorted_sequence, 

323 values, 

324 out_int32=out_int32, 

325 right=right, 

326 side=side, 

327 sorter=sorter, 

328 ) 

329 

330 

331def searchsorted_scalar_out( 

332 sorted_sequence, 

333 self, 

334 *, 

335 out_int32=False, 

336 right=False, 

337 side=None, 

338 sorter=None, 

339 out, 

340): 

341 logger.debug("GEMS SEARCHSORTED SCALAR OUT") 

342 _check_scalar_values_shape(sorted_sequence) 

343 values = torch.scalar_tensor(self, device=sorted_sequence.device) 

344 return _searchsorted_impl( 

345 sorted_sequence, 

346 values, 

347 out_int32=out_int32, 

348 right=right, 

349 side=side, 

350 sorter=sorter, 

351 out=out, 

352 )