Coverage for src/flag_gems/runtime/backend/_mthreads/ops/sort.py: 0%
211 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.ops.topk import _get_finfo_val, _get_iinfo_val, argsort
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
11logger = logging.getLogger(
12 f"flag_gems.runtime.backend._mthreads.ops.{__name__.split('.')[-1]}"
13)
16def unwrap_if_constexpr(o):
17 return o.value if isinstance(o, tl.constexpr) else o
20@tl.constexpr
21def get_int_t(num_bits: tl.constexpr, signed: tl.constexpr) -> tl.dtype:
22 num_bits = unwrap_if_constexpr(num_bits)
23 signed = unwrap_if_constexpr(signed)
24 return tl.core.get_int_dtype(num_bits, signed)
27@tl.constexpr
28def one_zeros(num_bits: tl.constexpr) -> int:
29 num_bits = unwrap_if_constexpr(num_bits)
30 return 1 << (num_bits - 1)
33@tl.constexpr
34def zero_ones(num_bits: tl.constexpr) -> int:
35 num_bits = unwrap_if_constexpr(num_bits)
36 return (1 << (num_bits - 1)) - 1
39@triton.jit
40def uint_to_uint(x, descending: tl.constexpr = False):
41 out = ~x if descending else x
42 return out
45@triton.jit
46def int_to_uint(x, descending: tl.constexpr = False):
47 num_bits: tl.constexpr = x.dtype.primitive_bitwidth
48 udtype = get_int_t(num_bits, False)
49 ux = tl.cast(x, udtype, bitcast=True)
50 if descending:
51 # 0111111....1
52 bit_mask: tl.constexpr = zero_ones(num_bits)
53 bit_mask_tensor = tl.full((), value=bit_mask, dtype=udtype)
54 out = ux ^ bit_mask_tensor
55 else:
56 # 1000000...0
57 sign_bit_mask: tl.constexpr = one_zeros(num_bits)
58 sign_bit_mask_tensor = tl.full((), value=sign_bit_mask, dtype=udtype)
59 out = ux ^ sign_bit_mask_tensor
60 return out
63@triton.jit
64def floating_to_uint(x, descending: tl.constexpr = False):
65 num_bits: tl.constexpr = x.dtype.primitive_bitwidth
66 sdtype = get_int_t(num_bits, True)
67 udtype = get_int_t(num_bits, False)
68 sx = x.to(sdtype, bitcast=True)
69 ux = x.to(udtype, bitcast=True)
71 sign_bit_mask_v: tl.constexpr = one_zeros(num_bits)
72 sign_bit_mask = tl.full((), value=sign_bit_mask_v, dtype=udtype)
73 # mind the dtype, right_shift for signed is arithmetic right shift
74 # Fix for triton 3.1 or else `sx >> rshift_bits` is promoted to int32
75 rshift_bits = tl.full((), value=num_bits - 1, dtype=sdtype)
76 mask = sign_bit_mask | (sx >> rshift_bits).to(udtype, bitcast=True)
77 tl.static_assert(mask.dtype == udtype, "type mismatch")
78 # 1000000000...0 for positive
79 # 1111111111...1 for negative
80 if descending:
81 out = ux ^ (~mask)
82 else:
83 out = ux ^ mask
84 return out.to(udtype, bitcast=True)
87@triton.jit
88def convert_to_uint_preverse_order(x: tl.tensor, descending: tl.constexpr = False):
89 if x.dtype.is_floating():
90 out = floating_to_uint(x, descending)
91 elif x.dtype.is_int_signed():
92 out = int_to_uint(x, descending)
93 elif x.dtype.is_int_unsigned():
94 out = uint_to_uint(x, descending)
95 return out
98@triton.jit
99def compute_global_hist_kernel(
100 arr_ptr,
101 out_ptr,
102 num_passes,
103 m,
104 n,
105 tiles_n_per_cta,
106 TILE_N: tl.constexpr,
107 TILE_R: tl.constexpr,
108 num_bits_per_pass: tl.constexpr,
109 descending: tl.constexpr,
110):
111 # arr_ptr: (m, n)
112 # out_ptr: (m, n_passes, r), where r = 2 ** k_bits is the number of bins
113 pid = tl.program_id(0)
114 pid_n = pid // m
115 pid_m = pid % m
117 r: tl.constexpr = 2**num_bits_per_pass
118 bfe_mask: tl.constexpr = (1 << num_bits_per_pass) - 1 # a.k.a. 2 ** k_bits - 1
119 CTA_TILE_N: tl.constexpr = TILE_N * tiles_n_per_cta
120 cta_n_start = CTA_TILE_N * pid_n
121 cta_n_end = tl.minimum(cta_n_start + CTA_TILE_N, n)
123 for p in range(0, num_passes): # parallel
124 bit_offset = p * num_bits_per_pass
125 for r_start in range(0, r, TILE_R): # parallel
126 bin_indices = r_start + tl.arange(0, TILE_R)
127 acc = tl.zeros((TILE_R, TILE_N), dtype=tl.int64)
128 for n_start in range(cta_n_start, cta_n_end, TILE_N): # sequantial
129 n_offsets = n_start + tl.arange(0, TILE_N) # (TILE_N, )
130 mask = n_offsets < cta_n_end
131 arr = tl.load(arr_ptr + pid_m * n + n_offsets, mask=mask)
132 arr = convert_to_uint_preverse_order(arr, descending)
133 key = (arr >> bit_offset) & bfe_mask # (TILE_N, )
134 matches = tl.where(
135 mask, (bin_indices[:, None] == key), False
136 ) # (TILE_R, TILE_N)
137 acc += matches
138 local_sum = tl.sum(acc, axis=1)
139 tl.atomic_add(
140 out_ptr + pid_m * num_passes * r + p * r + bin_indices,
141 local_sum,
142 sem="relaxed",
143 )
146@triton.jit
147def sweep(
148 arr_ptr,
149 associate_arr_ptr, # inputs: (key & value)
150 out_ptr,
151 associate_out_ptr, # outputs: (key & value)
152 excumsum_bins_ptr,
153 status_ptr, # aux input and status
154 n_passes,
155 pass_id,
156 bit_offset,
157 m,
158 N,
159 OUT_N,
160 TILE_N: tl.constexpr,
161 TILE_R: tl.constexpr,
162 k_bits: tl.constexpr,
163 descending: tl.constexpr,
164):
165 # r: num_bins = 2 ** k_bits
166 # OUT_N: grid_n = cdiv(N, )
168 # arr_ptr: (m, N)
169 # out_ptr: (m, N)
170 # excumsum_bins_ptr: (m, n_passes, r)
171 # flag_ptr: (m, r, OUT_N)
173 # grid: (m, grid_r, grid_n)
175 # load data
176 pid = tl.program_id(0)
177 pid_m = pid % m
178 pid_n = pid // m
179 pid_r = tl.program_id(1)
181 # bit masks
182 aggregate_mask: tl.constexpr = 1 << 30
183 inclusive_prefix_mask: tl.constexpr = 1 << 31
184 v_mask: tl.constexpr = (1 << 30) - 1
185 bfe_mask: tl.constexpr = (1 << k_bits) - 1 # a.k.a. 2 ** k_bits - 1
187 # initialize flag to zero-local sum is not ready
188 r: tl.constexpr = 2**k_bits
189 cta_r_start = pid_r * TILE_R
190 cta_r_end = tl.minimum(cta_r_start + TILE_R, r)
192 # cumsum for a bin_index
193 n_offsets = pid_n * TILE_N + tl.arange(0, TILE_N) # (TILE_N, )
194 mask = n_offsets < N
195 arr = tl.load(arr_ptr + pid_m * N + n_offsets, mask=mask)
196 arr_u = convert_to_uint_preverse_order(arr, descending)
197 key = (arr_u >> bit_offset) & bfe_mask # (TILE_N, )
198 if associate_arr_ptr is not None:
199 associate_arr = tl.load(associate_arr_ptr + pid_m * N + n_offsets, mask=mask)
201 # since triton can only use scalar as condition, loop by bin_index
202 # status must be pre zero-initialized, or else we have to initialize it
203 for bin_index in range(cta_r_start, cta_r_end):
204 matches = tl.where(mask, key == bin_index, False) # (TILE_N, ) bool
205 # cta level cumsum per bin
206 # CAUTION: tl.sum in triton 3.2 does not promote type
207 local_sum = tl.sum(matches.to(tl.uint32), axis=0)
208 pack0 = aggregate_mask | local_sum
209 status_offset = pid_m * (r * OUT_N) + bin_index * OUT_N + pid_n
210 tl.store(status_ptr + status_offset, pack0, cache_modifier=".cg")
212 # decoupled lookback
213 exclusive_prefix = tl.zeros((), dtype=tl.uint32)
214 i_lookback = pid_n - 1
215 while i_lookback >= 0:
216 flag_offset_i = pid_m * (r * OUT_N) + bin_index * OUT_N + i_lookback
217 pack1 = 0
218 while pack1 == 0:
219 # pack1 = tl.load(status_ptr + flag_offset_i, volatile=True) # uin32
220 pack1 = tl.atomic_cas(status_ptr + flag_offset_i, 0, 0, sem="acquire")
221 exclusive_prefix += pack1 & v_mask
222 if (pack1 & aggregate_mask) == aggregate_mask:
223 i_lookback -= 1
224 else:
225 i_lookback = -1
226 pack2 = inclusive_prefix_mask | (exclusive_prefix + local_sum)
227 tl.store(status_ptr + status_offset, pack2, cache_modifier=".cg")
229 local_ex_cumsum = (
230 tl.cumsum(matches.to(tl.uint32), axis=0) - matches
231 ) # (TILE_N, )
232 ex_cumsum_in_bin = (
233 exclusive_prefix + local_ex_cumsum
234 ) # global ex_cumsum_in_bin (TILE_N, )
236 # ex_cumsum_bins (m, n_passes, r)
237 ex_cumsum_bins = tl.load(
238 excumsum_bins_ptr + pid_m * (n_passes * r) + pass_id * r + bin_index
239 ) # scalar
240 pos = ex_cumsum_bins + ex_cumsum_in_bin # (TILE_N, )
242 # scatter
243 tl.store(out_ptr + pid_m * N + pos, arr, mask=matches)
244 if associate_arr_ptr is not None:
245 tl.store(associate_out_ptr + pid_m * N + pos, associate_arr, mask=matches)
248def radix_sort(arr, k_bits=8, descending=False):
249 n = arr.shape[-1]
250 m = arr.numel() // n
251 assert n < (1 << 30), "we have not implemented 2**30 per launch"
252 dtype = arr.dtype
253 num_bits = 1 if dtype == torch.bool else (arr.itemsize * 8)
255 TILE_N = 1024
256 tiles_n_per_cta = 8
257 CTA_TILE_N = tiles_n_per_cta * TILE_N
259 num_bins = 2**k_bits
260 n_passes = triton.cdiv(num_bits, k_bits)
261 TILE_R = 16
263 grid_n = triton.cdiv(n, CTA_TILE_N)
264 grid_for_global_hist = (m * grid_n, 1, 1)
266 with torch_device_fn.device(arr.device):
267 global_hist = torch.zeros(
268 (m, n_passes, num_bins), device=arr.device, dtype=torch.int32
269 )
270 compute_global_hist_kernel[grid_for_global_hist](
271 arr,
272 global_hist,
273 n_passes,
274 m,
275 n,
276 tiles_n_per_cta,
277 TILE_N,
278 TILE_R,
279 k_bits,
280 descending,
281 )
282 ex_cumsum_bins = torch.cumsum(global_hist, -1) - global_hist
283 ex_cumsum_bins = ex_cumsum_bins.to(torch.int32)
285 # sort
286 arr_in = torch.clone(arr)
287 indices_in = (
288 torch.arange(0, n, dtype=torch.int64, device=arr_in.device)
289 .broadcast_to(arr.shape)
290 .contiguous()
291 )
292 arr_out = torch.empty_like(arr)
293 indices_out = torch.empty_like(indices_in)
295 TILE_R = 8
296 grid_r = triton.cdiv(num_bins, TILE_R)
297 TILE_N = 2048
298 grid_n = triton.cdiv(n, TILE_N)
299 grid_for_sweep = (m * grid_n, grid_r)
301 status = torch.empty(
302 (m, num_bins, grid_n), device=arr.device, dtype=torch.int32
303 )
305 for i in range(0, n_passes):
306 bit_offset = i * k_bits
307 status.zero_()
308 sweep[grid_for_sweep](
309 arr_in,
310 indices_in,
311 arr_out,
312 indices_out,
313 ex_cumsum_bins,
314 status,
315 n_passes,
316 i,
317 bit_offset,
318 m,
319 n,
320 grid_n,
321 TILE_N,
322 TILE_R,
323 k_bits,
324 descending,
325 )
326 # print(f"< sorted last {bit_offset + k_bits:>2d} bits: {arr_out}")
327 arr_in, arr_out = arr_out, arr_in
328 indices_in, indices_out = indices_out, indices_in
330 return arr_in, indices_in
333@libentry()
334@triton.jit()
335def sort_kernel(
336 in_ptr,
337 out_ptr,
338 out_index_ptr,
339 N: tl.constexpr,
340 BLOCK_SIZE: tl.constexpr,
341 DESCENDING: tl.constexpr,
342 IS_FLOAT: tl.constexpr,
343):
344 cols = tl.arange(0, BLOCK_SIZE)
345 mask = cols < N
346 offset = tl.program_id(0) * N + cols
347 in_ptr += offset
348 out_ptr += offset
349 out_index_ptr += offset
351 if IS_FLOAT:
352 mask_val = _get_finfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING)
353 in_val = tl.load(in_ptr, mask=mask, other=mask_val)
354 else:
355 mask_val = _get_iinfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING)
356 in_val = tl.load(in_ptr, mask=mask, other=mask_val)
358 index_val = tl.arange(0, BLOCK_SIZE)
360 sorted_in_val, sorted_index_val = argsort(
361 in_val, index_val, 0, descending=DESCENDING
362 )
363 tl.store(out_ptr, sorted_in_val, mask=mask)
364 tl.store(out_index_ptr, sorted_index_val, mask=mask)
367def sort(inp, dim=-1, descending=False):
368 # We only implement stable radix sort here
369 logger.debug("GEMS_MTHREADS SORT")
370 return sort_stable(inp, stable=False, dim=dim, descending=descending)
373def sort_stable(inp, *, stable, dim=-1, descending=False):
374 logger.debug("GEMS_MTHREADS SORT.STABLE")
375 # We only implement stable radix sort here
376 _ = stable
377 sort_elem_cnt = inp.shape[dim]
378 if sort_elem_cnt == 1:
379 return inp, torch.zeros_like(inp, dtype=torch.int64)
381 if dim < 0:
382 dim = dim + inp.ndim
383 if dim != inp.ndim - 1:
384 inp = torch.movedim(inp, dim, -1).contiguous()
385 else:
386 inp = inp.contiguous()
388 dtype = inp.dtype
389 num_bits_per_pass = 1 if dtype == torch.bool else 4
390 out, out_index = radix_sort(inp, num_bits_per_pass, descending)
392 if dim != inp.ndim - 1:
393 out = torch.movedim(out, -1, dim)
394 out_index = torch.movedim(out_index, -1, dim)
395 return out, out_index