Coverage for src/flag_gems/runtime/backend/_arm/ops/sort.py: 0%
43 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from .topk import _get_finfo_val, _get_iinfo_val, argsort
11# @libentry()
12@triton.jit()
13def sort_kernel(
14 in_ptr,
15 out_ptr,
16 out_index_ptr,
17 N: tl.constexpr,
18 BLOCK_SIZE: tl.constexpr,
19 DESCENDING: tl.constexpr,
20 IS_FLOAT: tl.constexpr,
21):
22 cols = tl.arange(0, BLOCK_SIZE)
23 mask = cols < N
24 offset = tl.program_id(0) * N + cols
25 in_ptr += offset
26 out_ptr += offset
27 out_index_ptr += offset
29 if IS_FLOAT:
30 mask_val = _get_finfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING)
31 in_val = tl.load(in_ptr, mask=mask, other=mask_val)
32 in_val = tl.where(in_val.dtype.is_fp64(), in_val, in_val.to(tl.float32))
33 else:
34 mask_val = _get_iinfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING)
35 in_val = tl.load(in_ptr, mask=mask, other=mask_val).to(tl.int32)
36 index_val = tl.arange(0, BLOCK_SIZE)
38 sorted_in_val, sorted_index_val = argsort(
39 in_val, index_val, 0, descending=DESCENDING
40 )
41 tl.store(out_ptr, sorted_in_val, mask=mask)
42 tl.store(out_index_ptr, sorted_index_val, mask=mask)
45def sort(inp, dim=-1, descending=False):
46 logging.debug("GEMS SORT")
47 sort_elem_cnt = inp.shape[dim]
48 if sort_elem_cnt == 1:
49 return inp, torch.zeros_like(inp, dtype=torch.int64)
50 block_size = triton.next_power_of_2(sort_elem_cnt)
52 if dim < 0:
53 dim = dim + inp.ndim
54 if dim != inp.ndim - 1:
55 inp = torch.movedim(inp, dim, -1).contiguous()
56 else:
57 inp = inp.contiguous()
58 batch_size = math.prod(inp.shape) // sort_elem_cnt
60 out = torch.empty_like(inp)
61 out_index = torch.empty_like(inp, dtype=torch.int64)
63 # with torch_device_fn.device(inp.device):
64 sort_kernel[batch_size,](
65 inp,
66 out,
67 out_index,
68 N=sort_elem_cnt,
69 BLOCK_SIZE=block_size,
70 DESCENDING=descending,
71 IS_FLOAT=inp.is_floating_point(),
72 num_warps=4,
73 )
75 if dim != inp.ndim - 1:
76 out = torch.movedim(out, -1, dim)
77 out_index = torch.movedim(out_index, -1, dim)
78 return out, out_index