Coverage for src/flag_gems/runtime/backend/_sunrise/ops/randperm.py: 0%
266 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.ops.topk import argsort
9from flag_gems.runtime import device, torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils.random_utils import philox_backend_seed_offset
13logger = logging.getLogger(__name__)
14device_ = device
16_MIN_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).min)
17_MAX_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).max)
18_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min)
19_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max)
20_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min)
21_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max)
22_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min)
23_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max)
24_MAX_UINT32_VAL = tl.constexpr((1 << 32) - 1)
25_MIN_UINT32_VAL = tl.constexpr(0)
26_MIN_INT24_VAL = tl.constexpr(-(2**23))
27_MAX_INT24_VAL = tl.constexpr(2**23 - 1)
30@triton.jit
31def _get_iinfo_val(
32 dtype,
33 return_max,
34):
35 if dtype is tl.int64:
36 if return_max:
37 return _MAX_INT64_VAL
38 else:
39 return _MIN_INT64_VAL
40 elif dtype is tl.int32:
41 if return_max:
42 return _MAX_INT32_VAL
43 else:
44 return _MIN_INT32_VAL
45 elif dtype is tl.int16:
46 if return_max:
47 return _MAX_INT16_VAL
48 else:
49 return _MIN_INT16_VAL
50 elif dtype is tl.int8:
51 if return_max:
52 return _MAX_INT8_VAL
53 else:
54 return _MIN_INT8_VAL
55 elif dtype is tl.uint32:
56 if return_max:
57 return _MAX_UINT32_VAL
58 else:
59 return _MIN_UINT32_VAL
60 else:
61 raise ValueError("Unknown dtype")
64@libentry()
65@triton.jit
66def bitonic_sortbykey_kernel(
67 y_ptr,
68 index_ptr,
69 chunk_x,
70 chunk_index,
71 N: tl.constexpr,
72 BLOCK_SIZE: tl.constexpr,
73 DESCENDING: tl.constexpr,
74):
75 cur_batch = tl.program_id(0)
76 chunk_x += cur_batch * N
77 chunk_index += cur_batch * N
78 index_ptr += cur_batch * N
79 y_ptr += cur_batch * N
81 cols = tl.arange(0, BLOCK_SIZE)
82 mask = cols < N
84 mask_val = _get_iinfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING)
86 chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val)
87 chunk_index_val = tl.load(chunk_index + cols, mask=mask)
89 sorted_chunk_x, sorted_chunk_index = argsort(
90 chunk_x_val, chunk_index_val, 0, descending=DESCENDING
91 )
92 tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < N)
93 tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < N)
96@triton.jit
97def radix_type_convert(k):
98 ik = k.to(tl.int64)
99 if tl.constexpr(k.dtype == tl.int8):
100 mask = (ik >> 7) & 0x1
101 o = tl.where(mask, ik & 0x7F, ik | 0x80)
102 elif tl.constexpr(k.dtype == tl.int16):
103 mask = (ik >> 15) & 0x1
104 o = tl.where(mask, ik & 0x7FFF, ik | 0x8000)
105 elif tl.constexpr(k.dtype == tl.int32):
106 mask = (ik >> 31) & 0x1
107 o = tl.where(mask, ik & 0x7FFFFFFF, ik | 0x80000000)
108 elif tl.constexpr(k.dtype == tl.int64):
109 mask = (ik >> 63) & 0x1
110 o = tl.where(mask, ik & 0x7FFFFFFFFFFFFFFF, ik | 0x8000000000000000)
111 else:
112 o = k
113 return o
116@libentry()
117@triton.jit
118def digit_hist_kernel(
119 digit_hist,
120 key,
121 n_elements,
122 bits_per_pass,
123 bins,
124 passes,
125 bit_mask,
126 bins_segment,
127 BLOCK_SIZE: tl.constexpr,
128):
129 bin_segid = tl.program_id(1)
130 pid0 = tl.program_id(0)
131 grid0 = tl.num_programs(0)
133 key_offset = pid0.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
134 key_mask = key_offset < n_elements
135 key_data = tl.load(key + key_offset, mask=key_mask)
136 ikey_data = radix_type_convert(key_data)
137 bit_offset = 0
138 for p in range(passes):
139 key_digit = (ikey_data >> bit_offset) & bit_mask
140 blk_bin_start = bin_segid * bins_segment
141 for s in range(bins_segment):
142 bin_id = s + blk_bin_start
143 digit_mask = tl.where((key_digit == bin_id) & key_mask, 1, 0)
144 digit_sum = tl.sum(digit_mask)
145 # +1 for exclusive
146 bin_offset = p * (bins + 1) * grid0 + (bin_id + 1) * grid0 + pid0
147 # reduce rather than global atomic for perf issue
148 tl.store(digit_hist + bin_offset, digit_sum)
149 tl.store(digit_hist + p * (bins + 1) * grid0 + pid0, 0, mask=bin_segid == 0)
150 bit_offset += bits_per_pass
153@libentry()
154@triton.autotune(
155 configs=runtime.get_tuned_config("randperm"),
156 key=["n_elements"],
157)
158@triton.jit
159def radix_sortbykey_scatter_kernel(
160 key_out,
161 value_out,
162 key_in,
163 value_in,
164 digit_hist,
165 d_lookback,
166 n_elements,
167 bit_offset,
168 passes,
169 p,
170 num_portions,
171 portion_size,
172 portion_id,
173 bit_mask,
174 bins_segment,
175 max_tiles_per_portion,
176 bins: tl.constexpr,
177 BLOCK_SIZE: tl.constexpr,
178):
179 LOOKBACK_PARTIAL_MASK = 1 << 30
180 LOOKBACK_GLOBAL_MASK = 1 << 31
181 LOOKBACK_KIND_MASK = LOOKBACK_PARTIAL_MASK | LOOKBACK_GLOBAL_MASK
182 LOOKBACK_VALUE_MASK = ~LOOKBACK_KIND_MASK
184 pid0 = tl.program_id(0)
185 portion_id_i64 = portion_id
186 portion_id_i64 = portion_id_i64.to(tl.int64)
187 key_offset = (
188 portion_id_i64 * portion_size
189 + pid0.to(tl.int64) * BLOCK_SIZE
190 + tl.arange(0, BLOCK_SIZE)
191 )
193 key_mask = key_offset < n_elements
194 value_data = tl.load(value_in + key_offset, mask=key_mask)
195 key_data = tl.load(key_in + key_offset, mask=key_mask)
197 ikey_data = radix_type_convert(key_data)
198 key_digit = (ikey_data >> bit_offset) & bit_mask
200 blk_bin_start = tl.program_id(1) * bins_segment
201 last_block = tl.program_id(0) == tl.num_programs(0) - 1
202 for s in range(bins_segment):
203 bin_id = s + blk_bin_start
204 key_digit_mask = (key_digit == bin_id) & key_mask
205 key_elem_mask = tl.where(key_digit_mask, 1, 0)
206 key_block_rank = tl.cumsum(key_elem_mask)
207 key_block_rank = tl.where(key_digit_mask, key_block_rank - 1, 0)
208 bin_of_bucket = tl.sum(key_elem_mask)
209 partial_counter = bin_of_bucket | LOOKBACK_PARTIAL_MASK
210 tl.store(
211 d_lookback
212 + ((portion_id * passes + p) * max_tiles_per_portion + pid0) * bins
213 + bin_id,
214 partial_counter,
215 cache_modifier=".cg",
216 )
217 bin_offset = p * (bins + 1) + bin_id
218 prefix_offsets = tl.load(
219 digit_hist + bin_offset + portion_id * passes * (bins + 1)
220 )
221 bk = pid0 - 1
222 inc_sum = bin_of_bucket
223 while bk >= 0:
224 rd_lbk_offset = (
225 (portion_id * passes + p) * max_tiles_per_portion + bk
226 ) * bins + bin_id
227 partial_prefix = tl.load(d_lookback + rd_lbk_offset, volatile=True)
228 while partial_prefix == 0:
229 partial_prefix = tl.load(d_lookback + rd_lbk_offset, volatile=True)
230 inc_sum += (partial_prefix & LOOKBACK_VALUE_MASK).to(tl.int32)
231 if partial_prefix & LOOKBACK_GLOBAL_MASK:
232 # break
233 bk = -1
234 else:
235 bk -= 1
236 global_counter = inc_sum | LOOKBACK_GLOBAL_MASK
237 tl.store(
238 d_lookback
239 + ((portion_id * passes + p) * max_tiles_per_portion + pid0) * bins
240 + bin_id,
241 global_counter,
242 cache_modifier=".cg",
243 )
244 inc_bucket_offset = prefix_offsets.to(tl.int64) + inc_sum.to(tl.int64)
245 if last_block and portion_id < num_portions - 1:
246 tl.store(
247 digit_hist + bin_offset + (portion_id + 1) * passes * (bins + 1),
248 inc_bucket_offset,
249 )
250 global_offsets = (
251 inc_bucket_offset - bin_of_bucket.to(tl.int64) + key_block_rank.to(tl.int64)
252 )
253 tl.store(key_out + global_offsets, key_data, mask=key_digit_mask)
254 tl.store(value_out + global_offsets, value_data, mask=key_digit_mask)
257# for parallelization, randomly shuffle the entire block rather than adjacent equal elements as pytorch GPU backend
258@libentry()
259@triton.jit(do_not_specialize=["philox_seed", "philox_offset"])
260def duplicate_keys_shuffle_kernel(
261 value_in, n_elements, philox_seed, philox_offset, BLOCK_SIZE: tl.constexpr
262):
263 pid0 = tl.program_id(0)
264 offset_range = tl.arange(0, BLOCK_SIZE)
265 value_offset = pid0.to(tl.int64) * BLOCK_SIZE + offset_range
266 value_mask = value_offset < n_elements
267 value_data = tl.load(value_in + value_offset, mask=value_mask)
269 philox_seed = philox_seed.to(tl.int64)
270 philox_offset = philox_offset.to(tl.int64)
271 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
272 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
273 i4 = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
274 c0 += i4
275 _O = c0 * 0
276 r0, _, _, _ = tl.philox(philox_seed, c0, c1, _O, _O)
278 _block_size = BLOCK_SIZE
279 r1 = r0 % _block_size.to(tl.uint32)
280 mask_val = _get_iinfo_val(tl.uint32, True)
281 r1 = tl.where(value_offset < n_elements, r1, mask_val)
282 _, sorted_chunk_index = argsort(r1, offset_range, 0, descending=False)
283 store_offset = pid0.to(tl.int64) * BLOCK_SIZE + sorted_chunk_index.to(tl.int64)
284 tl.store(value_in + store_offset, value_data, mask=store_offset < n_elements)
287def sort_by_key(key, value, valid_bits, generator=None):
288 n_elements = key.numel()
289 if n_elements > 2 * 1024:
290 # radix method
291 BLOCK_SIZE = 1024
292 bits_per_pass = 4
293 bits_per_segment = 3
294 passes = triton.cdiv(valid_bits, bits_per_pass)
295 bins = 2**bits_per_pass
296 bins_per_sgement = 2**bits_per_segment
297 bit_mask = bins - 1
299 portion_size = 2**30 # 2 bits reserved for mask
300 num_portions = triton.cdiv(n_elements, portion_size)
301 max_portion_items = portion_size if num_portions > 1 else n_elements
302 max_tiles_per_portion = triton.cdiv(max_portion_items, BLOCK_SIZE)
304 hist_dtype = torch.int64 if num_portions > 1 else torch.int32
305 grid_hist = (triton.cdiv(n_elements, BLOCK_SIZE), bins // bins_per_sgement)
307 digit_hist_slice = torch.empty(
308 (passes, bins + 1, grid_hist[0]), dtype=hist_dtype, device=key.device
309 )
311 digit_hist = torch.empty(
312 (num_portions, passes, bins + 1), dtype=hist_dtype, device=key.device
313 )
314 d_lookback = torch.empty(
315 num_portions * passes * bins * max_tiles_per_portion,
316 dtype=torch.int32,
317 device=key.device,
318 )
320 key_out_p = torch.empty_like(key)
321 key_out_q = torch.empty_like(key)
322 value_out_p = torch.empty_like(value)
323 value_out_q = torch.empty_like(value)
325 # step1
326 d_lookback.zero_()
327 with torch_device_fn.device(key.device):
328 digit_hist_kernel[grid_hist](
329 digit_hist_slice,
330 key,
331 n_elements,
332 bits_per_pass,
333 bins,
334 passes,
335 bit_mask,
336 bins_per_sgement,
337 BLOCK_SIZE,
338 )
340 # step2
341 digit_hist_slice = torch.sum(digit_hist_slice, dim=2, keepdim=False)
342 # digit_hist_slice = digit_hist_slice.cumsum(dim=1) # shape of [passes, bins + 1]
343 digit_hist_slice = digit_hist_slice.cpu().cumsum(dim=1).to(key.device)
344 digit_hist.copy_(digit_hist_slice)
346 bit_offset = 0
347 for p in range(passes):
348 k_in = (key if p == 0 else key_out_p) if p % 2 == 0 else key_out_q
349 v_in = (value if p == 0 else value_out_p) if p % 2 == 0 else value_out_q
350 k_out = key_out_q if p % 2 == 0 else key_out_p
351 v_out = value_out_q if p % 2 == 0 else value_out_p
352 # step3
353 for portion_id in range(num_portions):
354 portion_items = min(
355 n_elements - portion_id * portion_size, portion_size
356 )
357 tiles_per_portion = triton.cdiv(portion_items, BLOCK_SIZE)
358 grid_scatter = (tiles_per_portion, grid_hist[1])
359 with torch_device_fn.device(key.device):
360 radix_sortbykey_scatter_kernel[grid_scatter](
361 k_out,
362 v_out,
363 k_in,
364 v_in,
365 digit_hist,
366 d_lookback,
367 n_elements,
368 bit_offset,
369 passes,
370 p,
371 num_portions,
372 portion_size,
373 portion_id,
374 bit_mask,
375 bins_per_sgement,
376 max_tiles_per_portion,
377 bins,
378 BLOCK_SIZE,
379 )
380 bit_offset += bits_per_pass
382 # last step, shuffle inner-block data
383 BLOCK_SIZE_SHUFFLE = 512
384 grid_shuffle = (triton.cdiv(n_elements, BLOCK_SIZE_SHUFFLE),)
385 philox_seed, philox_offset = philox_backend_seed_offset(
386 n_elements, generator=generator
387 )
388 with torch_device_fn.device(key.device):
389 duplicate_keys_shuffle_kernel[grid_shuffle](
390 v_out,
391 n_elements,
392 philox_seed,
393 philox_offset,
394 BLOCK_SIZE_SHUFFLE,
395 num_warps=4,
396 )
397 return v_out
398 else:
399 # bitonic method
400 BLOCK_SIZE = triton.next_power_of_2(n_elements)
401 grid = (1,)
402 k_out = torch.empty_like(key)
403 v_out = torch.empty_like(value)
404 with torch_device_fn.device(key.device):
405 bitonic_sortbykey_kernel[grid](
406 k_out, v_out, key, value, n_elements, BLOCK_SIZE, False
407 )
408 return v_out
411def randperm(
412 n,
413 *,
414 generator=None,
415 out=None,
416 dtype=torch.int64,
417 layout=torch.strided,
418 device=None,
419 requires_grad=False,
420 pin_memory=False,
421):
422 logger.debug("GEMS RANDPERM")
423 assert dtype == torch.int16 or dtype == torch.int32 or dtype == torch.int64
424 assert n <= _MAX_INT64_VAL, "n exceeds maximum int64"
426 if device is None:
427 device = torch.device(device_.name)
428 in_range = torch.arange(n, dtype=dtype, device=device)
430 u8max = 2**8
431 u16max = 2**16
432 u24max = 2**24
433 u32max = 2**32
435 if n <= u8max:
436 valid_bits = 8
437 key_dtype = torch.int8
438 keymin = _MIN_INT8_VAL
439 keymax = _MAX_INT8_VAL
440 elif n <= u16max:
441 valid_bits = 16
442 key_dtype = torch.int16
443 keymin = _MIN_INT16_VAL
444 keymax = _MAX_INT16_VAL
445 elif n <= u24max:
446 valid_bits = 24
447 key_dtype = torch.int32
448 keymin = _MIN_INT24_VAL
449 keymax = _MAX_INT24_VAL
450 elif n <= u32max:
451 valid_bits = 32
452 key_dtype = torch.int32
453 keymin = _MIN_INT32_VAL
454 keymax = _MAX_INT32_VAL
455 else:
456 valid_bits = 64
457 key_dtype = torch.int64
458 keymin = _MIN_INT64_VAL
459 keymax = _MAX_INT64_VAL
461 rand_key = torch.randint(
462 low=keymin, high=keymax, size=[n], dtype=key_dtype, device=device
463 )
464 perm_range = sort_by_key(rand_key, in_range, valid_bits, generator=generator)
465 return perm_range