Coverage for src/flag_gems/runtime/backend/_sunrise/ops/sort.py: 0%
241 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7import flag_gems
8from flag_gems.ops.topk import _get_finfo_val, _get_iinfo_val, argsort
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15def next_power_of_2(n: int) -> int:
16 n -= 1
17 n |= n >> 1
18 n |= n >> 2
19 n |= n >> 4
20 n |= n >> 8
21 n |= n >> 16
22 n |= n >> 32
23 n += 1
24 return n
27def unwrap_if_constexpr(o):
28 return o.value if isinstance(o, tl.constexpr) else o
31@tl.constexpr
32def get_int_t(num_bits: tl.constexpr, signed: tl.constexpr) -> tl.dtype:
33 num_bits = unwrap_if_constexpr(num_bits)
34 signed = unwrap_if_constexpr(signed)
35 return tl.core.get_int_dtype(num_bits, signed)
38@tl.constexpr
39def one_zeros(num_bits: tl.constexpr) -> int:
40 num_bits = unwrap_if_constexpr(num_bits)
41 return 1 << (num_bits - 1)
44@tl.constexpr
45def zero_ones(num_bits: tl.constexpr) -> int:
46 num_bits = unwrap_if_constexpr(num_bits)
47 return (1 << (num_bits - 1)) - 1
50@triton.jit
51def uint_to_uint(x, descending: tl.constexpr = False):
52 out = ~x if descending else x
53 return out
56@triton.jit
57def int_to_uint(x, descending: tl.constexpr = False):
58 num_bits: tl.constexpr = x.dtype.primitive_bitwidth
59 udtype = get_int_t(num_bits, False)
60 ux = tl.cast(x, udtype, bitcast=True)
61 if descending:
62 # 0111111....1
63 bit_mask: tl.constexpr = zero_ones(num_bits)
64 bit_mask_tensor = tl.full((), value=bit_mask, dtype=udtype)
65 out = ux ^ bit_mask_tensor
66 else:
67 # 1000000...0
68 sign_bit_mask: tl.constexpr = one_zeros(num_bits)
69 sign_bit_mask_tensor = tl.full((), value=sign_bit_mask, dtype=udtype)
70 out = ux ^ sign_bit_mask_tensor
71 return out
74@triton.jit
75def floating_to_uint(x, descending: tl.constexpr = False):
76 num_bits: tl.constexpr = x.dtype.primitive_bitwidth
77 sdtype = get_int_t(num_bits, True)
78 udtype = get_int_t(num_bits, False)
79 sx = x.to(sdtype, bitcast=True)
80 ux = x.to(udtype, bitcast=True)
82 sign_bit_mask_v: tl.constexpr = one_zeros(num_bits)
83 sign_bit_mask = tl.full((), value=sign_bit_mask_v, dtype=udtype)
84 # mind the dtype, right_shift for signed is arithmetic right shift
85 # Fix for triton 3.1 or else `sx >> rshift_bits` is promoted to int32
86 rshift_bits = tl.full((), value=num_bits - 1, dtype=sdtype)
87 mask = sign_bit_mask | (sx >> rshift_bits).to(udtype, bitcast=True)
88 tl.static_assert(mask.dtype == udtype, "type mismatch")
89 # 1000000000...0 for positive
90 # 1111111111...1 for negative
91 if descending:
92 out = ux ^ (~mask)
93 else:
94 out = ux ^ mask
95 return out.to(udtype, bitcast=True)
98@triton.jit
99def convert_to_uint_preverse_order(x: tl.tensor, descending: tl.constexpr = False):
100 if x.dtype.is_floating():
101 out = floating_to_uint(x, descending)
102 elif x.dtype.is_int_signed():
103 out = int_to_uint(x, descending)
104 elif x.dtype.is_int_unsigned():
105 out = uint_to_uint(x, descending)
106 return out
109@triton.jit
110def compute_global_hist_kernel(
111 arr_ptr,
112 out_ptr,
113 num_passes,
114 m,
115 n,
116 tiles_n_per_cta,
117 TILE_N: tl.constexpr,
118 TILE_R: tl.constexpr,
119 num_bits_per_pass: tl.constexpr,
120 descending: tl.constexpr,
121 USE_UINT16: tl.constexpr,
122):
123 # arr_ptr: (m, n)
124 # out_ptr: (m, n_passes, r), where r = 2 ** k_bits is the number of bins
125 pid = tl.program_id(0)
126 pid_n = pid // m
127 pid_m = pid % m
129 r: tl.constexpr = 2**num_bits_per_pass
130 bfe_mask: tl.constexpr = (1 << num_bits_per_pass) - 1 # a.k.a. 2 ** k_bits - 1
131 CTA_TILE_N: tl.constexpr = TILE_N * tiles_n_per_cta
132 cta_n_start = CTA_TILE_N * pid_n
133 dsize = cta_n_start + CTA_TILE_N
134 cta_n_end = tl.where(dsize < n, dsize, n)
136 arr_partial_ptr = arr_ptr + pid_m * n
137 range_tile_r = tl.arange(0, TILE_R)
138 range_tile_n = tl.arange(0, TILE_N)
139 acc_type = tl.int32
140 if tl.constexpr(USE_UINT16):
141 acc_type = tl.uint16
142 for p in range(0, num_passes): # parallel
143 bit_offset = p * num_bits_per_pass
144 for r_start in range(0, r, TILE_R): # parallel
145 bin_indices = r_start + range_tile_r
146 acc = tl.zeros((TILE_R, TILE_N), dtype=acc_type)
147 for n_start in range(cta_n_start, cta_n_end, TILE_N): # sequantial
148 n_offsets = n_start + range_tile_n # (TILE_N, )
149 mask = n_offsets < cta_n_end
150 arr = tl.load(arr_partial_ptr + n_offsets, mask=mask)
151 arr = convert_to_uint_preverse_order(arr, descending)
152 key = (arr >> bit_offset) & bfe_mask # (TILE_N, )
153 matches = tl.where(
154 mask, (bin_indices[:, None] == key).to(acc_type), 0
155 ) # (TILE_R, TILE_N)
156 acc += matches
157 local_sum = tl.sum(acc, axis=1).to(tl.int32)
158 tl.atomic_add(
159 out_ptr + (pid_m * num_passes + p) * r + bin_indices,
160 local_sum,
161 sem="relaxed",
162 )
165@triton.jit
166def sweep(
167 arr_ptr,
168 associate_arr_ptr, # inputs: (key & value)
169 out_ptr,
170 associate_out_ptr, # outputs: (key & value)
171 excumsum_bins_ptr,
172 status_ptr, # aux input and status
173 n_passes,
174 pass_id,
175 bit_offset,
176 m,
177 N,
178 OUT_N,
179 TILE_N: tl.constexpr,
180 TILE_R: tl.constexpr,
181 k_bits: tl.constexpr,
182 descending: tl.constexpr,
183 USE_UINT16: tl.constexpr,
184):
185 # r: num_bins = 2 ** k_bits
186 # OUT_N: grid_n = cdiv(N, )
188 # arr_ptr: (m, N)
189 # out_ptr: (m, N)
190 # excumsum_bins_ptr: (m, n_passes, r)
191 # flag_ptr: (m, r, OUT_N)
193 # grid: (m, grid_r, grid_n)
195 # load data
196 pid = tl.program_id(0)
197 pid_m = pid % m
198 pid_n = pid // m
199 pid_r = tl.program_id(1)
201 # bit masks
202 aggregate_mask: tl.constexpr = 1 << 30
203 inclusive_prefix_mask: tl.constexpr = 1 << 31
204 v_mask: tl.constexpr = (1 << 30) - 1
205 bfe_mask: tl.constexpr = (1 << k_bits) - 1 # a.k.a. 2 ** k_bits - 1
207 # initialize flag to zero-local sum is not ready
208 r: tl.constexpr = 2**k_bits
209 cta_r_start = pid_r * TILE_R
210 cta_r_end = tl.minimum(cta_r_start + TILE_R, r)
212 # cumsum for a bin_index
213 n_offsets = pid_n * TILE_N + tl.arange(0, TILE_N) # (TILE_N, )
214 mask = n_offsets < N
215 arr = tl.load(arr_ptr + pid_m * N + n_offsets, mask=mask)
216 arr_u = convert_to_uint_preverse_order(arr, descending)
217 key = (arr_u >> bit_offset) & bfe_mask # (TILE_N, )
219 dt = tl.uint32
220 if tl.constexpr(USE_UINT16):
221 dt = tl.uint16
223 # since triton can only use scalar as condition, loop by bin_index
224 # status must be pre zero-initialized, or else we have to initialize it
225 for bin_index in range(cta_r_start, cta_r_end):
226 matches = tl.where(mask, key == bin_index, False) # (TILE_N, ) bool
227 # cta level cumsum per bin
228 # CAUTION: tl.sum in triton 3.2 does not promote type
229 local_sum = tl.sum(matches.to(dtype=dt), axis=0).to(tl.uint32)
230 pack0 = aggregate_mask | local_sum
231 status_offset = (pid_m * r + bin_index) * OUT_N + pid_n
232 tl.store(status_ptr + status_offset, pack0, cache_modifier=".cg")
234 # decoupled lookback
235 exclusive_prefix = tl.zeros((), dtype=tl.uint32)
236 i_lookback = pid_n - 1
237 while i_lookback >= 0:
238 flag_offset_i = pid_m * (r * OUT_N) + bin_index * OUT_N + i_lookback
239 pack1 = tl.load(status_ptr + flag_offset_i, volatile=True) # uin32
240 while pack1 == 0:
241 pack1 = tl.load(status_ptr + flag_offset_i, volatile=True)
242 exclusive_prefix += pack1 & v_mask
243 i_lookback = ((pack1 & aggregate_mask) == aggregate_mask) * i_lookback - 1
244 pack2 = inclusive_prefix_mask | (exclusive_prefix + local_sum)
245 tl.store(status_ptr + status_offset, pack2, cache_modifier=".cg")
247 local_ex_cumsum = (
248 tl.cumsum(matches.to(dt), axis=0).to(tl.uint32) - matches
249 ) # (TILE_N, )
250 ex_cumsum_in_bin = (
251 exclusive_prefix + local_ex_cumsum
252 ) # global ex_cumsum_in_bin (TILE_N, )
254 # ex_cumsum_bins (m, n_passes, r)
255 ex_cumsum_bins = tl.load(
256 excumsum_bins_ptr + (pid_m * n_passes + pass_id) * r + bin_index
257 ) # scalar
258 pos = ex_cumsum_bins + ex_cumsum_in_bin # (TILE_N, )
260 # scatter
261 tl.store(out_ptr + pid_m * N + pos, arr, mask=matches)
262 if associate_arr_ptr is not None:
263 associate_arr = tl.load(
264 associate_arr_ptr + pid_m * N + n_offsets, mask=mask
265 )
266 tl.store(associate_out_ptr + pid_m * N + pos, associate_arr, mask=matches)
269def radix_sort(arr, k_bits=8, descending=False):
270 n = arr.shape[-1]
271 m = arr.numel() // n
272 assert n < (1 << 30), "we have not implemented 2**30 per launch"
273 dtype = arr.dtype
274 num_bits = 1 if dtype == torch.bool else (arr.itemsize * 8)
276 TILE_N = 128
277 if m > 2048 and n < 512:
278 TILE_N = 512
280 tiles_n_per_cta = 8
281 CTA_TILE_N = tiles_n_per_cta * TILE_N
283 num_bins = 1 << k_bits
284 n_passes = triton.cdiv(num_bits, k_bits)
285 TILE_R = 4
287 grid_n = triton.cdiv(n, CTA_TILE_N)
288 grid_for_global_hist = (m * grid_n, 1, 1)
290 USE_UINT16 = False
292 with torch_device_fn.device(arr.device):
293 global_hist = torch.zeros(
294 (m, n_passes, num_bins), dtype=torch.int32, device=arr.device
295 )
296 compute_global_hist_kernel[grid_for_global_hist](
297 arr,
298 global_hist,
299 n_passes,
300 m,
301 n,
302 tiles_n_per_cta,
303 TILE_N,
304 TILE_R,
305 k_bits,
306 descending,
307 USE_UINT16,
308 )
309 ex_cumsum_bins = flag_gems.sub(
310 flag_gems.cumsum(global_hist, -1), global_hist
311 ) # [DIPU] cumsum结果错误
312 ex_cumsum_bins = ex_cumsum_bins.to(
313 torch.int32
314 ) # .to(torch.uint32) -> [DIPU] torch2.2.2不支持uint32
316 # sort
317 # arr_in = torch.clone(arr)
318 indices_in = (
319 torch.arange(0, n, dtype=torch.int32, device=arr.device)
320 .broadcast_to(arr.shape)
321 .contiguous()
322 )
323 arr_out = torch.empty_like(arr)
324 indices_out = torch.empty(
325 indices_in.shape, device=indices_in.device
326 ).as_strided(indices_in.shape, indices_in.stride())
328 TILE_R = 2 if n > 2048 else num_bins
329 grid_r = triton.cdiv(num_bins, TILE_R)
330 TILE_N = 2048
331 if n > 32768:
332 TILE_N = 2048
333 elif m > 2048 and n <= 128:
334 TILE_N = 128
335 elif m < 32 and n > 8096:
336 TILE_N = 256
337 elif m < 32 and n > 2048:
338 TILE_N = 256
339 grid_n = triton.cdiv(n, TILE_N)
340 grid_for_sweep = (m * grid_n, grid_r)
342 USE_UINT16 = n <= 4096
344 status = torch.empty(
345 (m, num_bins, grid_n), device=arr.device, dtype=torch.int32
346 ) # .to(torch.uint32) -> [DIPU] torch2.2.2不支持uint32
348 for i in range(0, n_passes):
349 bit_offset = i * k_bits
350 status = status.zero_()
351 sweep[grid_for_sweep](
352 arr,
353 indices_in,
354 arr_out,
355 indices_out,
356 ex_cumsum_bins,
357 status,
358 n_passes,
359 i,
360 bit_offset,
361 m,
362 n,
363 grid_n,
364 TILE_N,
365 TILE_R,
366 k_bits,
367 descending,
368 USE_UINT16,
369 )
370 # print(f"< sorted last {bit_offset + k_bits:>2d} bits: {arr_out}")
371 arr, arr_out = arr_out, arr
372 indices_in, indices_out = indices_out, indices_in
374 return arr, indices_in
377@libentry()
378@triton.jit()
379def sort_kernel(
380 in_ptr,
381 out_ptr,
382 out_index_ptr,
383 N: tl.constexpr,
384 BLOCK_SIZE: tl.constexpr,
385 DESCENDING: tl.constexpr,
386 IS_FLOAT: tl.constexpr,
387):
388 cols = tl.arange(0, BLOCK_SIZE)
389 mask = cols < N
390 offset = tl.program_id(0) * N + cols
391 in_ptr += offset
392 out_ptr += offset
393 out_index_ptr += offset
395 if IS_FLOAT:
396 mask_val = _get_finfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING)
397 in_val = tl.load(in_ptr, mask=mask, other=mask_val)
398 else:
399 mask_val = _get_iinfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING)
400 in_val = tl.load(in_ptr, mask=mask, other=mask_val)
402 index_val = tl.arange(0, BLOCK_SIZE)
404 sorted_in_val, sorted_index_val = argsort(
405 in_val, index_val, 0, descending=DESCENDING
406 )
407 tl.store(out_ptr, sorted_in_val, mask=mask)
408 tl.store(out_index_ptr, sorted_index_val, mask=mask)
411def sort(inp, dim=-1, descending=False):
412 # We only implement stable radix sort here
413 logger.debug("GEMS SORT")
414 return sort_stable(inp, stable=False, dim=dim, descending=descending)
417def sort_stable(inp, *, stable, dim=-1, descending=False):
418 device = inp.device
419 logger.debug("GEMS SORT.STABLE")
420 # We only implement stable radix sort here
421 _ = stable
422 sort_elem_cnt = inp.shape[dim]
423 if sort_elem_cnt == 1:
424 return inp, torch.zeros(inp.shape, dtype=torch.int64, device=inp.device)
426 if dim < 0:
427 dim = dim + inp.ndim
428 if dim != inp.ndim - 1:
429 inp = torch.movedim(inp.cpu(), dim, -1).contiguous().to(device=device)
430 else:
431 inp = inp.contiguous()
433 dtype = inp.dtype
434 num_bits_per_pass = 1 if dtype == torch.bool else 2
435 out, out_index = radix_sort(inp, num_bits_per_pass, descending)
437 if dim != inp.ndim - 1:
438 out = torch.movedim(out, -1, dim)
439 out_index = torch.movedim(out_index, -1, dim)
441 # [sunrise fix] 殷文达反馈 -> “返回 return out, out_index.to(torch.int64) 应该是返回了内部mem,返回之后,内部的mem被冲掉了,数据没了”
442 return out.clone(), out_index.to(torch.int64).clone()