Coverage for src/flag_gems/runtime/backend/_ascend/ops/select_scatter.py: 0%
21 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
5from flag_gems.utils.shape_utils import has_internal_overlapping
7logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
10def select_scatter(inp, src, dim, index):
11 logger.debug("GEMS_ASCEND SELECT_SCATTER")
12 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
13 assert index >= -inp.size(dim) and index < inp.size(dim), "Invalid index"
14 dim = dim % inp.ndim
15 index = index % inp.size(dim)
17 valid_shape = list(inp.shape)
18 del valid_shape[dim]
19 assert (
20 list(src.shape) == valid_shape
21 ), "Expected src to have a size equal to the slice of self"
23 if has_internal_overlapping(inp):
24 out = torch.empty(inp.size(), dtype=inp.dtype, device=inp.device)
25 else:
26 out = torch.empty_strided(
27 inp.size(), inp.stride(), dtype=inp.dtype, device=inp.device
28 )
30 out.copy_(inp)
31 indices = [slice(None)] * inp.ndim
32 indices[dim] = index
33 out[indices].copy_(src)
35 return out