Coverage for src/flag_gems/ops/nonzero_numpy.py: 100%

7 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3from flag_gems.ops.nonzero import nonzero 

4 

5logger = logging.getLogger(__name__) 

6 

7 

8def nonzero_numpy(inp): 

9 """ 

10 Returns a tuple of 1D tensors, one for each dimension of the input, 

11 containing the indices of the non-zero elements in that dimension. 

12 

13 This is equivalent to torch.nonzero(...).T or numpy.nonzero() behavior. 

14 """ 

15 logger.debug("GEMS NONZERO_NUMPY") 

16 

17 # Use the existing nonzero implementation which returns shape [N, ndim] 

18 out = nonzero(inp, as_tuple=False) 

19 

20 # Unbind along dim=1 to get ndim tensors of shape [N] 

21 # Convert to list since aten::nonzero_numpy returns Tensor[] 

22 return list(out.unbind(dim=1))