Coverage for src/flag_gems/runtime/backend/_ascend/ops/select_backward.py: 0%

22 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2 

3import torch 

4 

5logger = logging.getLogger(__name__) 

6 

7 

8def select_backward(grad, input_sizes, dim, index, out=None): 

9 logger.debug("GEMS_ASCEND SELECT_BACKWARD") 

10 dim = int(dim) 

11 index = int(index) 

12 sizes = list(input_sizes) 

13 ndim = len(sizes) 

14 

15 assert dim >= -ndim and dim < ndim, "Invalid dim" 

16 dim %= ndim 

17 

18 dim_size = sizes[dim] 

19 

20 assert index >= -dim_size and index < dim_size, "Invalid index" 

21 index %= dim_size 

22 

23 if out is None: 

24 out = torch.empty( 

25 sizes, 

26 dtype=grad.dtype, 

27 device=grad.device, 

28 ) 

29 else: 

30 assert tuple(out.shape) == tuple(sizes), "out shape mismatch" 

31 assert out.dtype == grad.dtype, "dtype mismatch" 

32 assert out.device == grad.device, "device mismatch" 

33 

34 out.zero_() 

35 out.select(dim, index).copy_(grad) 

36 return out