Coverage for src/flag_gems/runtime/backend/_sunrise/ops/index_select.py: 0%
93 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.utils import dim_compress, libentry
9from flag_gems.utils import triton_lang_extension as ext
11logger = logging.getLogger(__name__)
14@libentry()
15@triton.heuristics(runtime.get_heuristic_config("index_select"))
16@triton.jit
17def index_select_kernel(
18 inp, out, M, N, index, index_len, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr
19):
20 pid_x = ext.program_id(axis=0)
21 pid_y = ext.program_id(axis=1)
22 rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
23 rows_mask = rows_offsets < M
24 cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N)
26 out_mask = rows_mask & (cols_offsets < index_len)
28 indices = tl.load(index + cols_offsets, mask=(cols_offsets < index_len), other=0)
29 valid_lower_bound = indices >= 0
30 valid_upper_bound = indices < N
31 index_valid_mask = valid_lower_bound & valid_upper_bound
33 inp_off = rows_offsets * N + indices[None, :]
34 out_off = rows_offsets * index_len + cols_offsets[None, :]
36 final_mask = out_mask & index_valid_mask
37 selected = tl.load(inp + inp_off, mask=final_mask, other=0.0)
38 tl.store(out + out_off, selected, mask=final_mask)
41def index_select_heur_block_m(args):
42 N = args["N"]
43 if N >= 8192:
44 return 2048
45 elif N >= 4096:
46 return 1024
47 elif N >= 1024:
48 return 512
49 return 256
52@libentry()
53@triton.heuristics({"BLOCK_N": index_select_heur_block_m})
54@triton.jit
55def index_select_dim0_kernel(inp, out, N, index, BLOCK_N: tl.constexpr):
56 pid_x = ext.program_id(axis=0)
57 pid_y = ext.program_id(axis=1)
58 cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N)
59 mask = cols_offsets < N
60 indices = tl.load(index + pid_x)
61 in_offset = indices * N + cols_offsets
62 selected = tl.load(inp + in_offset, mask=mask, other=0.0)
63 out_offset = pid_x * N + cols_offsets
64 tl.store(out + out_offset, selected, mask=mask)
67def _dim_compress(inp, dim):
68 batch_dim = [dim]
69 reduction_dim = [i for i in range(inp.ndim) if i != dim]
70 order = batch_dim + reduction_dim
71 return inp.permute(order).contiguous()
74def index_select_dim0(inp, dim, index):
75 # inp_shape = list(inp.shape)
76 inp = _dim_compress(inp, dim)
77 out_shape = list(inp.shape)
78 index_len = index.numel()
79 out_shape[0] = index_len
80 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device)
81 N = out.numel() // index_len
82 grid = lambda meta: (
83 index_len,
84 triton.cdiv(N, meta["BLOCK_N"]),
85 )
86 index_select_dim0_kernel[grid](inp, out, N, index)
87 if dim != 0:
88 order = [i for i in range(1, out.ndim)]
89 order.insert(dim, 0)
90 return out.permute(order).contiguous()
91 else:
92 return out
95def index_select(inp, dim, index):
96 logger.debug("GEMS INDEX SELECT")
97 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
98 assert index.ndim <= 1, "Index should have dimension 1 or 0"
100 if index.ndim == 0:
101 index = index.unsqueeze(0)
102 dim = dim % inp.ndim
103 inp_shape = list(inp.shape)
104 index_len = index.numel()
106 if index_len > 0 and (inp.ndim == 2 or inp.ndim == 3):
107 return index_select_dim0(inp, dim, index)
109 # with dim_compress
110 inp = dim_compress(inp, dim)
111 N = inp_shape[dim]
112 M = inp.numel() // N
113 out_shape = list(inp.shape)
114 out_shape[inp.ndim - 1] = index_len
115 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device)
117 grid = lambda meta: (
118 triton.cdiv(M, meta["BLOCK_M"]),
119 triton.cdiv(index_len, meta["BLOCK_N"]),
120 )
121 index_select_kernel[grid](inp, out, M, N, index, index_len)
122 if dim != out.ndim - 1:
123 order = [i for i in range(out.ndim - 1)]
124 order.insert(dim, out.ndim - 1)
125 out = out.permute(order).contiguous()
126 return out.reshape(out.shape)
127 else:
128 return out