Coverage for src/flag_gems/runtime/backend/_arm/ops/sort.py: 0%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from .topk import _get_finfo_val, _get_iinfo_val, argsort 

9 

10 

11# @libentry() 

12@triton.jit() 

13def sort_kernel( 

14 in_ptr, 

15 out_ptr, 

16 out_index_ptr, 

17 N: tl.constexpr, 

18 BLOCK_SIZE: tl.constexpr, 

19 DESCENDING: tl.constexpr, 

20 IS_FLOAT: tl.constexpr, 

21): 

22 cols = tl.arange(0, BLOCK_SIZE) 

23 mask = cols < N 

24 offset = tl.program_id(0) * N + cols 

25 in_ptr += offset 

26 out_ptr += offset 

27 out_index_ptr += offset 

28 

29 if IS_FLOAT: 

30 mask_val = _get_finfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING) 

31 in_val = tl.load(in_ptr, mask=mask, other=mask_val) 

32 in_val = tl.where(in_val.dtype.is_fp64(), in_val, in_val.to(tl.float32)) 

33 else: 

34 mask_val = _get_iinfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING) 

35 in_val = tl.load(in_ptr, mask=mask, other=mask_val).to(tl.int32) 

36 index_val = tl.arange(0, BLOCK_SIZE) 

37 

38 sorted_in_val, sorted_index_val = argsort( 

39 in_val, index_val, 0, descending=DESCENDING 

40 ) 

41 tl.store(out_ptr, sorted_in_val, mask=mask) 

42 tl.store(out_index_ptr, sorted_index_val, mask=mask) 

43 

44 

45def sort(inp, dim=-1, descending=False): 

46 logging.debug("GEMS SORT") 

47 sort_elem_cnt = inp.shape[dim] 

48 if sort_elem_cnt == 1: 

49 return inp, torch.zeros_like(inp, dtype=torch.int64) 

50 block_size = triton.next_power_of_2(sort_elem_cnt) 

51 

52 if dim < 0: 

53 dim = dim + inp.ndim 

54 if dim != inp.ndim - 1: 

55 inp = torch.movedim(inp, dim, -1).contiguous() 

56 else: 

57 inp = inp.contiguous() 

58 batch_size = math.prod(inp.shape) // sort_elem_cnt 

59 

60 out = torch.empty_like(inp) 

61 out_index = torch.empty_like(inp, dtype=torch.int64) 

62 

63 # with torch_device_fn.device(inp.device): 

64 sort_kernel[batch_size,]( 

65 inp, 

66 out, 

67 out_index, 

68 N=sort_elem_cnt, 

69 BLOCK_SIZE=block_size, 

70 DESCENDING=descending, 

71 IS_FLOAT=inp.is_floating_point(), 

72 num_warps=4, 

73 ) 

74 

75 if dim != inp.ndim - 1: 

76 out = torch.movedim(out, -1, dim) 

77 out_index = torch.movedim(out_index, -1, dim) 

78 return out, out_index