Coverage for src/flag_gems/ops/argsort.py: 57%
7 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3from flag_gems.ops.sort import sort_stable
5logger = logging.getLogger(__name__)
8def argsort(inp, dim=-1, descending=False):
9 """Returns the indices that sort a tensor along a given dimension.
11 This is equivalent to calling torch.sort and returning only the indices.
12 """
13 logger.debug("GEMS ARGSORT")
14 _, indices = sort_stable(inp, stable=True, dim=dim, descending=descending)
15 return indices