Coverage for src/flag_gems/ops/roll.py: 66%

279 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2from collections.abc import Sequence 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils import libentry 

9 

10logger = logging.getLogger(__name__) 

11 

12IntOrInts = int | Sequence[int] 

13MAX_DIMS = 5 

14 

15 

16def roll(inp: torch.Tensor, shifts, dims=None) -> torch.Tensor: 

17 logger.debug("GEMS ROLL") 

18 

19 validate_inputs(inp, shifts, dims) 

20 if _can_use_triton(inp): 

21 if _can_use_flat_single_dim_triton(inp, dims): 

22 return _candidate_triton(inp, shifts, None) 

23 if _can_use_first_dim_triton(inp, dims): 

24 return _candidate_triton_first_dim(inp, shifts) 

25 if _can_use_last_dim_triton(inp, dims): 

26 return _candidate_triton_last_dim(inp, shifts) 

27 if dims is not None and not _is_empty_sequence(dims): 

28 dim_values = _as_tuple(dims) 

29 if len(dim_values) == 1: 

30 return _candidate_triton_single_dim( 

31 inp, 

32 _as_tuple(shifts)[0], 

33 _canonicalize_dim( 

34 dim_values[0], 

35 inp.dim(), 

36 allow_empty_wrap=inp.numel() == 0, 

37 ), 

38 ) 

39 return _candidate_triton(inp, shifts, dims) 

40 return _candidate_fallback(inp, shifts, dims) 

41 

42 

43def _candidate_triton( 

44 inp: torch.Tensor, shifts: IntOrInts, dims: IntOrInts | None = None 

45) -> torch.Tensor: 

46 shift_values = _as_tuple(shifts) 

47 

48 if dims is None or _is_empty_sequence(dims): 

49 flattened = inp.reshape(-1).contiguous() 

50 out_flat = torch.empty_like(flattened) 

51 block = _select_flat_block(flattened) 

52 _launch_roll_flat_kernel( 

53 flattened, 

54 out_flat, 

55 shift_values[0] % max(flattened.numel(), 1), 

56 block=block, 

57 ) 

58 return out_flat.reshape(inp.shape) 

59 

60 return _candidate_triton_multi_dim(inp, shift_values, _as_tuple(dims)) 

61 

62 

63def _candidate_triton_last_dim(inp: torch.Tensor, shifts: IntOrInts) -> torch.Tensor: 

64 shift = _as_tuple(shifts)[0] % inp.shape[-1] 

65 if shift == 0: 

66 return inp.contiguous().clone() 

67 

68 out = torch.empty_like(inp) 

69 _launch_roll_last_dim_kernel(inp, out, shift) 

70 return out 

71 

72 

73def _candidate_triton_first_dim(inp: torch.Tensor, shifts: IntOrInts) -> torch.Tensor: 

74 shift = (_as_tuple(shifts)[0] % inp.shape[0]) * inp.stride(0) 

75 if shift == 0: 

76 return inp.contiguous().clone() 

77 

78 out = torch.empty_like(inp) 

79 _launch_roll_flat_kernel(inp.reshape(-1), out.reshape(-1), shift, block=1024) 

80 return out 

81 

82 

83def _select_flat_block(inp: torch.Tensor) -> int: 

84 if inp.numel() <= 2048: 

85 return 128 

86 if inp.dtype is torch.float32 and inp.numel() >= (1 << 20): 

87 return 1024 

88 return 512 

89 

90 

91def _candidate_triton_single_dim( 

92 inp: torch.Tensor, shift: int, dim: int 

93) -> torch.Tensor: 

94 size = inp.size(dim) 

95 if size == 0: 

96 return inp.clone() 

97 

98 shift %= size 

99 if shift == 0: 

100 return inp.clone() 

101 

102 inp_contig = inp.contiguous() 

103 out = torch.empty_like(inp_contig) 

104 dim_stride = inp_contig.stride(dim) 

105 _launch_roll_single_dim_kernel(inp_contig, out, size, shift, dim_stride) 

106 return out 

107 

108 

109def _candidate_triton_multi_dim( 

110 inp: torch.Tensor, shifts: Sequence[int], dims: Sequence[int] 

111) -> torch.Tensor: 

112 if inp.numel() == 0: 

113 return inp.clone() 

114 

115 effective_shifts = _normalize_roll_dims(inp.shape, shifts, dims) 

116 active_dims = [ 

117 (dim, shift) 

118 for dim, (size, shift) in enumerate(zip(inp.shape, effective_shifts)) 

119 if size and shift 

120 ] 

121 if not active_dims: 

122 return inp.contiguous().clone() 

123 

124 if len(active_dims) == 1: 

125 dim, shift = active_dims[0] 

126 if inp.is_contiguous() and _can_use_first_dim_triton(inp, dim): 

127 return _candidate_triton_first_dim(inp, shift) 

128 if inp.is_contiguous() and _can_use_last_dim_triton(inp, dim): 

129 return _candidate_triton_last_dim(inp, shift) 

130 return _candidate_triton_single_dim(inp, shift, dim) 

131 

132 inp_contig = inp.contiguous() 

133 out = torch.empty_like(inp_contig) 

134 sizes = [inp_contig.size(dim) for dim, _ in active_dims] 

135 strides = [inp_contig.stride(dim) for dim, _ in active_dims] 

136 active_shifts = [shift for _, shift in active_dims] 

137 _launch_roll_multi_dim_kernel(inp_contig, out, sizes, strides, active_shifts) 

138 return out 

139 

140 

141def _candidate_fallback( 

142 inp: torch.Tensor, shifts: IntOrInts, dims: IntOrInts | None = None 

143) -> torch.Tensor: 

144 shift_values = _as_tuple(shifts) 

145 

146 if dims is None or _is_empty_sequence(dims): 

147 flattened = inp.reshape(-1) 

148 return _roll_along_dim(flattened, shift_values[0], 0).reshape(inp.shape) 

149 

150 result = inp 

151 for shift, dim in zip(shift_values, _as_tuple(dims)): 

152 result = _roll_along_dim( 

153 result, 

154 shift, 

155 _canonicalize_dim(dim, inp.dim(), allow_empty_wrap=inp.numel() == 0), 

156 ) 

157 return result 

158 

159 

160def validate_inputs( 

161 inp: torch.Tensor, shifts: IntOrInts, dims: IntOrInts | None = None 

162) -> None: 

163 if not isinstance(inp, torch.Tensor): 

164 raise TypeError("roll(): argument 'input' must be Tensor") 

165 if not _is_int_or_int_sequence(shifts): 

166 raise TypeError("roll(): argument 'shifts' must be int or tuple of ints") 

167 shift_count = 1 if isinstance(shifts, int) else len(shifts) 

168 if shift_count == 0: 

169 raise RuntimeError("`shifts` required") 

170 

171 if dims is None or _is_empty_sequence(dims): 

172 if shift_count > 1: 

173 raise RuntimeError( 

174 f"shifts and dimensions must align. shifts: {shift_count}, dims:0" 

175 ) 

176 return 

177 

178 if not _is_int_or_int_sequence(dims): 

179 raise TypeError("roll(): argument 'dims' must be int or tuple of ints") 

180 dim_count = 1 if isinstance(dims, int) else len(dims) 

181 if shift_count != dim_count: 

182 raise RuntimeError("shifts and dimensions must align") 

183 

184 

185def _roll_along_dim(inp: torch.Tensor, shift: int, dim: int) -> torch.Tensor: 

186 size = inp.size(dim) 

187 if size == 0: 

188 return inp.clone(memory_format=torch.preserve_format) 

189 

190 shift %= size 

191 if shift == 0: 

192 return inp.clone(memory_format=torch.preserve_format) 

193 

194 split = size - shift 

195 return torch.cat( 

196 (inp.narrow(dim, split, shift), inp.narrow(dim, 0, split)), dim=dim 

197 ) 

198 

199 

200def _canonicalize_dim(dim: int, ndim: int, allow_empty_wrap: bool = False) -> int: 

201 if ndim == 0: 

202 raise IndexError(f"Dimension specified as {dim} but tensor has no dimensions") 

203 if allow_empty_wrap: 

204 return dim % ndim 

205 if dim < -ndim or dim >= ndim: 

206 raise IndexError( 

207 f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {dim})" 

208 ) 

209 return dim % ndim 

210 

211 

212def _as_tuple(value: IntOrInts) -> tuple[int, ...]: 

213 if isinstance(value, int): 

214 return (value,) 

215 return tuple(value) 

216 

217 

218def _can_use_triton(inp: torch.Tensor) -> bool: 

219 return inp.is_cuda and inp.dim() <= MAX_DIMS and not inp.dtype.is_complex 

220 

221 

222def _can_use_first_dim_triton(inp: torch.Tensor, dims: IntOrInts | None) -> bool: 

223 if not _can_use_triton(inp) or not inp.is_contiguous() or inp.dim() <= 1: 

224 return False 

225 

226 if isinstance(dims, int): 

227 dim = dims 

228 elif isinstance(dims, Sequence) and not isinstance(dims, int) and len(dims) == 1: 

229 dim = dims[0] 

230 else: 

231 return False 

232 

233 return dim in {0, -inp.dim()} and inp.numel() >= (1 << 20) 

234 

235 

236def _can_use_flat_single_dim_triton(inp: torch.Tensor, dims: IntOrInts | None) -> bool: 

237 if not _can_use_triton(inp) or inp.dim() != 1 or inp.dtype is not torch.float32: 

238 return False 

239 

240 if isinstance(dims, int): 

241 dim = dims 

242 elif isinstance(dims, Sequence) and not isinstance(dims, int) and len(dims) == 1: 

243 dim = dims[0] 

244 else: 

245 return False 

246 

247 return dim in {0, -1} 

248 

249 

250def _can_use_last_dim_triton(inp: torch.Tensor, dims: IntOrInts | None) -> bool: 

251 if not _can_use_triton(inp) or not inp.is_contiguous() or inp.dim() == 0: 

252 return False 

253 

254 if isinstance(dims, int): 

255 dim = dims 

256 elif isinstance(dims, Sequence) and not isinstance(dims, int) and len(dims) == 1: 

257 dim = dims[0] 

258 else: 

259 return False 

260 

261 return dim in {-1, inp.dim() - 1} and inp.numel() >= (1 << 20) 

262 

263 

264def _normalize_roll_dims( 

265 shape: Sequence[int], shifts: Sequence[int], dims: Sequence[int] 

266) -> list[int]: 

267 ndim = len(shape) 

268 effective = [0] * ndim 

269 for shift, dim in zip(shifts, dims): 

270 canonical_dim = _canonicalize_dim(dim, ndim) 

271 effective[canonical_dim] += shift 

272 for index, size in enumerate(shape): 

273 if size: 

274 effective[index] %= size 

275 return effective 

276 

277 

278def _pad_left(values: Sequence[int], total: int, fill_value: int) -> list[int]: 

279 padded = [fill_value] * (total - len(values)) 

280 padded.extend(int(value) for value in values) 

281 return padded 

282 

283 

284def _launch_roll_flat_kernel( 

285 inp: torch.Tensor, out: torch.Tensor, shift: int, block: int = 256 

286) -> None: 

287 if out.numel() == 0: 

288 return 

289 

290 numel = out.numel() 

291 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK"]),) 

292 _roll_flat_kernel[grid](inp, out, numel, shift, BLOCK=block) 

293 

294 

295def _launch_roll_last_dim_kernel( 

296 inp: torch.Tensor, out: torch.Tensor, shift: int 

297) -> None: 

298 if out.numel() == 0: 

299 return 

300 

301 numel = out.numel() 

302 width = inp.shape[-1] 

303 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK"]),) 

304 _roll_last_dim_kernel[grid](inp, out, numel, width, shift, BLOCK=1024) 

305 

306 

307def _launch_roll_single_dim_kernel( 

308 inp: torch.Tensor, 

309 out: torch.Tensor, 

310 dim_size: int, 

311 shift: int, 

312 dim_stride: int, 

313) -> None: 

314 if out.numel() == 0: 

315 return 

316 

317 numel = out.numel() 

318 block = 1024 

319 if inp.dtype is torch.float32 and numel <= (1 << 18): 

320 block = 512 

321 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK"]),) 

322 _roll_single_dim_kernel[grid]( 

323 inp, 

324 out, 

325 numel, 

326 dim_size, 

327 shift, 

328 dim_stride, 

329 BLOCK=block, 

330 ) 

331 

332 

333def _launch_roll_multi_dim_kernel( 

334 inp: torch.Tensor, 

335 out: torch.Tensor, 

336 sizes: Sequence[int], 

337 strides: Sequence[int], 

338 shifts: Sequence[int], 

339) -> None: 

340 if out.numel() == 0: 

341 return 

342 

343 numel = out.numel() 

344 size_values = _pad_right(sizes, MAX_DIMS, 1) 

345 stride_values = _pad_right(strides, MAX_DIMS, 0) 

346 shift_values = _pad_right(shifts, MAX_DIMS, 0) 

347 block = 1024 

348 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK"]),) 

349 _roll_multi_dim_kernel[grid]( 

350 inp, 

351 out, 

352 numel, 

353 size_values[0], 

354 stride_values[0], 

355 shift_values[0], 

356 size_values[1], 

357 stride_values[1], 

358 shift_values[1], 

359 size_values[2], 

360 stride_values[2], 

361 shift_values[2], 

362 size_values[3], 

363 stride_values[3], 

364 shift_values[3], 

365 size_values[4], 

366 stride_values[4], 

367 shift_values[4], 

368 DIMS=len(sizes), 

369 BLOCK=block, 

370 ) 

371 

372 

373def _is_int_or_int_sequence(value: object) -> bool: 

374 if isinstance(value, int): 

375 return True 

376 if not isinstance(value, Sequence): 

377 return False 

378 return all(isinstance(item, int) for item in value) 

379 

380 

381def _is_empty_sequence(value: object) -> bool: 

382 return ( 

383 isinstance(value, Sequence) and not isinstance(value, int) and len(value) == 0 

384 ) 

385 

386 

387def _pad_right(values: Sequence[int], total: int, fill_value: int) -> list[int]: 

388 padded = [int(value) for value in values] 

389 padded.extend([fill_value] * (total - len(padded))) 

390 return padded 

391 

392 

393@libentry() 

394@triton.jit 

395def _roll_flat_kernel(inp_ptr, out_ptr, numel, shift, BLOCK: tl.constexpr): 

396 offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

397 mask = offsets < numel 

398 split = numel - shift 

399 src_offsets = offsets + split 

400 src_offsets = tl.where(offsets < shift, src_offsets, offsets - shift) 

401 values = tl.load(inp_ptr + src_offsets, mask=mask, other=0) 

402 tl.store(out_ptr + offsets, values, mask=mask) 

403 

404 

405@libentry() 

406@triton.jit 

407def _roll_single_dim_kernel( 

408 inp_ptr, 

409 out_ptr, 

410 numel, 

411 dim_size, 

412 shift, 

413 dim_stride, 

414 BLOCK: tl.constexpr, 

415): 

416 offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

417 mask = offsets < numel 

418 

419 dim_index = (offsets // dim_stride) % dim_size 

420 target_dim_index = (dim_index + shift) % dim_size 

421 target_offsets = offsets + (target_dim_index - dim_index) * dim_stride 

422 

423 values = tl.load(inp_ptr + offsets, mask=mask, other=0) 

424 tl.store(out_ptr + target_offsets, values, mask=mask) 

425 

426 

427@libentry() 

428@triton.jit 

429def _roll_last_dim_kernel(inp_ptr, out_ptr, numel, width, shift, BLOCK: tl.constexpr): 

430 offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

431 mask = offsets < numel 

432 column = offsets % width 

433 row_start = offsets - column 

434 source_column = (column + width - shift) % width 

435 values = tl.load(inp_ptr + row_start + source_column, mask=mask, other=0) 

436 tl.store(out_ptr + offsets, values, mask=mask) 

437 

438 

439@libentry() 

440@triton.jit 

441def _roll_multi_dim_kernel( 

442 inp_ptr, 

443 out_ptr, 

444 numel, 

445 size0, 

446 stride0, 

447 shift0, 

448 size1, 

449 stride1, 

450 shift1, 

451 size2, 

452 stride2, 

453 shift2, 

454 size3, 

455 stride3, 

456 shift3, 

457 size4, 

458 stride4, 

459 shift4, 

460 DIMS: tl.constexpr, 

461 BLOCK: tl.constexpr, 

462): 

463 offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

464 mask = offsets < numel 

465 source_offsets = offsets 

466 

467 if DIMS >= 1: 

468 dim_index0 = (offsets // stride0) % size0 

469 source_dim_index0 = (dim_index0 + size0 - shift0) % size0 

470 source_offsets += (source_dim_index0 - dim_index0) * stride0 

471 if DIMS >= 2: 

472 dim_index1 = (offsets // stride1) % size1 

473 source_dim_index1 = (dim_index1 + size1 - shift1) % size1 

474 source_offsets += (source_dim_index1 - dim_index1) * stride1 

475 if DIMS >= 3: 

476 dim_index2 = (offsets // stride2) % size2 

477 source_dim_index2 = (dim_index2 + size2 - shift2) % size2 

478 source_offsets += (source_dim_index2 - dim_index2) * stride2 

479 if DIMS >= 4: 

480 dim_index3 = (offsets // stride3) % size3 

481 source_dim_index3 = (dim_index3 + size3 - shift3) % size3 

482 source_offsets += (source_dim_index3 - dim_index3) * stride3 

483 if DIMS >= 5: 

484 dim_index4 = (offsets // stride4) % size4 

485 source_dim_index4 = (dim_index4 + size4 - shift4) % size4 

486 source_offsets += (source_dim_index4 - dim_index4) * stride4 

487 

488 values = tl.load(inp_ptr + source_offsets, mask=mask, other=0) 

489 tl.store(out_ptr + offsets, values, mask=mask)