Coverage for src/flag_gems/ops/roll.py: 66%
279 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
2from collections.abc import Sequence
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils import libentry
10logger = logging.getLogger(__name__)
12IntOrInts = int | Sequence[int]
13MAX_DIMS = 5
16def roll(inp: torch.Tensor, shifts, dims=None) -> torch.Tensor:
17 logger.debug("GEMS ROLL")
19 validate_inputs(inp, shifts, dims)
20 if _can_use_triton(inp):
21 if _can_use_flat_single_dim_triton(inp, dims):
22 return _candidate_triton(inp, shifts, None)
23 if _can_use_first_dim_triton(inp, dims):
24 return _candidate_triton_first_dim(inp, shifts)
25 if _can_use_last_dim_triton(inp, dims):
26 return _candidate_triton_last_dim(inp, shifts)
27 if dims is not None and not _is_empty_sequence(dims):
28 dim_values = _as_tuple(dims)
29 if len(dim_values) == 1:
30 return _candidate_triton_single_dim(
31 inp,
32 _as_tuple(shifts)[0],
33 _canonicalize_dim(
34 dim_values[0],
35 inp.dim(),
36 allow_empty_wrap=inp.numel() == 0,
37 ),
38 )
39 return _candidate_triton(inp, shifts, dims)
40 return _candidate_fallback(inp, shifts, dims)
43def _candidate_triton(
44 inp: torch.Tensor, shifts: IntOrInts, dims: IntOrInts | None = None
45) -> torch.Tensor:
46 shift_values = _as_tuple(shifts)
48 if dims is None or _is_empty_sequence(dims):
49 flattened = inp.reshape(-1).contiguous()
50 out_flat = torch.empty_like(flattened)
51 block = _select_flat_block(flattened)
52 _launch_roll_flat_kernel(
53 flattened,
54 out_flat,
55 shift_values[0] % max(flattened.numel(), 1),
56 block=block,
57 )
58 return out_flat.reshape(inp.shape)
60 return _candidate_triton_multi_dim(inp, shift_values, _as_tuple(dims))
63def _candidate_triton_last_dim(inp: torch.Tensor, shifts: IntOrInts) -> torch.Tensor:
64 shift = _as_tuple(shifts)[0] % inp.shape[-1]
65 if shift == 0:
66 return inp.contiguous().clone()
68 out = torch.empty_like(inp)
69 _launch_roll_last_dim_kernel(inp, out, shift)
70 return out
73def _candidate_triton_first_dim(inp: torch.Tensor, shifts: IntOrInts) -> torch.Tensor:
74 shift = (_as_tuple(shifts)[0] % inp.shape[0]) * inp.stride(0)
75 if shift == 0:
76 return inp.contiguous().clone()
78 out = torch.empty_like(inp)
79 _launch_roll_flat_kernel(inp.reshape(-1), out.reshape(-1), shift, block=1024)
80 return out
83def _select_flat_block(inp: torch.Tensor) -> int:
84 if inp.numel() <= 2048:
85 return 128
86 if inp.dtype is torch.float32 and inp.numel() >= (1 << 20):
87 return 1024
88 return 512
91def _candidate_triton_single_dim(
92 inp: torch.Tensor, shift: int, dim: int
93) -> torch.Tensor:
94 size = inp.size(dim)
95 if size == 0:
96 return inp.clone()
98 shift %= size
99 if shift == 0:
100 return inp.clone()
102 inp_contig = inp.contiguous()
103 out = torch.empty_like(inp_contig)
104 dim_stride = inp_contig.stride(dim)
105 _launch_roll_single_dim_kernel(inp_contig, out, size, shift, dim_stride)
106 return out
109def _candidate_triton_multi_dim(
110 inp: torch.Tensor, shifts: Sequence[int], dims: Sequence[int]
111) -> torch.Tensor:
112 if inp.numel() == 0:
113 return inp.clone()
115 effective_shifts = _normalize_roll_dims(inp.shape, shifts, dims)
116 active_dims = [
117 (dim, shift)
118 for dim, (size, shift) in enumerate(zip(inp.shape, effective_shifts))
119 if size and shift
120 ]
121 if not active_dims:
122 return inp.contiguous().clone()
124 if len(active_dims) == 1:
125 dim, shift = active_dims[0]
126 if inp.is_contiguous() and _can_use_first_dim_triton(inp, dim):
127 return _candidate_triton_first_dim(inp, shift)
128 if inp.is_contiguous() and _can_use_last_dim_triton(inp, dim):
129 return _candidate_triton_last_dim(inp, shift)
130 return _candidate_triton_single_dim(inp, shift, dim)
132 inp_contig = inp.contiguous()
133 out = torch.empty_like(inp_contig)
134 sizes = [inp_contig.size(dim) for dim, _ in active_dims]
135 strides = [inp_contig.stride(dim) for dim, _ in active_dims]
136 active_shifts = [shift for _, shift in active_dims]
137 _launch_roll_multi_dim_kernel(inp_contig, out, sizes, strides, active_shifts)
138 return out
141def _candidate_fallback(
142 inp: torch.Tensor, shifts: IntOrInts, dims: IntOrInts | None = None
143) -> torch.Tensor:
144 shift_values = _as_tuple(shifts)
146 if dims is None or _is_empty_sequence(dims):
147 flattened = inp.reshape(-1)
148 return _roll_along_dim(flattened, shift_values[0], 0).reshape(inp.shape)
150 result = inp
151 for shift, dim in zip(shift_values, _as_tuple(dims)):
152 result = _roll_along_dim(
153 result,
154 shift,
155 _canonicalize_dim(dim, inp.dim(), allow_empty_wrap=inp.numel() == 0),
156 )
157 return result
160def validate_inputs(
161 inp: torch.Tensor, shifts: IntOrInts, dims: IntOrInts | None = None
162) -> None:
163 if not isinstance(inp, torch.Tensor):
164 raise TypeError("roll(): argument 'input' must be Tensor")
165 if not _is_int_or_int_sequence(shifts):
166 raise TypeError("roll(): argument 'shifts' must be int or tuple of ints")
167 shift_count = 1 if isinstance(shifts, int) else len(shifts)
168 if shift_count == 0:
169 raise RuntimeError("`shifts` required")
171 if dims is None or _is_empty_sequence(dims):
172 if shift_count > 1:
173 raise RuntimeError(
174 f"shifts and dimensions must align. shifts: {shift_count}, dims:0"
175 )
176 return
178 if not _is_int_or_int_sequence(dims):
179 raise TypeError("roll(): argument 'dims' must be int or tuple of ints")
180 dim_count = 1 if isinstance(dims, int) else len(dims)
181 if shift_count != dim_count:
182 raise RuntimeError("shifts and dimensions must align")
185def _roll_along_dim(inp: torch.Tensor, shift: int, dim: int) -> torch.Tensor:
186 size = inp.size(dim)
187 if size == 0:
188 return inp.clone(memory_format=torch.preserve_format)
190 shift %= size
191 if shift == 0:
192 return inp.clone(memory_format=torch.preserve_format)
194 split = size - shift
195 return torch.cat(
196 (inp.narrow(dim, split, shift), inp.narrow(dim, 0, split)), dim=dim
197 )
200def _canonicalize_dim(dim: int, ndim: int, allow_empty_wrap: bool = False) -> int:
201 if ndim == 0:
202 raise IndexError(f"Dimension specified as {dim} but tensor has no dimensions")
203 if allow_empty_wrap:
204 return dim % ndim
205 if dim < -ndim or dim >= ndim:
206 raise IndexError(
207 f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {dim})"
208 )
209 return dim % ndim
212def _as_tuple(value: IntOrInts) -> tuple[int, ...]:
213 if isinstance(value, int):
214 return (value,)
215 return tuple(value)
218def _can_use_triton(inp: torch.Tensor) -> bool:
219 return inp.is_cuda and inp.dim() <= MAX_DIMS and not inp.dtype.is_complex
222def _can_use_first_dim_triton(inp: torch.Tensor, dims: IntOrInts | None) -> bool:
223 if not _can_use_triton(inp) or not inp.is_contiguous() or inp.dim() <= 1:
224 return False
226 if isinstance(dims, int):
227 dim = dims
228 elif isinstance(dims, Sequence) and not isinstance(dims, int) and len(dims) == 1:
229 dim = dims[0]
230 else:
231 return False
233 return dim in {0, -inp.dim()} and inp.numel() >= (1 << 20)
236def _can_use_flat_single_dim_triton(inp: torch.Tensor, dims: IntOrInts | None) -> bool:
237 if not _can_use_triton(inp) or inp.dim() != 1 or inp.dtype is not torch.float32:
238 return False
240 if isinstance(dims, int):
241 dim = dims
242 elif isinstance(dims, Sequence) and not isinstance(dims, int) and len(dims) == 1:
243 dim = dims[0]
244 else:
245 return False
247 return dim in {0, -1}
250def _can_use_last_dim_triton(inp: torch.Tensor, dims: IntOrInts | None) -> bool:
251 if not _can_use_triton(inp) or not inp.is_contiguous() or inp.dim() == 0:
252 return False
254 if isinstance(dims, int):
255 dim = dims
256 elif isinstance(dims, Sequence) and not isinstance(dims, int) and len(dims) == 1:
257 dim = dims[0]
258 else:
259 return False
261 return dim in {-1, inp.dim() - 1} and inp.numel() >= (1 << 20)
264def _normalize_roll_dims(
265 shape: Sequence[int], shifts: Sequence[int], dims: Sequence[int]
266) -> list[int]:
267 ndim = len(shape)
268 effective = [0] * ndim
269 for shift, dim in zip(shifts, dims):
270 canonical_dim = _canonicalize_dim(dim, ndim)
271 effective[canonical_dim] += shift
272 for index, size in enumerate(shape):
273 if size:
274 effective[index] %= size
275 return effective
278def _pad_left(values: Sequence[int], total: int, fill_value: int) -> list[int]:
279 padded = [fill_value] * (total - len(values))
280 padded.extend(int(value) for value in values)
281 return padded
284def _launch_roll_flat_kernel(
285 inp: torch.Tensor, out: torch.Tensor, shift: int, block: int = 256
286) -> None:
287 if out.numel() == 0:
288 return
290 numel = out.numel()
291 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK"]),)
292 _roll_flat_kernel[grid](inp, out, numel, shift, BLOCK=block)
295def _launch_roll_last_dim_kernel(
296 inp: torch.Tensor, out: torch.Tensor, shift: int
297) -> None:
298 if out.numel() == 0:
299 return
301 numel = out.numel()
302 width = inp.shape[-1]
303 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK"]),)
304 _roll_last_dim_kernel[grid](inp, out, numel, width, shift, BLOCK=1024)
307def _launch_roll_single_dim_kernel(
308 inp: torch.Tensor,
309 out: torch.Tensor,
310 dim_size: int,
311 shift: int,
312 dim_stride: int,
313) -> None:
314 if out.numel() == 0:
315 return
317 numel = out.numel()
318 block = 1024
319 if inp.dtype is torch.float32 and numel <= (1 << 18):
320 block = 512
321 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK"]),)
322 _roll_single_dim_kernel[grid](
323 inp,
324 out,
325 numel,
326 dim_size,
327 shift,
328 dim_stride,
329 BLOCK=block,
330 )
333def _launch_roll_multi_dim_kernel(
334 inp: torch.Tensor,
335 out: torch.Tensor,
336 sizes: Sequence[int],
337 strides: Sequence[int],
338 shifts: Sequence[int],
339) -> None:
340 if out.numel() == 0:
341 return
343 numel = out.numel()
344 size_values = _pad_right(sizes, MAX_DIMS, 1)
345 stride_values = _pad_right(strides, MAX_DIMS, 0)
346 shift_values = _pad_right(shifts, MAX_DIMS, 0)
347 block = 1024
348 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK"]),)
349 _roll_multi_dim_kernel[grid](
350 inp,
351 out,
352 numel,
353 size_values[0],
354 stride_values[0],
355 shift_values[0],
356 size_values[1],
357 stride_values[1],
358 shift_values[1],
359 size_values[2],
360 stride_values[2],
361 shift_values[2],
362 size_values[3],
363 stride_values[3],
364 shift_values[3],
365 size_values[4],
366 stride_values[4],
367 shift_values[4],
368 DIMS=len(sizes),
369 BLOCK=block,
370 )
373def _is_int_or_int_sequence(value: object) -> bool:
374 if isinstance(value, int):
375 return True
376 if not isinstance(value, Sequence):
377 return False
378 return all(isinstance(item, int) for item in value)
381def _is_empty_sequence(value: object) -> bool:
382 return (
383 isinstance(value, Sequence) and not isinstance(value, int) and len(value) == 0
384 )
387def _pad_right(values: Sequence[int], total: int, fill_value: int) -> list[int]:
388 padded = [int(value) for value in values]
389 padded.extend([fill_value] * (total - len(padded)))
390 return padded
393@libentry()
394@triton.jit
395def _roll_flat_kernel(inp_ptr, out_ptr, numel, shift, BLOCK: tl.constexpr):
396 offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
397 mask = offsets < numel
398 split = numel - shift
399 src_offsets = offsets + split
400 src_offsets = tl.where(offsets < shift, src_offsets, offsets - shift)
401 values = tl.load(inp_ptr + src_offsets, mask=mask, other=0)
402 tl.store(out_ptr + offsets, values, mask=mask)
405@libentry()
406@triton.jit
407def _roll_single_dim_kernel(
408 inp_ptr,
409 out_ptr,
410 numel,
411 dim_size,
412 shift,
413 dim_stride,
414 BLOCK: tl.constexpr,
415):
416 offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
417 mask = offsets < numel
419 dim_index = (offsets // dim_stride) % dim_size
420 target_dim_index = (dim_index + shift) % dim_size
421 target_offsets = offsets + (target_dim_index - dim_index) * dim_stride
423 values = tl.load(inp_ptr + offsets, mask=mask, other=0)
424 tl.store(out_ptr + target_offsets, values, mask=mask)
427@libentry()
428@triton.jit
429def _roll_last_dim_kernel(inp_ptr, out_ptr, numel, width, shift, BLOCK: tl.constexpr):
430 offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
431 mask = offsets < numel
432 column = offsets % width
433 row_start = offsets - column
434 source_column = (column + width - shift) % width
435 values = tl.load(inp_ptr + row_start + source_column, mask=mask, other=0)
436 tl.store(out_ptr + offsets, values, mask=mask)
439@libentry()
440@triton.jit
441def _roll_multi_dim_kernel(
442 inp_ptr,
443 out_ptr,
444 numel,
445 size0,
446 stride0,
447 shift0,
448 size1,
449 stride1,
450 shift1,
451 size2,
452 stride2,
453 shift2,
454 size3,
455 stride3,
456 shift3,
457 size4,
458 stride4,
459 shift4,
460 DIMS: tl.constexpr,
461 BLOCK: tl.constexpr,
462):
463 offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
464 mask = offsets < numel
465 source_offsets = offsets
467 if DIMS >= 1:
468 dim_index0 = (offsets // stride0) % size0
469 source_dim_index0 = (dim_index0 + size0 - shift0) % size0
470 source_offsets += (source_dim_index0 - dim_index0) * stride0
471 if DIMS >= 2:
472 dim_index1 = (offsets // stride1) % size1
473 source_dim_index1 = (dim_index1 + size1 - shift1) % size1
474 source_offsets += (source_dim_index1 - dim_index1) * stride1
475 if DIMS >= 3:
476 dim_index2 = (offsets // stride2) % size2
477 source_dim_index2 = (dim_index2 + size2 - shift2) % size2
478 source_offsets += (source_dim_index2 - dim_index2) * stride2
479 if DIMS >= 4:
480 dim_index3 = (offsets // stride3) % size3
481 source_dim_index3 = (dim_index3 + size3 - shift3) % size3
482 source_offsets += (source_dim_index3 - dim_index3) * stride3
483 if DIMS >= 5:
484 dim_index4 = (offsets // stride4) % size4
485 source_dim_index4 = (dim_index4 + size4 - shift4) % size4
486 source_offsets += (source_dim_index4 - dim_index4) * stride4
488 values = tl.load(inp_ptr + source_offsets, mask=mask, other=0)
489 tl.store(out_ptr + offsets, values, mask=mask)