Coverage for src/flag_gems/runtime/backend/_ascend/ops/index.py: 0%
116 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
10@triton.jit
11def index_kernel_func(
12 input_ptr,
13 stride: tl.constexpr,
14 index_len,
15 index_ptr,
16 out_ptr,
17 BLOCK_SIZE: tl.constexpr,
18 MAX_DATA_SIZE: tl.constexpr,
19):
20 pid0 = tl.program_id(axis=0)
22 for i in range(0, BLOCK_SIZE):
23 offset = pid0 * BLOCK_SIZE + i
25 if offset < index_len:
26 in_start_index = tl.load(index_ptr + offset) * stride
27 out_start_offset = offset * stride
28 loop_num = (stride - 1) // MAX_DATA_SIZE + 1
30 for loop_idx in range(0, loop_num):
31 inner_offset = loop_idx * MAX_DATA_SIZE + tl.arange(0, MAX_DATA_SIZE)
32 mask = inner_offset < stride
33 cur_value = tl.load(
34 input_ptr + in_start_index + inner_offset, mask=mask
35 )
36 tl.store(
37 out_ptr + out_start_offset + inner_offset, cur_value, mask=mask
38 )
41def index_wrapper(input, indices, out):
42 """
43 Simple kernel wrapper for contiguous tensor indices starting from dim 0
44 """
45 input_shape = input.shape
46 input_dim = len(input_shape)
47 indices_dim = len(indices)
48 stride = 1
50 for i in range(indices_dim, input_dim):
51 stride *= input_shape[i]
53 index_len = indices[0].numel()
54 if index_len <= 0:
55 return
57 actual_index = indices[0]
58 for idx in range(0, indices_dim - 1):
59 actual_index = actual_index * input_shape[idx + 1] + indices[idx + 1]
61 BLOCK_SIZE = 32
62 MAX_DATA_SIZE = 8 * 1024
64 grid = lambda meta: (triton.cdiv(index_len, meta["BLOCK_SIZE"]),)
66 index_kernel_func[grid](
67 input,
68 stride,
69 index_len,
70 actual_index,
71 out,
72 BLOCK_SIZE=BLOCK_SIZE,
73 MAX_DATA_SIZE=MAX_DATA_SIZE,
74 )
77def index(inp, indices):
78 logger.debug("GEMS_ASCEND INDEX")
79 indices = list(indices)
80 if not indices:
81 raise ValueError("at least one index must be provided")
83 indices = [
84 index.to(inp.device)
85 if index is not None and index.device != inp.device
86 else index
87 for index in indices
88 ]
90 # Step 1: Process indices (convert bool/int8 to long, handle None)
91 # Following PyTorch meta implementation
92 processed_indices = []
93 for i, index in enumerate(indices):
94 if index is not None:
95 # Check dtype
96 if index.dtype in [torch.int8, torch.bool]:
97 # Convert boolean/int8 mask to long indices
98 nonzero = index.nonzero()
99 k = len(processed_indices)
100 if k + index.ndim > inp.ndim:
101 raise IndexError(
102 f"too many indices for tensor of dimension {inp.ndim}"
103 )
104 # Check shape matches
105 for j in range(index.ndim):
106 if index.shape[j] != inp.shape[k + j]:
107 raise IndexError(
108 f"The shape of the mask {index.shape} at index {i} "
109 f"does not match the shape of the indexed tensor {inp.shape} at index {k + j}"
110 )
111 # Extract indices from nonzero
112 for j in range(index.ndim):
113 processed_indices.append(nonzero.select(1, j))
114 elif index.dtype in [torch.long, torch.int, torch.int32, torch.int64]:
115 processed_indices.append(index)
116 else:
117 raise TypeError(
118 "tensors used as indices must be long, int, byte or bool tensors"
119 )
120 else:
121 processed_indices.append(None)
123 indices = processed_indices
125 # Check indices count
126 if len(indices) > inp.ndim:
127 raise IndexError(
128 f"too many indices for tensor of dimension {inp.ndim} (got {len(indices)})"
129 )
131 # Step 2: Broadcast indices (only tensor indices, not None)
132 tensor_indices = [idx for idx in indices if idx is not None]
133 if tensor_indices:
134 # Broadcast all tensor indices together
135 if len(tensor_indices) > 1:
136 tensor_indices = list(torch.broadcast_tensors(*tensor_indices))
137 # Update indices list with broadcasted tensors
138 tensor_idx = 0
139 for i in range(len(indices)):
140 if indices[i] is not None:
141 indices[i] = tensor_indices[tensor_idx]
142 tensor_idx += 1
144 # Step 3: Add missing None indices (pad to input.ndim)
145 while len(indices) < inp.ndim:
146 indices.append(None)
148 # Step 4: Check if has contiguous subspace
149 # (all non-None tensors are adjacent)
150 state = 0
151 has_contiguous_subspace = False
152 starts_from_zero = False
153 for i, index in enumerate(indices):
154 if state == 0:
155 if index is not None:
156 if i == 0:
157 starts_from_zero = True
158 state = 1
159 elif state == 1:
160 if index is None:
161 state = 2
162 else:
163 if index is not None:
164 break
165 else:
166 has_contiguous_subspace = True
168 # Step 5: Transpose to front if needed
169 # If not contiguous, transpose input so all non-None indices come first
170 if not has_contiguous_subspace or not starts_from_zero:
171 # Build full index tuple (None -> slice(None))
172 full_indices = []
173 for idx in indices:
174 if idx is None:
175 full_indices.append(slice(None))
176 else:
177 full_indices.append(idx)
178 return inp[tuple(full_indices)]
180 # Step 6: Now indices have contiguous subspace
181 # Calculate output shape: before_shape + replacement_shape + after_shape
182 before_shape = []
183 after_shape = []
184 replacement_shape = []
186 for dim, index in enumerate(indices):
187 if index is None:
188 if replacement_shape:
189 # None after tensor indices -> goes to after_shape
190 after_shape.append(inp.shape[dim])
191 else:
192 # None before tensor indices -> goes to before_shape
193 before_shape.append(inp.shape[dim])
194 else:
195 # First tensor index determines replacement_shape
196 if not replacement_shape:
197 replacement_shape = list(index.shape)
199 # Step 7: Build output shape and create output tensor
200 out_shape = before_shape + replacement_shape + after_shape
201 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device)
203 # Step 8: Handle empty tensor case
204 if inp.numel() == 0 or out.numel() == 0:
205 return out
207 # Step 9: Extract only tensor indices for kernel
208 tensor_indices = [idx for idx in indices if idx is not None]
209 if not tensor_indices:
210 # All None, just reshape
211 return inp.view(*out_shape)
213 # Step 10: Call kernel with tensor indices
214 # Note: kernel needs to handle the fact that input was potentially permuted
215 # and output shape includes None dimensions
216 if inp.ndim == 1 and len(tensor_indices) == 1:
217 return torch.gather(inp, 0, tensor_indices[0])
219 index_wrapper(inp, tensor_indices, out)
220 return out