Coverage for src/flag_gems/ops/argsort.py: 57%

7 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3from flag_gems.ops.sort import sort_stable 

4 

5logger = logging.getLogger(__name__) 

6 

7 

8def argsort(inp, dim=-1, descending=False): 

9 """Returns the indices that sort a tensor along a given dimension. 

10 

11 This is equivalent to calling torch.sort and returning only the indices. 

12 """ 

13 logger.debug("GEMS ARGSORT") 

14 _, indices = sort_stable(inp, stable=True, dim=dim, descending=descending) 

15 return indices