Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/select_scatter.py: 0%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.utils.shape_utils import MemOverlap, has_internal_overlapping 

7 

8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

9 

10 

11@triton.jit 

12def scatter_slice_kernel( 

13 out_ptr, 

14 src_ptr, 

15 src_elements, 

16 dim_prod_post, 

17 out_stride_dim, 

18 index_offset, 

19 BLOCK_SIZE: tl.constexpr, 

20): 

21 pid = tl.program_id(0) 

22 block_start = pid * BLOCK_SIZE 

23 offsets = tl.arange(0, BLOCK_SIZE) 

24 idx = block_start + offsets 

25 mask = idx < src_elements 

26 

27 pre_idx = idx // dim_prod_post 

28 post_idx = idx % dim_prod_post 

29 

30 out_idx = pre_idx * out_stride_dim + index_offset + post_idx 

31 

32 src_data = tl.load(src_ptr + idx, mask=mask) 

33 tl.store(out_ptr + out_idx, src_data, mask=mask) 

34 

35 

36def select_scatter(inp, src, dim, index): 

37 logger.debug("GEMS_KUNLUNXIN SELECT_SCATTER") 

38 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

39 assert index >= -inp.size(dim) and index < inp.size(dim), "Invalid index" 

40 dim = dim % inp.ndim 

41 index = index % inp.size(dim) 

42 

43 valid_shape = list(inp.shape) 

44 del valid_shape[dim] 

45 assert ( 

46 list(src.shape) == valid_shape 

47 ), "Expected src to have a size equal to the slice of self" 

48 

49 if has_internal_overlapping(inp) == MemOverlap.Yes: 

50 out = inp.clone() 

51 else: 

52 out = inp.clone() 

53 

54 src = src.contiguous() 

55 out_contig = out.contiguous() 

56 

57 src_elements = src.numel() 

58 if src_elements == 0: 

59 return out 

60 

61 dim_prod_post = 1 

62 for d in range(dim + 1, inp.ndim): 

63 dim_prod_post *= inp.size(d) 

64 

65 out_stride_dim = inp.size(dim) * dim_prod_post 

66 out_offset = index * dim_prod_post 

67 

68 BLOCK_SIZE = 1024 

69 if src_elements >= 1024 * 1024: 

70 BLOCK_SIZE = 4096 

71 elif src_elements >= 4096: 

72 BLOCK_SIZE = 2048 

73 

74 grid = (triton.cdiv(src_elements, BLOCK_SIZE),) 

75 

76 scatter_slice_kernel[grid]( 

77 out_contig, 

78 src, 

79 src_elements, 

80 dim_prod_post, 

81 out_stride_dim, 

82 out_offset, 

83 BLOCK_SIZE=BLOCK_SIZE, 

84 ) 

85 

86 return out_contig