Coverage for src/flag_gems/runtime/backend/_mthreads/ops/repeat_interleave.py: 0%
232 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
9from flag_gems.utils import triton_lang_extension as tle
10from flag_gems.utils.pointwise_dynamic import pointwise_dynamic
11from flag_gems.utils.shape_utils import c_contiguous_stride
12from flag_gems.utils.tensor_wrapper import StridedBuffer
14logger = logging.getLogger(
15 f"flag_gems.runtime.backend._mthreads.ops.{__name__.split('.')[-1]}"
16)
18# repeat_interleave.self_{int,Tensor} are CompositeImplicitAutograd;
19# Direct coverage will cause the gradient to break;
20# Redispatch to this keyset to run the decomposed forward (and backward)
21# when gradients may be needed.
22_FALLBACK_KEYSET = torch._C.DispatchKeySet(
23 torch._C.DispatchKey.CompositeImplicitAutograd
24)
27@pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")])
28@triton.jit
29def copy_func(x):
30 return x
33def repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None):
34 logger.debug("GEMS_MTHREADS REPEAT_INTERLEAVE_SELF_INT")
35 if torch.is_grad_enabled():
36 return torch.ops.aten.repeat_interleave.self_int.redispatch(
37 _FALLBACK_KEYSET, inp, repeats, dim, output_size=output_size
38 )
39 if dim is None:
40 inp = inp.flatten()
41 dim = 0
42 else:
43 if (dim < -inp.ndim) or (dim >= inp.ndim):
44 raise IndexError(
45 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
46 -inp.ndim, inp.ndim - 1, dim
47 )
48 )
49 inp_shape = list(inp.shape)
50 inp_stride = list(inp.stride())
51 output_shape = list(inp.shape)
53 if dim < 0:
54 dim = dim + len(inp_shape)
56 output_shape[dim] *= repeats
58 if output_size is not None and output_size != output_shape[dim]:
59 raise RuntimeError(
60 "repeat_interleave: Invalid output_size, expected {} but got {}".format(
61 output_shape[dim], output_size
62 )
63 )
65 output = torch.empty(output_shape, dtype=inp.dtype, device=inp.device)
67 if repeats == 0:
68 return output
70 in_view_stride = inp_stride[: dim + 1] + [0] + inp_stride[dim + 1 :]
71 out_view_shape = inp_shape[: dim + 1] + [repeats] + inp_shape[dim + 1 :]
72 out_view_stride = c_contiguous_stride(out_view_shape)
74 in_view = StridedBuffer(inp, out_view_shape, in_view_stride)
75 out_view = StridedBuffer(output, out_view_shape, out_view_stride)
76 ndim = len(out_view_shape)
77 copy_func.instantiate(ndim)(in_view, out0=out_view)
78 return output
81@triton.jit
82def repeat_interleave_tensor_kernel(
83 repeats_ptr, cumsum_ptr, out_ptr, size, BLOCK_SIZE: tl.constexpr
84):
85 pid = tle.program_id(0)
86 mask = pid < size
87 cumsum = tl.load(cumsum_ptr + pid, mask, other=0)
88 repeats = tl.load(repeats_ptr + pid, mask, other=0)
89 out_offset = cumsum - repeats
91 tl.device_assert(repeats >= 0, "repeats can not be negative")
93 out_ptr += out_offset
94 for start_k in range(0, repeats, BLOCK_SIZE):
95 offsets_k = start_k + tl.arange(0, BLOCK_SIZE)
96 mask_k = offsets_k < repeats
97 tl.store(out_ptr + offsets_k, pid, mask=mask_k)
100def repeat_interleave_tensor(repeats, *, output_size=None):
101 logger.debug("GEMS_MTHREADS REPEAT_INTERLEAVE_TENSOR")
103 assert repeats.ndim == 1, "repeat_interleave only accept 1D vector as repeat"
105 cumsum = repeats.cumsum(axis=0)
106 result_size = cumsum[-1].item()
108 assert result_size >= 0, "repeats can not be negative"
110 out = torch.empty((result_size,), dtype=repeats.dtype, device=repeats.device)
111 size = repeats.size(0)
113 grid = (size,)
114 BLOCK_SIZE = 32
115 with torch_device_fn.device(repeats.device):
116 repeat_interleave_tensor_kernel[grid](
117 repeats,
118 cumsum,
119 out,
120 size,
121 BLOCK_SIZE=BLOCK_SIZE,
122 num_warps=1,
123 )
124 return out
127@libentry()
128@triton.jit
129def fused_repeat_interleave_dim0_kernel(
130 inp_ptr,
131 out_ptr,
132 cumsum_ptr,
133 num_input_rows,
134 row_size,
135 BLOCK_SIZE: tl.constexpr,
136):
137 """Fused kernel for repeat_interleave with dim=0.
138 Each program handles one input row and copies to all its repeated output positions.
139 """
140 pid = tle.program_id(0)
142 if pid >= num_input_rows:
143 return
145 # Get output row range for this input row
146 row_idx_mask = pid > 0
147 start_row_idx = tl.load(cumsum_ptr + pid - 1, mask=row_idx_mask, other=0)
148 end_row_idx = tl.load(cumsum_ptr + pid)
150 num_of_rows = end_row_idx - start_row_idx
151 if num_of_rows == 0:
152 return
154 # Calculate input row offset
155 inp_row_offset = pid * row_size
157 # Process columns in blocks
158 for col_block in range(0, tl.cdiv(row_size, BLOCK_SIZE)):
159 col_offsets = col_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
160 col_mask = col_offsets < row_size
162 # Load from input
163 cur_inp = tl.load(
164 inp_ptr + inp_row_offset + col_offsets, mask=col_mask, other=0.0
165 )
167 # Store to each output row
168 for cur_row in range(0, num_of_rows):
169 output_row_index = start_row_idx + cur_row
170 output_row_offsets = output_row_index * row_size + col_offsets
171 tl.store(out_ptr + output_row_offsets, cur_inp, mask=col_mask)
174@libentry()
175@triton.jit
176def fused_repeat_interleave_output_centric_kernel(
177 inp_ptr,
178 out_ptr,
179 cumsum_ptr,
180 num_input_rows,
181 num_output_rows,
182 row_size,
183 BLOCK_SIZE: tl.constexpr,
184):
185 """Output-centric kernel for repeat_interleave with dim=0.
186 Uses 2D grid: (num_output_rows, num_col_chunks).
187 Uses binary search to find input row.
188 """
189 out_row_idx = tle.program_id(0)
190 col_chunk_idx = tle.program_id(1)
192 if out_row_idx >= num_output_rows:
193 return
195 # Binary search to find input row index
196 # Find the smallest i such that cumsum[i] > out_row_idx
197 low = 0
198 high = num_input_rows
199 while low < high:
200 mid = (low + high) // 2
201 cumsum_mid = tl.load(cumsum_ptr + mid)
202 if cumsum_mid <= out_row_idx:
203 low = mid + 1
204 else:
205 high = mid
207 inp_row_idx = low
209 # Calculate column offsets for this chunk
210 col_offset = col_chunk_idx * BLOCK_SIZE
211 col_offsets = col_offset + tl.arange(0, BLOCK_SIZE)
212 col_mask = col_offsets < row_size
214 # Load from input
215 inp_row_offset = inp_row_idx * row_size
216 cur_inp = tl.load(inp_ptr + inp_row_offset + col_offsets, mask=col_mask, other=0.0)
218 # Store to output
219 out_row_offset = out_row_idx * row_size
220 tl.store(out_ptr + out_row_offset + col_offsets, cur_inp, mask=col_mask)
223@libentry()
224@triton.jit
225def fused_repeat_interleave_1d_bsearch_kernel(
226 inp_ptr,
227 out_ptr,
228 cumsum_ptr,
229 num_input_rows,
230 num_output_rows,
231 row_size,
232 BLOCK_SIZE: tl.constexpr,
233):
234 """1D output-centric kernel with binary search.
235 Each program handles one complete output row.
236 Better for large row sizes.
237 """
238 out_row_idx = tle.program_id(0)
240 if out_row_idx >= num_output_rows:
241 return
243 # Binary search to find input row index
244 low = 0
245 high = num_input_rows
246 while low < high:
247 mid = (low + high) // 2
248 cumsum_mid = tl.load(cumsum_ptr + mid)
249 if cumsum_mid <= out_row_idx:
250 low = mid + 1
251 else:
252 high = mid
254 inp_row_idx = low
256 # Calculate row offsets
257 inp_row_offset = inp_row_idx * row_size
258 out_row_offset = out_row_idx * row_size
260 # Process all columns in blocks
261 for col_offset in range(0, row_size, BLOCK_SIZE):
262 col_offsets = col_offset + tl.arange(0, BLOCK_SIZE)
263 col_mask = col_offsets < row_size
265 cur_inp = tl.load(
266 inp_ptr + inp_row_offset + col_offsets, mask=col_mask, other=0.0
267 )
268 tl.store(out_ptr + out_row_offset + col_offsets, cur_inp, mask=col_mask)
271@libentry()
272@triton.jit
273def fused_repeat_interleave_with_indices_kernel(
274 inp_ptr,
275 out_ptr,
276 index_ptr,
277 num_output_rows,
278 row_size,
279 BLOCK_SIZE: tl.constexpr,
280):
281 """Output-centric kernel using precomputed index mapping.
282 Uses 2D grid: (num_output_rows, num_col_chunks).
283 """
284 out_row_idx = tle.program_id(0)
285 col_chunk_idx = tle.program_id(1)
287 if out_row_idx >= num_output_rows:
288 return
290 # Load precomputed input row index
291 inp_row_idx = tl.load(index_ptr + out_row_idx)
293 # Calculate column offsets for this chunk
294 col_offset = col_chunk_idx * BLOCK_SIZE
295 col_offsets = col_offset + tl.arange(0, BLOCK_SIZE)
296 col_mask = col_offsets < row_size
298 # Load from input
299 inp_row_offset = inp_row_idx * row_size
300 cur_inp = tl.load(inp_ptr + inp_row_offset + col_offsets, mask=col_mask, other=0.0)
302 # Store to output
303 out_row_offset = out_row_idx * row_size
304 tl.store(out_ptr + out_row_offset + col_offsets, cur_inp, mask=col_mask)
307@libentry()
308@triton.jit
309def fused_repeat_interleave_large_row_kernel(
310 inp_ptr,
311 out_ptr,
312 index_ptr,
313 num_output_rows,
314 row_size,
315 BLOCK_SIZE: tl.constexpr,
316):
317 """Optimized kernel for large row sizes.
318 Each program handles one output row and processes all columns.
319 """
320 out_row_idx = tle.program_id(0)
322 if out_row_idx >= num_output_rows:
323 return
325 # Load precomputed input row index
326 inp_row_idx = tl.load(index_ptr + out_row_idx)
328 # Calculate row offsets
329 inp_row_offset = inp_row_idx * row_size
330 out_row_offset = out_row_idx * row_size
332 # Process all columns in blocks
333 for col_offset in range(0, row_size, BLOCK_SIZE):
334 col_offsets = col_offset + tl.arange(0, BLOCK_SIZE)
335 col_mask = col_offsets < row_size
337 # Load from input and store to output
338 cur_inp = tl.load(
339 inp_ptr + inp_row_offset + col_offsets, mask=col_mask, other=0.0
340 )
341 tl.store(out_ptr + out_row_offset + col_offsets, cur_inp, mask=col_mask)
344def fused_repeat_interleave_dim0(inp, repeats, dim):
345 """Fused repeat_interleave for dim=0 case.
346 Works with any tensor dimension, handles dim=0 efficiently.
347 """
348 logger.debug("GEMS_MTHREADS FUSED_REPEAT_INTERLEAVE_DIM0")
350 assert repeats.ndim == 1, "repeat_interleave only accept 1D vector as repeat"
352 # Compute cumsum of repeats
353 cumsum = repeats.cumsum(axis=0)
354 total_output_rows = cumsum[-1].item()
356 if total_output_rows == 0:
357 out_shape = list(inp.shape)
358 out_shape[dim] = 0
359 return torch.empty(out_shape, dtype=inp.dtype, device=inp.device)
361 # Setup output tensor
362 out_shape = list(inp.shape)
363 out_shape[dim] = total_output_rows
364 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device)
366 # Flatten non-dim dimensions for easier indexing
367 num_input_rows = inp.shape[dim]
368 row_size = inp.numel() // num_input_rows
370 # Make input contiguous for efficient access
371 inp_contig = inp.contiguous()
373 # Strategy selection:
374 # 1. Small tensors: input-centric kernel
375 # 2. Medium row sizes: output-centric 2D grid with binary search
376 # 3. Large row sizes: output-centric 1D grid with binary search
378 if row_size < 512 and total_output_rows < 512:
379 # Small tensor: use input-centric kernel
380 BLOCK_SIZE = min(triton.next_power_of_2(row_size), 4096)
382 if BLOCK_SIZE <= 256:
383 num_warps = 2
384 elif BLOCK_SIZE <= 512:
385 num_warps = 4
386 else:
387 num_warps = 8
389 grid = (num_input_rows,)
391 with torch_device_fn.device(inp.device):
392 fused_repeat_interleave_dim0_kernel[grid](
393 inp_contig,
394 out,
395 cumsum,
396 num_input_rows,
397 row_size,
398 BLOCK_SIZE=BLOCK_SIZE,
399 num_warps=num_warps,
400 )
401 elif row_size >= 16384:
402 # Large row size: use 1D grid with binary search
403 # This reduces total number of programs and amortizes binary search cost
404 BLOCK_SIZE = 2048
405 num_warps = 16
407 grid = (total_output_rows,)
409 with torch_device_fn.device(inp.device):
410 fused_repeat_interleave_1d_bsearch_kernel[grid](
411 inp_contig,
412 out,
413 cumsum,
414 num_input_rows,
415 total_output_rows,
416 row_size,
417 BLOCK_SIZE=BLOCK_SIZE,
418 num_warps=num_warps,
419 )
420 else:
421 # Medium row size: use 2D grid with binary search
422 BLOCK_SIZE = min(triton.next_power_of_2(row_size), 1024)
423 num_col_chunks = triton.cdiv(row_size, BLOCK_SIZE)
425 if BLOCK_SIZE <= 256:
426 num_warps = 2
427 elif BLOCK_SIZE <= 512:
428 num_warps = 4
429 else:
430 num_warps = 8
432 grid = (total_output_rows, num_col_chunks)
434 with torch_device_fn.device(inp.device):
435 fused_repeat_interleave_output_centric_kernel[grid](
436 inp_contig,
437 out,
438 cumsum,
439 num_input_rows,
440 total_output_rows,
441 row_size,
442 BLOCK_SIZE=BLOCK_SIZE,
443 num_warps=num_warps,
444 )
446 return out
449def repeat_interleave_self_tensor(inp, repeats, dim=None, *, output_size=None):
450 logger.debug("GEMS_MTHREADS REPEAT_INTERLEAVE_SELF_TENSOR")
451 if torch.is_grad_enabled():
452 return torch.ops.aten.repeat_interleave.self_Tensor.redispatch(
453 _FALLBACK_KEYSET, inp, repeats, dim, output_size=output_size
454 )
456 if repeats.numel() == 0:
457 return inp.clone()
459 if dim is None:
460 inp = inp.flatten()
461 dim = 0
462 else:
463 if (dim < -inp.ndim) or (dim >= inp.ndim):
464 raise IndexError(
465 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
466 -inp.ndim, inp.ndim - 1, dim
467 )
468 )
470 if repeats.ndim == 0 or (repeats.ndim == 1 and repeats.size(0) == 1):
471 return repeat_interleave_self_int(
472 inp, repeats.item(), dim=dim, output_size=output_size
473 )
474 elif repeats.ndim > 1:
475 raise RuntimeError("repeats must be 0-dim or 1-dim tensor")
477 inp_shape = list(inp.shape)
478 if dim < 0:
479 dim = dim + len(inp_shape)
481 if repeats.size(0) != inp_shape[dim]:
482 raise RuntimeError(
483 "repeats must have the same size as input along dim, but got \
484 repeats.size(0) = {} and input.size({}) = {}".format(
485 repeats.size(0), dim, inp_shape[dim]
486 )
487 )
489 # Use fused kernel for dim=0
490 if dim == 0:
491 return fused_repeat_interleave_dim0(inp, repeats, dim)
493 # For other dimensions, use the fallback implementation
494 indices = repeat_interleave_tensor(repeats)
495 res = torch.index_select(inp, dim, indices)
497 return res