Coverage for src/flag_gems/runtime/backend/_ascend/ops/select_backward.py: 0%
22 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import torch
5logger = logging.getLogger(__name__)
8def select_backward(grad, input_sizes, dim, index, out=None):
9 logger.debug("GEMS_ASCEND SELECT_BACKWARD")
10 dim = int(dim)
11 index = int(index)
12 sizes = list(input_sizes)
13 ndim = len(sizes)
15 assert dim >= -ndim and dim < ndim, "Invalid dim"
16 dim %= ndim
18 dim_size = sizes[dim]
20 assert index >= -dim_size and index < dim_size, "Invalid index"
21 index %= dim_size
23 if out is None:
24 out = torch.empty(
25 sizes,
26 dtype=grad.dtype,
27 device=grad.device,
28 )
29 else:
30 assert tuple(out.shape) == tuple(sizes), "out shape mismatch"
31 assert out.dtype == grad.dtype, "dtype mismatch"
32 assert out.device == grad.device, "device mismatch"
34 out.zero_()
35 out.select(dim, index).copy_(grad)
36 return out