Coverage for src/flag_gems/runtime/backend/_mthreads/ops/index_put.py: 0%
382 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
11logger = logging.getLogger(
12 f"flag_gems.runtime.backend._mthreads.ops.{__name__.split('.')[-1]}"
13)
16def get_max_rank_shape(indices: List[torch.Tensor]) -> List[int]:
17 # Filter out None values (basic indexing markers)
18 tensor_indices = [idx for idx in indices if idx is not None]
19 if len(tensor_indices) == 0:
20 return []
21 max_rank = max([len(index.shape) for index in tensor_indices])
22 shape = [0 for _ in range(max_rank)]
23 for i in range(max_rank):
24 max_num = 0
25 for index in tensor_indices:
26 axis = len(index.shape) - 1 - i
27 if axis >= 0:
28 max_num = max(max_num, index.shape[axis])
29 shape[max_rank - 1 - i] = max_num
30 return shape
33def broadcast_indices(indices, target_shape):
34 for i, index in enumerate(indices):
35 if index is not None and tuple(index.shape) != tuple(target_shape):
36 indices[i] = torch.broadcast_to(index, target_shape)
39def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
40 code.writeline("import triton")
41 code.writeline("import triton.language as tl")
42 code.newline()
43 code.writeline("from flag_gems.utils import libentry, libtuner")
44 code.writeline("from flag_gems import runtime")
45 code.writeline("from flag_gems.utils.shape_utils import volume")
46 code.writeline("from flag_gems.utils import triton_lang_extension as ext")
48 code.newline()
49 code.newline()
50 return code
53def generate_index_put_kernel(
54 inp_rank, indices_len, index_rank, kernel_name: str, code: IndentedBuffer
55):
56 code.writeline("@libentry()")
57 code.writeline("@libtuner(")
58 with code.indent():
59 code.writeline('configs=runtime.get_tuned_config("index_put"),')
60 code.writeline('key=["M", "N"],')
61 code.writeline('restore_value=["input_ptr"],')
62 code.writeline('strategy=["align32", "align32"],')
63 code.writeline("warmup=5,")
64 code.writeline("rep=10,")
65 code.writeline(")")
66 code.writeline("@triton.jit")
67 code.writeline(f"def {kernel_name}(")
68 with code.indent():
69 args = ["input_ptr,"]
70 args += [f"indices{i}_ptr," for i in range(indices_len)]
71 args += ["values_ptr,"]
72 args += [f"input_shape{i}," for i in range(inp_rank)]
73 for i in range(indices_len):
74 args += [f"indices{i}_shape{j}," for j in range(index_rank)]
75 args += [f"input_stride{i}," for i in range(inp_rank)]
76 for i in range(indices_len):
77 args += [f"indices{i}_stride{j}," for j in range(index_rank)]
78 args += [
79 f"values_stride{i}," for i in range(index_rank + inp_rank - indices_len)
80 ]
81 args += [
82 "M,",
83 "N,",
84 "IS_ACCUMULATE: tl.constexpr,",
85 "BLOCK_SIZE0: tl.constexpr,",
86 "BLOCK_SIZE1: tl.constexpr,",
87 ]
88 code.writelines(args)
89 code.writeline("):")
91 with code.indent():
92 code.writeline("pid0 = ext.program_id(axis=0)")
93 code.writeline("pid1 = ext.program_id(axis=1)")
94 code.writeline(
95 "offset0 = pid0 * BLOCK_SIZE0 + tl.arange(0, BLOCK_SIZE0)[:, None]"
96 )
97 if inp_rank == indices_len:
98 code.writeline("offset1 = pid1 * 1 + tl.arange(0, 1)[None, :]")
99 else:
100 code.writeline(
101 "offset1 = pid1 * BLOCK_SIZE1 + tl.arange(0, BLOCK_SIZE1)[None, :]"
102 )
103 code.newline()
104 code.writeline("cur_idx = offset0")
105 for i in range(index_rank - 1, -1, -1):
106 code.writeline(f"indices_idx{i} = cur_idx % indices0_shape{i}")
107 code.writeline(f"cur_idx = cur_idx // indices0_shape{i}")
108 code.newline()
109 code.writeline("cur_idx = offset1")
110 for i in range(inp_rank - 1, indices_len - 1, -1):
111 code.writeline(f"input_idx{i} = cur_idx % input_shape{i}")
112 code.writeline(f"cur_idx = cur_idx // input_shape{i}")
113 code.newline()
114 code.writeline("mask0 = offset0 < M")
115 for i in range(indices_len):
116 comp = [f"indices_idx{j} * indices{i}_stride{j}" for j in range(index_rank)]
117 code.writeline(
118 f"cur_index{i} = tl.load(indices{i}_ptr + {' + '.join(comp)}, mask=mask0, other=0)"
119 )
120 code.newline()
121 index_mask = [
122 f"(cur_index{i} >= 0) & (cur_index{i} < input_shape{i})"
123 for i in range(indices_len)
124 ]
125 code.writeline(f"index_mask = {' & '.join(index_mask)}")
126 code.writeline("mask1 = offset1 < N")
127 code.writeline("mask = index_mask & mask0 & mask1")
128 code.newline()
129 comp = [f"cur_index{i} * input_stride{i}" for i in range(indices_len)]
130 comp += [
131 f"input_idx{i} * input_stride{i}" for i in range(indices_len, inp_rank)
132 ]
133 code.writeline(f"input_offset = {' + '.join(comp)}")
134 comp = [f"indices_idx{i} * values_stride{i}" for i in range(index_rank)]
135 comp += [
136 f"input_idx{indices_len + i} * values_stride{index_rank + i}"
137 for i in range(inp_rank - indices_len)
138 ]
139 code.writeline(f"values_offset = {' + '.join(comp)}")
140 code.newline()
141 code.writeline("cur_value = tl.load(values_ptr + values_offset, mask=mask)")
142 code.writeline("if IS_ACCUMULATE:")
143 with code.indent():
144 code.writeline(
145 "tl.atomic_add(input_ptr + input_offset, cur_value, mask=mask)"
146 )
147 code.writeline("else:")
148 with code.indent():
149 code.writeline("tl.store(input_ptr + input_offset, cur_value, mask=mask)")
151 code.newline()
152 code.newline()
153 return code
156def generate_index_put_wrapper(
157 inp_rank,
158 indices_len,
159 index_rank,
160 wrapper_name: str,
161 kernel_name: str,
162 code: IndentedBuffer,
163):
164 code.writeline(f"def {wrapper_name}(input, indices, values, accumulate):")
165 with code.indent():
166 code.writeline("input_shape = input.shape")
167 code.writeline("input_stride = input.stride()")
168 for i in range(indices_len):
169 code.writeline(f"indices{i}_shape = indices[{i}].shape")
170 code.writeline(f"indices{i}_stride = indices[{i}].stride()")
171 code.writeline("values_shape = values.shape")
172 code.writeline("values_stride = values.stride()")
173 code.writeline("M = indices[0].numel()")
174 code.writeline(f"N = volume(input_shape[{indices_len}: ])")
175 code.newline()
176 code.writeline("grid = lambda meta: (")
177 with code.indent():
178 code.writeline("triton.cdiv(M, meta['BLOCK_SIZE0']), ")
179 code.writeline("triton.cdiv(N, meta['BLOCK_SIZE1']), ")
180 code.writeline(")")
181 code.newline()
182 code.writeline(f"{kernel_name}[grid](")
183 with code.indent():
184 args = ["input,"]
185 args += [f"indices[{i}]," for i in range(indices_len)]
186 args += ["values,"]
187 args += [f"input_shape[{i}]," for i in range(inp_rank)]
188 for i in range(indices_len):
189 args += [f"indices{i}_shape[{j}]," for j in range(index_rank)]
190 args += [f"input_stride[{i}]," for i in range(inp_rank)]
191 for i in range(indices_len):
192 args += [f"indices{i}_stride[{j}]," for j in range(index_rank)]
193 args += [
194 f"values_stride[{i}],"
195 for i in range(index_rank + inp_rank - indices_len)
196 ]
197 args += ["M,", "N,", "accumulate==True,"]
198 code.writelines(args)
199 code.writeline(")")
200 code.writeline("return input")
201 code.newline()
202 code.newline()
203 return code
206def generate_code(
207 inputs: Tuple[Any],
208 wrapper_name: str,
209 kernel_name: str,
210 code: IndentedBuffer,
211):
212 inp_rank = inputs[0].ndim
213 # Filter out None values to get actual tensor indices
214 tensor_indices = [idx for idx in inputs[1] if idx is not None]
215 indices_len = len(tensor_indices)
216 if indices_len == 0:
217 raise ValueError("At least one non-None index tensor is required")
218 index_rank = tensor_indices[0].ndim
219 code = generate_imports(code)
220 generate_index_put_kernel(inp_rank, indices_len, index_rank, kernel_name, code)
221 generate_index_put_wrapper(
222 inp_rank, indices_len, index_rank, wrapper_name, kernel_name, code
223 )
224 return code
227class IndexPutFunction:
228 def __init__(self):
229 self.pid = os.getpid()
230 self.overloads: Mapping[str, Callable] = {}
232 def __call__(self, *args, **kwargs):
233 inp, tensor_indices, values, accumulate = args
234 full_args = (inp, tensor_indices, values)
236 key = self.arg_key(*full_args)
237 if key in self.overloads:
238 overload = self.overloads[key]
239 else:
240 code = IndentedBuffer()
241 code = generate_code(
242 full_args,
243 "_index_put_wrapper",
244 "_index_put_jit_function",
245 code,
246 )
247 file_name = f"index_put_{key}.py"
248 file_path = code_cache_dir() / file_name
249 write_atomic(file_path, code.getvalue())
251 spec = importlib.util.spec_from_file_location(
252 f"_gen_module_rank_{key}",
253 file_path,
254 )
256 m = importlib.util.module_from_spec(spec)
257 spec.loader.exec_module(m)
258 overload = getattr(m, "_index_put_wrapper")
259 self.overloads[key] = overload
261 return overload(*args)
263 def arg_key(self, *args, **kwargs):
264 inp, tensor_indices, _ = args[0], args[1], args[2]
265 inp_rank = inp.ndim
266 indices_len = len(tensor_indices)
267 if indices_len == 0:
268 index_rank = 0
269 else:
270 index_rank = tensor_indices[0].ndim
271 return f"inp_rank_{inp_rank}_indices_len_{indices_len}_index_rank_{index_rank}"
274_index_put_func = IndexPutFunction()
277def index_put(inp, indices, values, accumulate=False):
278 logger.debug("GEMS_MTHREADS INDEX PUT")
280 indices = list(indices)
281 if len(indices) == 1 and indices[0].dtype == torch.bool:
282 mask = indices[0]
284 if mask.device != inp.device:
285 mask = mask.to(inp.device)
287 indices = list(torch.where(mask))
289 K = indices[0].numel()
290 target_shape = (K,) + inp.shape[len(indices) :]
292 if values.numel() == 1:
293 values = torch.full(
294 target_shape, values.item(), dtype=inp.dtype, device=inp.device
295 )
296 elif values.numel() == K:
297 values = values.reshape((K,)).expand(target_shape)
299 if not indices:
300 raise ValueError("At least one index tensor is required")
302 indices = [
303 index.to(inp.device)
304 if index is not None and index.device != inp.device
305 else index
306 for index in indices
307 ]
309 processed_indices = []
310 for idx in indices:
311 if idx is None:
312 processed_indices.append(None)
313 elif idx.dtype in (torch.bool, torch.int8):
314 processed_indices.extend(idx.nonzero(as_tuple=True))
315 elif torch.is_tensor(idx):
316 processed_indices.append(idx)
317 else:
318 raise TypeError(
319 "tensors used as indices must be long, int, byte or bool tensors"
320 )
322 indices = processed_indices
324 if len(indices) < inp.ndim:
325 indices.extend([None] * (inp.ndim - len(indices)))
327 if len(indices) > inp.ndim:
328 raise IndexError("too many indices for tensor of dimension {}".format(inp.ndim))
330 tensor_pos = [i for i, x in enumerate(indices) if x is not None]
331 if not tensor_pos:
332 raise ValueError("At least one non-None index tensor is required")
334 tensor_indices = [indices[i] for i in tensor_pos]
335 if len(tensor_indices) > 1:
336 broadcasted = torch.broadcast_tensors(*tensor_indices)
337 for i, pos in enumerate(tensor_pos):
338 indices[pos] = broadcasted[i]
340 is_contiguous = (tensor_pos[-1] - tensor_pos[0] + 1) == len(tensor_pos)
341 starts_with_none = indices[0] is None
342 need_transpose = not is_contiguous or starts_with_none
344 out = inp.clone()
345 if need_transpose:
346 perm_order = tensor_pos + [i for i, x in enumerate(indices) if x is None]
347 inp_view = out.permute(perm_order)
348 final_indices = [indices[i] for i in tensor_pos] + [None] * (
349 len(indices) - len(tensor_pos)
350 )
351 else:
352 inp_view = out
353 final_indices = indices
355 tensors = [x for x in final_indices if x is not None]
356 broadcast_shape = list(tensors[0].shape)
357 slice_shape = [inp_view.shape[i] for i, x in enumerate(final_indices) if x is None]
359 target_shape = broadcast_shape + slice_shape
360 values = values.to(inp.device)
361 if need_transpose and is_contiguous:
362 num_before = tensor_pos[0]
364 before_dims = slice_shape[:num_before]
365 after_dims = slice_shape[num_before:]
366 natural_shape = before_dims + broadcast_shape + after_dims
367 values = values.broadcast_to(natural_shape)
369 B, T = len(before_dims), len(broadcast_shape)
370 val_perm = (
371 list(range(B, B + T)) + list(range(0, B)) + list(range(B + T, values.ndim))
372 )
373 values = values.permute(val_perm)
374 else:
375 values = values.broadcast_to(target_shape)
377 _index_put_func(inp_view, tensors, values, accumulate)
378 return out
381def index_put_(inp, indices, values, accumulate=False):
382 logger.debug("GEMS_MTHREADS INDEX PUT_")
384 indices = list(indices)
385 if len(indices) == 1 and indices[0].dtype == torch.bool:
386 mask = indices[0]
388 if mask.device != inp.device:
389 mask = mask.to(inp.device)
391 indices = list(torch.where(mask))
393 K = indices[0].numel()
394 target_shape = (K,) + inp.shape[len(indices) :]
396 if values.numel() == 1:
397 values = torch.full(
398 target_shape, values.item(), dtype=inp.dtype, device=inp.device
399 )
400 elif values.numel() == K:
401 values = values.reshape((K,)).expand(target_shape)
403 if not indices:
404 raise ValueError("At least one index tensor is required")
406 indices = [
407 index.to(inp.device)
408 if index is not None and index.device != inp.device
409 else index
410 for index in indices
411 ]
413 processed_indices = []
414 for idx in indices:
415 if idx is None:
416 processed_indices.append(None)
417 elif idx.dtype in (torch.bool, torch.int8):
418 processed_indices.extend(idx.nonzero(as_tuple=True))
419 elif torch.is_tensor(idx):
420 processed_indices.append(idx)
421 else:
422 raise TypeError(
423 "tensors used as indices must be long, int, byte or bool tensors"
424 )
426 indices = processed_indices
428 if len(indices) < inp.ndim:
429 indices.extend([None] * (inp.ndim - len(indices)))
431 if len(indices) > inp.ndim:
432 raise IndexError("too many indices for tensor of dimension {}".format(inp.ndim))
434 tensor_pos = [i for i, x in enumerate(indices) if x is not None]
435 if not tensor_pos:
436 raise ValueError("At least one non-None index tensor is required")
438 tensor_indices = [indices[i] for i in tensor_pos]
439 if len(tensor_indices) > 1:
440 broadcasted = torch.broadcast_tensors(*tensor_indices)
441 for i, pos in enumerate(tensor_pos):
442 indices[pos] = broadcasted[i]
444 is_contiguous = (tensor_pos[-1] - tensor_pos[0] + 1) == len(tensor_pos)
445 starts_with_none = indices[0] is None
446 need_transpose = not is_contiguous or starts_with_none
448 if need_transpose:
449 perm_order = tensor_pos + [i for i, x in enumerate(indices) if x is None]
450 inp_view = inp.permute(perm_order)
451 final_indices = [indices[i] for i in tensor_pos] + [None] * (
452 len(indices) - len(tensor_pos)
453 )
454 else:
455 inp_view = inp
456 final_indices = indices
458 tensors = [x for x in final_indices if x is not None]
459 broadcast_shape = list(tensors[0].shape)
460 slice_shape = [inp_view.shape[i] for i, x in enumerate(final_indices) if x is None]
462 target_shape = broadcast_shape + slice_shape
463 values = values.to(inp.device)
464 if need_transpose and is_contiguous:
465 num_before = tensor_pos[0]
467 before_dims = slice_shape[:num_before]
468 after_dims = slice_shape[num_before:]
469 natural_shape = before_dims + broadcast_shape + after_dims
470 values = values.broadcast_to(natural_shape)
472 B, T = len(before_dims), len(broadcast_shape)
473 val_perm = (
474 list(range(B, B + T)) + list(range(0, B)) + list(range(B + T, values.ndim))
475 )
476 values = values.permute(val_perm)
477 else:
478 values = values.broadcast_to(target_shape)
480 _index_put_func(inp_view, tensors, values, accumulate)
481 return inp
484def _index_put_impl_(inp, indices, values, accumulate=False, unsafe=False):
485 logger.debug("GEMS_MTHREADS _INDEX_PUT_IMPL_")
487 # The `unsafe` parameter is a hint to PyTorch for bounds checking.
488 # Our implementation always performs bounds checking, so we ignore this parameter.
489 # This is consistent with how PyTorch handles it internally.
491 indices = list(indices)
492 if len(indices) == 1 and indices[0].dtype == torch.bool:
493 mask = indices[0]
495 if mask.device != inp.device:
496 mask = mask.to(inp.device)
498 indices = list(torch.where(mask))
500 K = indices[0].numel()
501 target_shape = (K,) + inp.shape[len(indices) :]
503 if values.numel() == 1:
504 values = torch.full(
505 target_shape, values.item(), dtype=inp.dtype, device=inp.device
506 )
507 elif values.numel() == K:
508 values = values.reshape((K,)).expand(target_shape)
510 indices = [
511 index.to(inp.device)
512 if index is not None and index.device != inp.device
513 else index
514 for index in indices
515 ]
517 processed_indices = []
518 for idx in indices:
519 if idx is None:
520 processed_indices.append(None)
521 elif idx.dtype in (torch.bool, torch.int8):
522 processed_indices.extend(idx.nonzero(as_tuple=True))
523 elif torch.is_tensor(idx):
524 processed_indices.append(idx)
525 else:
526 raise TypeError(
527 "tensors used as indices must be long, int, byte or bool tensors"
528 )
530 indices = processed_indices
532 if len(indices) < inp.ndim:
533 indices.extend([None] * (inp.ndim - len(indices)))
535 if len(indices) > inp.ndim:
536 raise IndexError("too many indices for tensor of dimension {}".format(inp.ndim))
538 tensor_pos = [i for i, x in enumerate(indices) if x is not None]
539 if not tensor_pos:
540 raise ValueError("At least one non-None index tensor is required")
542 tensor_indices = [indices[i] for i in tensor_pos]
543 if len(tensor_indices) > 1:
544 broadcasted = torch.broadcast_tensors(*tensor_indices)
545 for i, pos in enumerate(tensor_pos):
546 indices[pos] = broadcasted[i]
548 is_contiguous = (tensor_pos[-1] - tensor_pos[0] + 1) == len(tensor_pos)
549 starts_with_none = indices[0] is None
550 need_transpose = not is_contiguous or starts_with_none
552 if need_transpose:
553 perm_order = tensor_pos + [i for i, x in enumerate(indices) if x is None]
554 inp_view = inp.permute(perm_order)
555 final_indices = [indices[i] for i in tensor_pos] + [None] * (
556 len(indices) - len(tensor_pos)
557 )
558 else:
559 inp_view = inp
560 final_indices = indices
562 tensors = [x for x in final_indices if x is not None]
563 broadcast_shape = list(tensors[0].shape)
564 slice_shape = [inp_view.shape[i] for i, x in enumerate(final_indices) if x is None]
566 target_shape = broadcast_shape + slice_shape
567 values = values.to(inp.device)
568 if need_transpose and is_contiguous:
569 num_before = tensor_pos[0]
571 before_dims = slice_shape[:num_before]
572 after_dims = slice_shape[num_before:]
573 natural_shape = before_dims + broadcast_shape + after_dims
574 values = values.broadcast_to(natural_shape)
576 B, T = len(before_dims), len(broadcast_shape)
577 val_perm = (
578 list(range(B, B + T)) + list(range(0, B)) + list(range(B + T, values.ndim))
579 )
580 values = values.permute(val_perm)
581 else:
582 values = values.broadcast_to(target_shape)
584 _index_put_func(inp_view, tensors, values, accumulate)
585 return inp