Coverage for src/flag_gems/ops/nonzero_numpy.py: 100%
7 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3from flag_gems.ops.nonzero import nonzero
5logger = logging.getLogger(__name__)
8def nonzero_numpy(inp):
9 """
10 Returns a tuple of 1D tensors, one for each dimension of the input,
11 containing the indices of the non-zero elements in that dimension.
13 This is equivalent to torch.nonzero(...).T or numpy.nonzero() behavior.
14 """
15 logger.debug("GEMS NONZERO_NUMPY")
17 # Use the existing nonzero implementation which returns shape [N, ndim]
18 out = nonzero(inp, as_tuple=False)
20 # Unbind along dim=1 to get ndim tensors of shape [N]
21 # Convert to list since aten::nonzero_numpy returns Tensor[]
22 return list(out.unbind(dim=1))