Coverage for src/flag_gems/runtime/backend/_ascend/ops/select_scatter.py: 0%

21 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4 

5from flag_gems.utils.shape_utils import has_internal_overlapping 

6 

7logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

8 

9 

10def select_scatter(inp, src, dim, index): 

11 logger.debug("GEMS_ASCEND SELECT_SCATTER") 

12 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

13 assert index >= -inp.size(dim) and index < inp.size(dim), "Invalid index" 

14 dim = dim % inp.ndim 

15 index = index % inp.size(dim) 

16 

17 valid_shape = list(inp.shape) 

18 del valid_shape[dim] 

19 assert ( 

20 list(src.shape) == valid_shape 

21 ), "Expected src to have a size equal to the slice of self" 

22 

23 if has_internal_overlapping(inp): 

24 out = torch.empty(inp.size(), dtype=inp.dtype, device=inp.device) 

25 else: 

26 out = torch.empty_strided( 

27 inp.size(), inp.stride(), dtype=inp.dtype, device=inp.device 

28 ) 

29 

30 out.copy_(inp) 

31 indices = [slice(None)] * inp.ndim 

32 indices[dim] = index 

33 out[indices].copy_(src) 

34 

35 return out