Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/select_scatter.py: 0%
47 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils.shape_utils import MemOverlap, has_internal_overlapping
8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11@triton.jit
12def scatter_slice_kernel(
13 out_ptr,
14 src_ptr,
15 src_elements,
16 dim_prod_post,
17 out_stride_dim,
18 index_offset,
19 BLOCK_SIZE: tl.constexpr,
20):
21 pid = tl.program_id(0)
22 block_start = pid * BLOCK_SIZE
23 offsets = tl.arange(0, BLOCK_SIZE)
24 idx = block_start + offsets
25 mask = idx < src_elements
27 pre_idx = idx // dim_prod_post
28 post_idx = idx % dim_prod_post
30 out_idx = pre_idx * out_stride_dim + index_offset + post_idx
32 src_data = tl.load(src_ptr + idx, mask=mask)
33 tl.store(out_ptr + out_idx, src_data, mask=mask)
36def select_scatter(inp, src, dim, index):
37 logger.debug("GEMS_KUNLUNXIN SELECT_SCATTER")
38 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
39 assert index >= -inp.size(dim) and index < inp.size(dim), "Invalid index"
40 dim = dim % inp.ndim
41 index = index % inp.size(dim)
43 valid_shape = list(inp.shape)
44 del valid_shape[dim]
45 assert (
46 list(src.shape) == valid_shape
47 ), "Expected src to have a size equal to the slice of self"
49 if has_internal_overlapping(inp) == MemOverlap.Yes:
50 out = inp.clone()
51 else:
52 out = inp.clone()
54 src = src.contiguous()
55 out_contig = out.contiguous()
57 src_elements = src.numel()
58 if src_elements == 0:
59 return out
61 dim_prod_post = 1
62 for d in range(dim + 1, inp.ndim):
63 dim_prod_post *= inp.size(d)
65 out_stride_dim = inp.size(dim) * dim_prod_post
66 out_offset = index * dim_prod_post
68 BLOCK_SIZE = 1024
69 if src_elements >= 1024 * 1024:
70 BLOCK_SIZE = 4096
71 elif src_elements >= 4096:
72 BLOCK_SIZE = 2048
74 grid = (triton.cdiv(src_elements, BLOCK_SIZE),)
76 scatter_slice_kernel[grid](
77 out_contig,
78 src,
79 src_elements,
80 dim_prod_post,
81 out_stride_dim,
82 out_offset,
83 BLOCK_SIZE=BLOCK_SIZE,
84 )
86 return out_contig