Coverage for src/flag_gems/runtime/backend/_sunrise/ops/svd.py: 0%

1842 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2from collections import namedtuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import device, torch_device_fn 

9from flag_gems.utils import libentry, tensor_wrapper 

10 

11logger = logging.getLogger(__name__) 

12 

13SVDResult = namedtuple("SVDResult", ["U", "S", "V"]) 

14 

15_GRAM_CONDITION_GUARD_MAX_BATCH = 16 

16_GRAM_CONDITION_GUARD_MAX_K = 32 

17_GRAM_CONDITION_EIGEN_RATIO = 1.0e-8 

18_GRAM_TALL_WIDE_MAX_K = 32 

19_GRAM_TALL_WIDE_MAX_ROWS = 1024 

20_RANK1_BLOCK_R_MAX = 1024 

21_RANK2_BLOCK_R_MAX = 2048 

22_TSQR_CHOLESKY_MAX_BATCH = 32 

23_TSQR_CHOLESKY_MAX_K = 128 

24_TSQR_CHOLESKY_MAX_ROWS = 1024 

25 

26 

27def _unsupported_svd(input, some=True, compute_uv=True, reason=None): 

28 batch, m, n = _svd_shape(input) 

29 suffix = "" if reason is None else f" {reason}" 

30 raise NotImplementedError( 

31 "FlagGems Sunrise native SVD currently supports float32 PTPU matrices with " 

32 "some=True, compute_uv=True, non-empty inputs, and native Triton " 

33 f"rank/Jacobi shape coverage; got batch={batch}, m={m}, n={n}, " 

34 f"dtype={input.dtype}, device={input.device}, some={some}, " 

35 f"compute_uv={compute_uv}.{suffix}" 

36 ) 

37 

38 

39def _is_iluvatar_backend(): 

40 return device.vendor_name == "iluvatar" 

41 

42 

43def _svd_shape(input): 

44 if input.dim() < 2: 

45 return 0, 0, 0 

46 m = input.shape[-2] 

47 n = input.shape[-1] 

48 batch = 1 

49 for dim in input.shape[:-2]: 

50 batch *= dim 

51 return batch, m, n 

52 

53 

54def _should_guard_gram_spectrum(batch, k): 

55 return batch <= _GRAM_CONDITION_GUARD_MAX_BATCH and k <= _GRAM_CONDITION_GUARD_MAX_K 

56 

57 

58def _is_float32_cuda_matrix(input): 

59 return ( 

60 input.device.type == "ptpu" 

61 and input.dtype == torch.float32 

62 and input.dim() >= 2 

63 ) 

64 

65 

66def _is_low_precision_cuda_matrix(input): 

67 return ( 

68 input.device.type == "ptpu" 

69 and input.dtype in (torch.float16, torch.bfloat16) 

70 and input.dim() >= 2 

71 ) 

72 

73 

74def _can_use_rank1_kernel(input, some=True, compute_uv=True): 

75 _, m, n = _svd_shape(input) 

76 return _is_float32_cuda_matrix(input) and some and compute_uv and min(m, n) == 1 

77 

78 

79def _can_use_rank2_kernel(input, some=True, compute_uv=True): 

80 _, m, n = _svd_shape(input) 

81 return ( 

82 _is_float32_cuda_matrix(input) 

83 and some 

84 and compute_uv 

85 and min(m, n) == 2 

86 and max(m, n) <= _RANK2_BLOCK_R_MAX 

87 ) 

88 

89 

90def _can_use_2x2_kernel(input): 

91 _, m, n = _svd_shape(input) 

92 return _can_use_rank2_kernel(input, True, True) and m == 2 and n == 2 

93 

94 

95def _can_use_4x4_kernel(input, some=True, compute_uv=True): 

96 _, m, n = _svd_shape(input) 

97 return _is_float32_cuda_matrix(input) and some and compute_uv and m == 4 and n == 4 

98 

99 

100def _can_use_small_jacobi_kernel(input, some=True, compute_uv=True): 

101 _, m, n = _svd_shape(input) 

102 return ( 

103 _is_float32_cuda_matrix(input) 

104 and some 

105 and compute_uv 

106 and not _is_iluvatar_backend() 

107 and min(m, n) <= 16 

108 and max(m, n) <= 1024 

109 ) 

110 

111 

112def _can_use_cyclic_jacobi_kernel(input, some=True, compute_uv=True): 

113 _, m, n = _svd_shape(input) 

114 k = min(m, n) 

115 return ( 

116 _is_float32_cuda_matrix(input) 

117 and some 

118 and compute_uv 

119 and 16 <= k <= 64 

120 and max(m, n) <= 1024 

121 ) 

122 

123 

124def _can_use_gram_jacobi_kernel(input, some=True, compute_uv=True): 

125 _, m, n = _svd_shape(input) 

126 k = min(m, n) 

127 return ( 

128 _is_float32_cuda_matrix(input) 

129 and some 

130 and compute_uv 

131 and 16 <= k <= 32 

132 and max(m, n) <= 64 

133 ) 

134 

135 

136def _can_use_tall_wide_gram_jacobi_kernel(input, some=True, compute_uv=True): 

137 batch, m, n = _svd_shape(input) 

138 k = min(m, n) 

139 rows = max(m, n) 

140 return ( 

141 _is_float32_cuda_matrix(input) 

142 and some 

143 and compute_uv 

144 and batch >= 128 

145 and 16 <= k <= _GRAM_TALL_WIDE_MAX_K 

146 and rows <= _GRAM_TALL_WIDE_MAX_ROWS 

147 and rows >= 2 * k 

148 ) 

149 

150 

151def _can_use_tsqr_cholesky_kernel(input, some=True, compute_uv=True): 

152 # Input-dependent TSQR safety needs a native device-side guard before dispatch. 

153 return False 

154 

155 

156def _can_use_blocked_jacobi_kernel(input, some=True, compute_uv=True): 

157 _, m, n = _svd_shape(input) 

158 k = min(m, n) 

159 return ( 

160 _is_float32_cuda_matrix(input) 

161 and some 

162 and compute_uv 

163 and 64 < k <= 512 

164 and max(m, n) <= 1024 

165 ) 

166 

167 

168def _can_use_blocked_square_project_kernel(input, some=True, compute_uv=True): 

169 batch, m, n = _svd_shape(input) 

170 k = min(m, n) 

171 return ( 

172 _is_float32_cuda_matrix(input) 

173 and some 

174 and compute_uv 

175 and batch == 1 

176 and m == n 

177 and 128 <= k <= 512 

178 ) 

179 

180 

181def _can_use_hier_block_square_project_kernel(input, some=True, compute_uv=True): 

182 batch, m, n = _svd_shape(input) 

183 k = min(m, n) 

184 return ( 

185 _is_float32_cuda_matrix(input) 

186 and some 

187 and compute_uv 

188 and batch <= 2 

189 and m == n 

190 and k in (256, 512) 

191 ) 

192 

193 

194def _can_use_projected_jacobi_kernel(input, some=True, compute_uv=True): 

195 batch, m, n = _svd_shape(input) 

196 k = min(m, n) 

197 return ( 

198 _is_float32_cuda_matrix(input) 

199 and some 

200 and compute_uv 

201 and 4 <= batch <= 32 

202 and k == 64 

203 and max(m, n) <= 128 

204 ) 

205 

206 

207def _can_use_singular_values_only(input, some=True, compute_uv=False): 

208 _, m, n = _svd_shape(input) 

209 k = min(m, n) 

210 return ( 

211 _is_float32_cuda_matrix(input) 

212 and not compute_uv 

213 and k <= 512 

214 and max(m, n) <= 1024 

215 ) 

216 

217 

218@libentry() 

219@triton.jit 

220def _small_jacobi_svd_kernel( 

221 A, 

222 A_WORK, 

223 V_WORK, 

224 U, 

225 S, 

226 V, 

227 M: tl.constexpr, 

228 N: tl.constexpr, 

229 K: tl.constexpr, 

230 ROWS: tl.constexpr, 

231 TALL: tl.constexpr, 

232 BLOCK_R: tl.constexpr, 

233 BLOCK_K: tl.constexpr, 

234 SWEEPS: tl.constexpr, 

235): 

236 pid = tl.program_id(0) 

237 rows = tl.arange(0, BLOCK_R) 

238 cols = tl.arange(0, BLOCK_K) 

239 row_mask = rows < ROWS 

240 col_mask = cols < K 

241 eps = 1.0e-20 

242 

243 a_base = A + pid * M * N 

244 aw_base = A_WORK + pid * K * ROWS 

245 vw_base = V_WORK + pid * K * K 

246 

247 for j in tl.static_range(0, K): 

248 if TALL: 

249 vals = tl.load(a_base + rows * N + j, mask=row_mask, other=0.0).to( 

250 tl.float32 

251 ) 

252 else: 

253 vals = tl.load(a_base + j * N + rows, mask=row_mask, other=0.0).to( 

254 tl.float32 

255 ) 

256 tl.store(aw_base + j * ROWS + rows, vals, mask=row_mask) 

257 ident_col = tl.where(cols == j, 1.0, 0.0) 

258 tl.store(vw_base + j * K + cols, ident_col, mask=col_mask) 

259 

260 for _ in tl.static_range(0, SWEEPS): 

261 for p in tl.static_range(0, K): 

262 for q in tl.static_range(p + 1, K): 

263 ap = tl.load(aw_base + p * ROWS + rows, mask=row_mask, other=0.0) 

264 aq = tl.load(aw_base + q * ROWS + rows, mask=row_mask, other=0.0) 

265 alpha = tl.sum(ap * ap) 

266 beta = tl.sum(aq * aq) 

267 gamma = tl.sum(ap * aq) 

268 abs_gamma = tl.abs(gamma) 

269 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps) 

270 active = abs_gamma > threshold 

271 

272 safe_gamma = tl.where(active, gamma, 1.0) 

273 tau = (beta - alpha) / (2.0 * safe_gamma) 

274 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0) 

275 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)) 

276 c = tl.rsqrt(1.0 + t * t) 

277 s_rot = t * c 

278 c = tl.where(active, c, 1.0) 

279 s_rot = tl.where(active, s_rot, 0.0) 

280 

281 new_ap = c * ap - s_rot * aq 

282 new_aq = s_rot * ap + c * aq 

283 tl.store(aw_base + p * ROWS + rows, new_ap, mask=row_mask) 

284 tl.store(aw_base + q * ROWS + rows, new_aq, mask=row_mask) 

285 

286 vp = tl.load(vw_base + p * K + cols, mask=col_mask, other=0.0) 

287 vq = tl.load(vw_base + q * K + cols, mask=col_mask, other=0.0) 

288 new_vp = c * vp - s_rot * vq 

289 new_vq = s_rot * vp + c * vq 

290 tl.store(vw_base + p * K + cols, new_vp, mask=col_mask) 

291 tl.store(vw_base + q * K + cols, new_vq, mask=col_mask) 

292 

293 s_idx = tl.arange(0, BLOCK_K) 

294 s_vals = tl.full((BLOCK_K,), 0.0, dtype=tl.float32) 

295 for j in tl.static_range(0, K): 

296 col = tl.load(aw_base + j * ROWS + rows, mask=row_mask, other=0.0) 

297 norm = tl.sqrt(tl.sum(col * col)) 

298 s_vals = tl.where(s_idx == j, norm, s_vals) 

299 

300 ranks = tl.zeros((BLOCK_K,), dtype=tl.int32) 

301 for i in tl.static_range(0, K): 

302 si = tl.sum(tl.where(s_idx == i, s_vals, 0.0)) 

303 beats = ((si > s_vals) | ((si == s_vals) & (i < s_idx))) & (s_idx < K) 

304 ranks = ranks + beats.to(tl.int32) 

305 

306 for j in tl.static_range(0, K): 

307 col = tl.load(aw_base + j * ROWS + rows, mask=row_mask, other=0.0) 

308 norm = tl.sum(tl.where(s_idx == j, s_vals, 0.0)) 

309 rank = tl.sum(tl.where(s_idx == j, ranks, 0)) 

310 inv_norm = tl.where(norm > eps, 1.0 / norm, 0.0) 

311 tl.store(S + pid * K + rank, norm) 

312 

313 basis = tl.load(vw_base + j * K + cols, mask=col_mask, other=0.0) 

314 if TALL: 

315 tl.store(U + pid * M * K + rows * K + rank, col * inv_norm, mask=row_mask) 

316 tl.store(V + pid * N * K + cols * K + rank, basis, mask=col_mask) 

317 else: 

318 tl.store(U + pid * M * K + cols * K + rank, basis, mask=col_mask) 

319 tl.store(V + pid * N * K + rows * K + rank, col * inv_norm, mask=row_mask) 

320 

321 

322@libentry() 

323@triton.jit 

324def _small_jacobi_svals_kernel( 

325 A, 

326 A_WORK, 

327 S, 

328 M: tl.constexpr, 

329 N: tl.constexpr, 

330 K: tl.constexpr, 

331 ROWS: tl.constexpr, 

332 TALL: tl.constexpr, 

333 BLOCK_R: tl.constexpr, 

334 BLOCK_K: tl.constexpr, 

335 SWEEPS: tl.constexpr, 

336): 

337 pid = tl.program_id(0) 

338 rows = tl.arange(0, BLOCK_R) 

339 s_idx = tl.arange(0, BLOCK_K) 

340 row_mask = rows < ROWS 

341 eps = 1.0e-20 

342 

343 a_base = A + pid * M * N 

344 aw_base = A_WORK + pid * K * ROWS 

345 

346 for j in tl.static_range(0, K): 

347 if TALL: 

348 vals = tl.load(a_base + rows * N + j, mask=row_mask, other=0.0).to( 

349 tl.float32 

350 ) 

351 else: 

352 vals = tl.load(a_base + j * N + rows, mask=row_mask, other=0.0).to( 

353 tl.float32 

354 ) 

355 tl.store(aw_base + j * ROWS + rows, vals, mask=row_mask) 

356 

357 for _ in tl.static_range(0, SWEEPS): 

358 for p in tl.static_range(0, K): 

359 for q in tl.static_range(p + 1, K): 

360 ap = tl.load(aw_base + p * ROWS + rows, mask=row_mask, other=0.0) 

361 aq = tl.load(aw_base + q * ROWS + rows, mask=row_mask, other=0.0) 

362 alpha = tl.sum(ap * ap) 

363 beta = tl.sum(aq * aq) 

364 gamma = tl.sum(ap * aq) 

365 abs_gamma = tl.abs(gamma) 

366 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps) 

367 active = abs_gamma > threshold 

368 

369 safe_gamma = tl.where(active, gamma, 1.0) 

370 tau = (beta - alpha) / (2.0 * safe_gamma) 

371 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0) 

372 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)) 

373 c = tl.rsqrt(1.0 + t * t) 

374 s_rot = t * c 

375 c = tl.where(active, c, 1.0) 

376 s_rot = tl.where(active, s_rot, 0.0) 

377 

378 new_ap = c * ap - s_rot * aq 

379 new_aq = s_rot * ap + c * aq 

380 tl.store(aw_base + p * ROWS + rows, new_ap, mask=row_mask) 

381 tl.store(aw_base + q * ROWS + rows, new_aq, mask=row_mask) 

382 

383 s_vals = tl.full((BLOCK_K,), 0.0, dtype=tl.float32) 

384 for j in tl.static_range(0, K): 

385 col = tl.load(aw_base + j * ROWS + rows, mask=row_mask, other=0.0) 

386 norm = tl.sqrt(tl.sum(col * col)) 

387 s_vals = tl.where(s_idx == j, norm, s_vals) 

388 

389 ranks = tl.zeros((BLOCK_K,), dtype=tl.int32) 

390 for i in tl.static_range(0, K): 

391 si = tl.sum(tl.where(s_idx == i, s_vals, 0.0)) 

392 beats = ((si > s_vals) | ((si == s_vals) & (i < s_idx))) & (s_idx < K) 

393 ranks = ranks + beats.to(tl.int32) 

394 

395 for j in tl.static_range(0, K): 

396 norm = tl.sum(tl.where(s_idx == j, s_vals, 0.0)) 

397 rank = tl.sum(tl.where(s_idx == j, ranks, 0)) 

398 tl.store(S + pid * K + rank, norm) 

399 

400 

401def _can_use_streaming_jacobi_kernel(input, some=True, compute_uv=True): 

402 _, m, n = _svd_shape(input) 

403 return ( 

404 _is_float32_cuda_matrix(input) 

405 and some 

406 and compute_uv 

407 and 16 < min(m, n) <= 64 

408 and max(m, n) <= 1024 

409 ) 

410 

411 

412def _can_use_gram_kernel(input, some=True, compute_uv=True): 

413 _, m, n = _svd_shape(input) 

414 return _is_float32_cuda_matrix(input) and some and compute_uv and min(m, n) <= 1024 

415 

416 

417@libentry() 

418@triton.jit 

419def _triton_bmm_kernel( 

420 A, 

421 B, 

422 C, 

423 stride_ab, 

424 stride_am, 

425 stride_ak, 

426 stride_bb, 

427 stride_bk, 

428 stride_bn, 

429 M: tl.constexpr, 

430 N: tl.constexpr, 

431 K: tl.constexpr, 

432 BLOCK_M: tl.constexpr, 

433 BLOCK_N: tl.constexpr, 

434 BLOCK_K: tl.constexpr, 

435): 

436 tile = tl.program_id(0) 

437 batch = tl.program_id(1) 

438 tiles_n = tl.cdiv(N, BLOCK_N) 

439 tile_m = tile // tiles_n 

440 tile_n = tile - tile_m * tiles_n 

441 

442 offs_m = tile_m * BLOCK_M + tl.arange(0, BLOCK_M) 

443 offs_n = tile_n * BLOCK_N + tl.arange(0, BLOCK_N) 

444 offs_k = tl.arange(0, BLOCK_K) 

445 

446 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

447 a_base = A + batch * stride_ab 

448 b_base = B + batch * stride_bb 

449 for k_start in range(0, K, BLOCK_K): 

450 k = k_start + offs_k 

451 a = tl.load( 

452 a_base + offs_m[:, None] * stride_am + k[None, :] * stride_ak, 

453 mask=(offs_m[:, None] < M) & (k[None, :] < K), 

454 other=0.0, 

455 ) 

456 b = tl.load( 

457 b_base + k[:, None] * stride_bk + offs_n[None, :] * stride_bn, 

458 mask=(k[:, None] < K) & (offs_n[None, :] < N), 

459 other=0.0, 

460 ) 

461 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

462 

463 tl.store( 

464 C + batch * M * N + offs_m[:, None] * N + offs_n[None, :], 

465 acc, 

466 mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), 

467 ) 

468 

469 

470def _triton_bmm(left, right, out_shape): 

471 batch, m, k = left.shape 

472 right_batch, right_k, n = right.shape 

473 assert batch == right_batch, "Batch dim mismatch" 

474 assert k == right_k, "K dim mismatch" 

475 out = torch.empty((batch, m, n), dtype=left.dtype, device=left.device) 

476 block_m = 16 if m <= 16 else 32 

477 block_n = 16 if n <= 16 else 32 

478 block_k = 32 

479 grid = (triton.cdiv(m, block_m) * triton.cdiv(n, block_n), batch) 

480 with torch_device_fn.device(left.device): 

481 _triton_bmm_kernel[grid]( 

482 left, 

483 right, 

484 out, 

485 left.stride(0), 

486 left.stride(1), 

487 left.stride(2), 

488 right.stride(0), 

489 right.stride(1), 

490 right.stride(2), 

491 M=m, 

492 N=n, 

493 K=k, 

494 BLOCK_M=block_m, 

495 BLOCK_N=block_n, 

496 BLOCK_K=block_k, 

497 num_warps=1 if block_m == 16 and block_n == 16 else 4, 

498 ) 

499 return out.reshape(out_shape) 

500 

501 

502@libentry() 

503@triton.jit 

504def _gram_build_tiled_kernel( 

505 A, 

506 GRAM, 

507 M: tl.constexpr, 

508 N: tl.constexpr, 

509 K: tl.constexpr, 

510 ROWS: tl.constexpr, 

511 TALL: tl.constexpr, 

512 BLOCK_I: tl.constexpr, 

513 BLOCK_J: tl.constexpr, 

514 BLOCK_R: tl.constexpr, 

515): 

516 tile_i = tl.program_id(0) 

517 tile_j = tl.program_id(1) 

518 batch = tl.program_id(2) 

519 offs_i = tile_i * BLOCK_I + tl.arange(0, BLOCK_I) 

520 offs_j = tile_j * BLOCK_J + tl.arange(0, BLOCK_J) 

521 rows = tl.arange(0, BLOCK_R) 

522 i_mask = offs_i < K 

523 j_mask = offs_j < K 

524 a_base = A + batch * M * N 

525 acc = tl.zeros((BLOCK_I, BLOCK_J), dtype=tl.float32) 

526 

527 for row_start in range(0, ROWS, BLOCK_R): 

528 chunk_rows = row_start + rows 

529 row_mask = chunk_rows < ROWS 

530 if TALL: 

531 lhs = tl.load( 

532 a_base + chunk_rows[None, :] * N + offs_i[:, None], 

533 mask=i_mask[:, None] & row_mask[None, :], 

534 other=0.0, 

535 ).to(tl.float32) 

536 rhs = tl.load( 

537 a_base + chunk_rows[:, None] * N + offs_j[None, :], 

538 mask=row_mask[:, None] & j_mask[None, :], 

539 other=0.0, 

540 ).to(tl.float32) 

541 else: 

542 lhs = tl.load( 

543 a_base + offs_i[:, None] * N + chunk_rows[None, :], 

544 mask=i_mask[:, None] & row_mask[None, :], 

545 other=0.0, 

546 ).to(tl.float32) 

547 rhs = tl.load( 

548 a_base + offs_j[None, :] * N + chunk_rows[:, None], 

549 mask=row_mask[:, None] & j_mask[None, :], 

550 other=0.0, 

551 ).to(tl.float32) 

552 acc += tl.dot(lhs, rhs, out_dtype=tl.float32, allow_tf32=False) 

553 

554 tl.store( 

555 GRAM + batch * K * K + offs_i[:, None] * K + offs_j[None, :], 

556 acc, 

557 mask=i_mask[:, None] & j_mask[None, :], 

558 ) 

559 

560 

561@libentry() 

562@triton.jit 

563def _cholesky_upper_kernel( 

564 GRAM, 

565 R, 

566 STATUS, 

567 K: tl.constexpr, 

568 BLOCK_K: tl.constexpr, 

569): 

570 batch = tl.program_id(0) 

571 cols = tl.arange(0, BLOCK_K) 

572 col_mask = cols < K 

573 base_g = GRAM + batch * K * K 

574 base_r = R + batch * K * K 

575 

576 diag_vals = tl.load( 

577 base_g + cols * K + cols, 

578 mask=col_mask, 

579 other=0.0, 

580 ).to(tl.float32) 

581 max_diag = tl.max(tl.abs(diag_vals), axis=0) 

582 tol = tl.maximum(max_diag * 1.0e-8, 1.0e-20) 

583 status = tl.full((), 0, dtype=tl.int32) 

584 finite_limit = 3.4028234663852886e38 

585 

586 j = 0 

587 while j < K: 

588 row_mask = col_mask & (cols >= j) 

589 gram_row = tl.load( 

590 base_g + j * K + cols, 

591 mask=row_mask, 

592 other=0.0, 

593 ).to(tl.float32) 

594 diag = tl.load(base_g + j * K + j).to(tl.float32) 

595 

596 p = 0 

597 while p < j: 

598 r_pj = tl.load(base_r + p * K + j).to(tl.float32) 

599 r_pcols = tl.load( 

600 base_r + p * K + cols, 

601 mask=row_mask, 

602 other=0.0, 

603 ).to(tl.float32) 

604 gram_row -= r_pj * r_pcols 

605 diag -= r_pj * r_pj 

606 p += 1 

607 

608 good_diag = (diag == diag) & (tl.abs(diag) < finite_limit) & (diag > tol) 

609 pivot = tl.sqrt(tl.maximum(diag, tol)) 

610 r_vals = gram_row / pivot 

611 r_vals = tl.where(cols == j, pivot, r_vals) 

612 r_vals = tl.where(row_mask, r_vals, 0.0) 

613 bad_vals = tl.sum( 

614 (((r_vals != r_vals) | (tl.abs(r_vals) >= finite_limit)) & row_mask).to( 

615 tl.int32 

616 ), 

617 axis=0, 

618 ) 

619 status = tl.where(good_diag & (bad_vals == 0), status, 1) 

620 tl.store(base_r + j * K + cols, r_vals, mask=col_mask) 

621 j += 1 

622 

623 tl.store(STATUS + batch, status) 

624 

625 

626def _tsqr_guard_fallback_svd(input): 

627 _, m, n = _svd_shape(input) 

628 k = min(m, n) 

629 if 16 <= k <= 64 and max(m, n) <= 1024: 

630 return _cyclic_jacobi_svd(input) 

631 if 64 < k <= 512 and max(m, n) <= 1024: 

632 return _blocked_jacobi_svd(input) 

633 return _unsupported_svd( 

634 input, 

635 True, 

636 True, 

637 "TSQR/Cholesky guard could not find a native Jacobi fallback.", 

638 ) 

639 

640 

641def _tsqr_cholesky_svd(input): 

642 batch, m, n = _svd_shape(input) 

643 k = min(m, n) 

644 rows = max(m, n) 

645 tall = m >= n 

646 a = input.contiguous().reshape(batch, m, n) 

647 gram = torch.empty((batch, k, k), dtype=torch.float32, device=input.device) 

648 r = torch.empty((batch, k, k), dtype=torch.float32, device=input.device) 

649 status = torch.empty((batch,), dtype=torch.int32, device=input.device) 

650 block_k = triton.next_power_of_2(k) 

651 block_tile = 32 

652 block_r = 64 

653 

654 with torch_device_fn.device(input.device): 

655 _gram_build_tiled_kernel[ 

656 ( 

657 triton.cdiv(k, block_tile), 

658 triton.cdiv(k, block_tile), 

659 batch, 

660 ) 

661 ]( 

662 a, 

663 gram, 

664 M=m, 

665 N=n, 

666 K=k, 

667 ROWS=rows, 

668 TALL=tall, 

669 BLOCK_I=block_tile, 

670 BLOCK_J=block_tile, 

671 BLOCK_R=block_r, 

672 num_warps=4, 

673 ) 

674 _cholesky_upper_kernel[(batch,)]( 

675 gram, 

676 r, 

677 status, 

678 K=k, 

679 BLOCK_K=block_k, 

680 num_warps=4, 

681 ) 

682 

683 _, s, basis = svd(r, some=True, compute_uv=True) 

684 basis = basis.reshape(batch, k, k) 

685 s = s.reshape(batch, k) 

686 

687 if tall: 

688 u = _triton_bmm(a, basis, (batch, m, k)) 

689 v = basis 

690 projected = u 

691 projected_rows = m 

692 else: 

693 u = basis 

694 v = _triton_bmm(a.transpose(1, 2).contiguous(), basis, (batch, n, k)) 

695 projected = v 

696 projected_rows = n 

697 

698 with torch_device_fn.device(input.device): 

699 _normalize_projection_kernel[(batch, k)]( 

700 projected, 

701 s, 

702 ROWS=projected_rows, 

703 K=k, 

704 BLOCK_R=triton.next_power_of_2(projected_rows), 

705 num_warps=1 if projected_rows <= 64 else 4, 

706 ) 

707 

708 return ( 

709 u.reshape(*input.shape[:-2], m, k), 

710 s.reshape(*input.shape[:-2], k), 

711 v.reshape(*input.shape[:-2], n, k), 

712 ) 

713 

714 

715@libentry() 

716@triton.jit 

717def _gram_build_kernel( 

718 A, 

719 GRAM, 

720 M: tl.constexpr, 

721 N: tl.constexpr, 

722 K: tl.constexpr, 

723 ROWS: tl.constexpr, 

724 TALL: tl.constexpr, 

725 BLOCK_K: tl.constexpr, 

726 BLOCK_R: tl.constexpr, 

727): 

728 batch = tl.program_id(0) 

729 i = tl.arange(0, BLOCK_K) 

730 j = tl.arange(0, BLOCK_K) 

731 rows = tl.arange(0, BLOCK_R) 

732 k_mask = i < K 

733 a_base = A + batch * M * N 

734 acc = tl.zeros((BLOCK_K, BLOCK_K), dtype=tl.float32) 

735 

736 for row_start in range(0, ROWS, BLOCK_R): 

737 chunk_rows = row_start + rows 

738 row_mask = chunk_rows < ROWS 

739 

740 if TALL: 

741 lhs = tl.load( 

742 a_base + chunk_rows[None, :] * N + i[:, None], 

743 mask=k_mask[:, None] & row_mask[None, :], 

744 other=0.0, 

745 ).to(tl.float32) 

746 rhs = tl.load( 

747 a_base + chunk_rows[:, None] * N + j[None, :], 

748 mask=row_mask[:, None] & (j[None, :] < K), 

749 other=0.0, 

750 ).to(tl.float32) 

751 else: 

752 lhs = tl.load( 

753 a_base + i[:, None] * N + chunk_rows[None, :], 

754 mask=k_mask[:, None] & row_mask[None, :], 

755 other=0.0, 

756 ).to(tl.float32) 

757 rhs = tl.load( 

758 a_base + j[None, :] * N + chunk_rows[:, None], 

759 mask=row_mask[:, None] & (j[None, :] < K), 

760 other=0.0, 

761 ).to(tl.float32) 

762 

763 acc += tl.dot(lhs, rhs, out_dtype=tl.float32, allow_tf32=False) 

764 

765 tl.store( 

766 GRAM + batch * K * K + i[:, None] * K + j[None, :], 

767 acc, 

768 mask=k_mask[:, None] & (j[None, :] < K), 

769 ) 

770 

771 

772@libentry() 

773@triton.jit 

774def _gram_jacobi_sym_kernel( 

775 GRAM, 

776 EVECS, 

777 EVALS, 

778 K, 

779 SWEEPS, 

780 BLOCK_K: tl.constexpr, 

781): 

782 batch = tl.program_id(0) 

783 r = tl.arange(0, BLOCK_K) 

784 cidx = tl.arange(0, BLOCK_K) 

785 rr = r[:, None] 

786 cc = cidx[None, :] 

787 mask = (rr < K) & (cc < K) 

788 base = GRAM + batch * K * K 

789 

790 g = tl.load(base + rr * K + cc, mask=mask, other=0.0).to(tl.float32) 

791 v = tl.where((rr == cc) & mask, 1.0, 0.0) 

792 eps = 1.0e-20 

793 

794 sweep = 0 

795 while sweep < SWEEPS: 

796 p = 0 

797 while p < K - 1: 

798 q = p + 1 

799 while q < K: 

800 diag_p = tl.sum(tl.where((rr == p) & (cc == p), g, 0.0)) 

801 diag_q = tl.sum(tl.where((rr == q) & (cc == q), g, 0.0)) 

802 off = tl.sum(tl.where((rr == p) & (cc == q), g, 0.0)) 

803 abs_off = tl.abs(off) 

804 threshold = 1.0e-7 * tl.sqrt(tl.abs(diag_p * diag_q) + eps) 

805 active = abs_off > threshold 

806 safe_off = tl.where(active, off, 1.0) 

807 tau = (diag_q - diag_p) / (2.0 * safe_off) 

808 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0) 

809 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)) 

810 crot = tl.rsqrt(1.0 + t * t) 

811 srot = t * crot 

812 crot = tl.where(active, crot, 1.0) 

813 srot = tl.where(active, srot, 0.0) 

814 

815 col_p = tl.sum(tl.where(cc == p, g, 0.0), axis=1) 

816 col_q = tl.sum(tl.where(cc == q, g, 0.0), axis=1) 

817 row_p = tl.sum(tl.where(rr == p, g, 0.0), axis=0) 

818 row_q = tl.sum(tl.where(rr == q, g, 0.0), axis=0) 

819 

820 new_col_p = crot * col_p - srot * col_q 

821 new_col_q = srot * col_p + crot * col_q 

822 new_row_p = crot * row_p - srot * row_q 

823 new_row_q = srot * row_p + crot * row_q 

824 g = tl.where(cc == p, new_col_p[:, None], g) 

825 g = tl.where(cc == q, new_col_q[:, None], g) 

826 g = tl.where(rr == p, new_row_p[None, :], g) 

827 g = tl.where(rr == q, new_row_q[None, :], g) 

828 

829 new_pp = ( 

830 crot * crot * diag_p 

831 - 2.0 * crot * srot * off 

832 + srot * srot * diag_q 

833 ) 

834 new_qq = ( 

835 srot * srot * diag_p 

836 + 2.0 * crot * srot * off 

837 + crot * crot * diag_q 

838 ) 

839 g = tl.where((rr == p) & (cc == p), new_pp, g) 

840 g = tl.where((rr == q) & (cc == q), new_qq, g) 

841 g = tl.where(((rr == p) & (cc == q)) | ((rr == q) & (cc == p)), 0.0, g) 

842 

843 vec_p = tl.sum(tl.where(cc == p, v, 0.0), axis=1) 

844 vec_q = tl.sum(tl.where(cc == q, v, 0.0), axis=1) 

845 new_vec_p = crot * vec_p - srot * vec_q 

846 new_vec_q = srot * vec_p + crot * vec_q 

847 v = tl.where(cc == p, new_vec_p[:, None], v) 

848 v = tl.where(cc == q, new_vec_q[:, None], v) 

849 q += 1 

850 p += 1 

851 sweep += 1 

852 

853 diag = tl.sum(tl.where(rr == cc, g, 0.0), axis=1) 

854 tl.store(EVALS + batch * K + r, diag, mask=r < K) 

855 tl.store(EVECS + batch * K * K + rr * K + cc, v, mask=mask) 

856 

857 

858@libentry() 

859@triton.jit 

860def _gram_sort_basis_kernel( 

861 EVALS, 

862 EVECS, 

863 BASIS, 

864 S, 

865 K: tl.constexpr, 

866 BLOCK_K: tl.constexpr, 

867): 

868 batch = tl.program_id(0) 

869 col = tl.program_id(1) 

870 rows = tl.arange(0, BLOCK_K) 

871 row_mask = rows < K 

872 eval_col = tl.maximum(tl.load(EVALS + batch * K + col), 0.0) 

873 rank = tl.full((), 0, dtype=tl.int32) 

874 for other in tl.static_range(0, K): 

875 eval_other = tl.maximum(tl.load(EVALS + batch * K + other), 0.0) 

876 rank += ( 

877 (eval_other > eval_col) | ((eval_other == eval_col) & (other < col)) 

878 ).to(tl.int32) 

879 

880 vec = tl.load( 

881 EVECS + batch * K * K + rows * K + col, 

882 mask=row_mask, 

883 other=0.0, 

884 ) 

885 tl.store(S + batch * K + rank, tl.sqrt(eval_col)) 

886 tl.store( 

887 BASIS + batch * K * K + rows * K + rank, 

888 vec, 

889 mask=row_mask, 

890 ) 

891 

892 

893@libentry() 

894@triton.jit 

895def _normalize_projection_kernel( 

896 Q, 

897 S, 

898 ROWS: tl.constexpr, 

899 K: tl.constexpr, 

900 BLOCK_R: tl.constexpr, 

901): 

902 batch = tl.program_id(0) 

903 col = tl.program_id(1) 

904 rows = tl.arange(0, BLOCK_R) 

905 mask = rows < ROWS 

906 eps = 1.0e-20 

907 sval = tl.load(S + batch * K + col) 

908 vals = tl.load(Q + batch * ROWS * K + rows * K + col, mask=mask, other=0.0) 

909 vals = vals / tl.maximum(sval, eps) 

910 tl.store(Q + batch * ROWS * K + rows * K + col, vals, mask=mask) 

911 

912 

913@libentry() 

914@triton.jit 

915def _renorm_projection_update_s_kernel( 

916 Q, 

917 S, 

918 ROWS: tl.constexpr, 

919 K: tl.constexpr, 

920 BLOCK_R: tl.constexpr, 

921): 

922 batch = tl.program_id(0) 

923 col = tl.program_id(1) 

924 rows = tl.arange(0, BLOCK_R) 

925 mask = rows < ROWS 

926 vals = tl.load(Q + batch * ROWS * K + rows * K + col, mask=mask, other=0.0) 

927 vals_f32 = vals.to(tl.float32) 

928 norm = tl.sqrt(tl.sum(vals_f32 * vals_f32, axis=0)) 

929 inv_norm = tl.rsqrt(tl.maximum(norm * norm, 1.0e-40)) 

930 basis = tl.where(rows == col, 1.0, 0.0) 

931 vals = tl.where(norm <= 1.0e-20, basis, vals * inv_norm) 

932 tl.store(S + batch * K + col, norm) 

933 tl.store(Q + batch * ROWS * K + rows * K + col, vals, mask=mask) 

934 

935 

936@libentry() 

937@triton.jit 

938def _complete_zero_projection_kernel( 

939 Q, 

940 S, 

941 ROWS: tl.constexpr, 

942 K: tl.constexpr, 

943 BLOCK_R: tl.constexpr, 

944): 

945 batch = tl.program_id(0) 

946 col = tl.program_id(1) 

947 rows = tl.arange(0, BLOCK_R) 

948 mask = rows < ROWS 

949 eps = 1.0e-12 

950 sval = tl.load(S + batch * K + col) 

951 basis = tl.where(rows == col, 1.0, 0.0) 

952 old = tl.load(Q + batch * ROWS * K + rows * K + col, mask=mask, other=0.0) 

953 vals = tl.where(sval <= eps, basis, old) 

954 tl.store(Q + batch * ROWS * K + rows * K + col, vals, mask=mask) 

955 

956 

957def _gram_jacobi_svd(input): 

958 batch, m, n = _svd_shape(input) 

959 k = min(m, n) 

960 rows = max(m, n) 

961 tall = m >= n 

962 a = input.contiguous().reshape(batch, m, n) 

963 gram = torch.empty((batch, k, k), dtype=torch.float32, device=input.device) 

964 eigvecs = torch.empty((batch, k, k), dtype=torch.float32, device=input.device) 

965 evals = torch.empty((batch, k), dtype=torch.float32, device=input.device) 

966 basis = torch.empty((batch, k, k), dtype=torch.float32, device=input.device) 

967 s = torch.empty((batch, k), dtype=input.dtype, device=input.device) 

968 block_k = triton.next_power_of_2(k) 

969 block_r = min(triton.next_power_of_2(rows), 64 if k > 32 else 128) 

970 sweeps = 12 if k <= 17 else 10 

971 

972 with torch_device_fn.device(input.device): 

973 _gram_build_kernel[(batch,)]( 

974 a, 

975 gram, 

976 M=m, 

977 N=n, 

978 K=k, 

979 ROWS=rows, 

980 TALL=tall, 

981 BLOCK_K=block_k, 

982 BLOCK_R=block_r, 

983 num_warps=4, 

984 ) 

985 _gram_jacobi_sym_kernel[(batch,)]( 

986 gram, 

987 eigvecs, 

988 evals, 

989 k, 

990 sweeps, 

991 BLOCK_K=block_k, 

992 num_warps=4, 

993 ) 

994 with torch_device_fn.device(input.device): 

995 _gram_sort_basis_kernel[(batch, k)]( 

996 evals, 

997 eigvecs, 

998 basis, 

999 s, 

1000 K=k, 

1001 BLOCK_K=block_k, 

1002 num_warps=1, 

1003 ) 

1004 

1005 if tall: 

1006 u = _triton_bmm(a, basis, (batch, m, k)) 

1007 v = basis 

1008 proj_rows = m 

1009 else: 

1010 a_t = a.transpose(1, 2).contiguous() 

1011 v = _triton_bmm(a_t, basis, (batch, n, k)) 

1012 u = basis 

1013 proj_rows = n 

1014 

1015 with torch_device_fn.device(input.device): 

1016 _renorm_projection_update_s_kernel[(batch, k)]( 

1017 u if tall else v, 

1018 s, 

1019 ROWS=proj_rows, 

1020 K=k, 

1021 BLOCK_R=triton.next_power_of_2(proj_rows), 

1022 num_warps=1 if proj_rows <= 64 else 4, 

1023 ) 

1024 _complete_zero_projection_kernel[(batch, k)]( 

1025 u if tall else v, 

1026 s, 

1027 ROWS=proj_rows, 

1028 K=k, 

1029 BLOCK_R=triton.next_power_of_2(proj_rows), 

1030 num_warps=1 if proj_rows <= 64 else 4, 

1031 ) 

1032 if k <= _GRAM_TALL_WIDE_MAX_K: 

1033 _thin_reorthogonalize_kernel[(batch,)]( 

1034 u if tall else v, 

1035 ROWS=proj_rows, 

1036 K=k, 

1037 BLOCK_R=triton.next_power_of_2(proj_rows), 

1038 num_warps=1 if proj_rows <= 64 else 4, 

1039 ) 

1040 

1041 return ( 

1042 u.reshape(*input.shape[:-2], m, k), 

1043 s.reshape(*input.shape[:-2], k), 

1044 v.reshape(*input.shape[:-2], n, k), 

1045 ) 

1046 

1047 

1048@triton.jit 

1049def _rotate_pair_4(ap, aq, vp, vq): 

1050 eps = 1.0e-20 

1051 alpha = tl.sum(ap * ap, axis=1) 

1052 beta = tl.sum(aq * aq, axis=1) 

1053 gamma = tl.sum(ap * aq, axis=1) 

1054 abs_gamma = tl.abs(gamma) 

1055 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps) 

1056 active = abs_gamma > threshold 

1057 safe_gamma = tl.where(active, gamma, 1.0) 

1058 tau = (beta - alpha) / (2.0 * safe_gamma) 

1059 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0) 

1060 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)) 

1061 c = tl.rsqrt(1.0 + t * t) 

1062 s_rot = t * c 

1063 c = tl.where(active, c, 1.0) 

1064 s_rot = tl.where(active, s_rot, 0.0) 

1065 new_ap = c[:, None] * ap - s_rot[:, None] * aq 

1066 new_aq = s_rot[:, None] * ap + c[:, None] * aq 

1067 new_vp = c[:, None] * vp - s_rot[:, None] * vq 

1068 new_vq = s_rot[:, None] * vp + c[:, None] * vq 

1069 return new_ap, new_aq, new_vp, new_vq 

1070 

1071 

1072@libentry() 

1073@triton.jit 

1074def _small4_square_svd_kernel( 

1075 A, 

1076 U, 

1077 S, 

1078 V, 

1079 BATCH: tl.constexpr, 

1080 BLOCK_B: tl.constexpr, 

1081 SWEEPS: tl.constexpr, 

1082): 

1083 pid = tl.program_id(0) 

1084 b = pid * BLOCK_B + tl.arange(0, BLOCK_B) 

1085 r = tl.arange(0, 4) 

1086 bb = b[:, None] 

1087 rr = r[None, :] 

1088 mask = b < BATCH 

1089 full_mask = (bb < BATCH) & (rr < 4) 

1090 base = A + bb * 16 + rr * 4 

1091 

1092 c0 = tl.load(base, mask=full_mask, other=0.0).to(tl.float32) 

1093 c1 = tl.load(base + 1, mask=full_mask, other=0.0).to(tl.float32) 

1094 c2 = tl.load(base + 2, mask=full_mask, other=0.0).to(tl.float32) 

1095 c3 = tl.load(base + 3, mask=full_mask, other=0.0).to(tl.float32) 

1096 

1097 v0 = tl.where(rr == 0, 1.0, 0.0) 

1098 v1 = tl.where(rr == 1, 1.0, 0.0) 

1099 v2 = tl.where(rr == 2, 1.0, 0.0) 

1100 v3 = tl.where(rr == 3, 1.0, 0.0) 

1101 

1102 for _ in tl.static_range(0, SWEEPS): 

1103 c0, c1, v0, v1 = _rotate_pair_4(c0, c1, v0, v1) 

1104 c0, c2, v0, v2 = _rotate_pair_4(c0, c2, v0, v2) 

1105 c0, c3, v0, v3 = _rotate_pair_4(c0, c3, v0, v3) 

1106 c1, c2, v1, v2 = _rotate_pair_4(c1, c2, v1, v2) 

1107 c1, c3, v1, v3 = _rotate_pair_4(c1, c3, v1, v3) 

1108 c2, c3, v2, v3 = _rotate_pair_4(c2, c3, v2, v3) 

1109 

1110 s0 = tl.sqrt(tl.sum(c0 * c0, axis=1)) 

1111 s1 = tl.sqrt(tl.sum(c1 * c1, axis=1)) 

1112 s2 = tl.sqrt(tl.sum(c2 * c2, axis=1)) 

1113 s3 = tl.sqrt(tl.sum(c3 * c3, axis=1)) 

1114 r0 = (s1 > s0).to(tl.int32) + (s2 > s0).to(tl.int32) + (s3 > s0).to(tl.int32) 

1115 r1 = ((s0 >= s1).to(tl.int32)) + (s2 > s1).to(tl.int32) + (s3 > s1).to(tl.int32) 

1116 r2 = ((s0 >= s2).to(tl.int32)) + ((s1 >= s2).to(tl.int32)) + (s3 > s2).to(tl.int32) 

1117 r3 = ( 

1118 ((s0 >= s3).to(tl.int32)) 

1119 + ((s1 >= s3).to(tl.int32)) 

1120 + ((s2 >= s3).to(tl.int32)) 

1121 ) 

1122 eps = 1.0e-20 

1123 

1124 tl.store(S + b * 4 + r0, s0, mask=mask) 

1125 tl.store(S + b * 4 + r1, s1, mask=mask) 

1126 tl.store(S + b * 4 + r2, s2, mask=mask) 

1127 tl.store(S + b * 4 + r3, s3, mask=mask) 

1128 

1129 tl.store( 

1130 U + bb * 16 + rr * 4 + r0[:, None], 

1131 c0 / tl.maximum(s0[:, None], eps), 

1132 mask=full_mask, 

1133 ) 

1134 tl.store( 

1135 U + bb * 16 + rr * 4 + r1[:, None], 

1136 c1 / tl.maximum(s1[:, None], eps), 

1137 mask=full_mask, 

1138 ) 

1139 tl.store( 

1140 U + bb * 16 + rr * 4 + r2[:, None], 

1141 c2 / tl.maximum(s2[:, None], eps), 

1142 mask=full_mask, 

1143 ) 

1144 tl.store( 

1145 U + bb * 16 + rr * 4 + r3[:, None], 

1146 c3 / tl.maximum(s3[:, None], eps), 

1147 mask=full_mask, 

1148 ) 

1149 

1150 tl.store(V + bb * 16 + rr * 4 + r0[:, None], v0, mask=full_mask) 

1151 tl.store(V + bb * 16 + rr * 4 + r1[:, None], v1, mask=full_mask) 

1152 tl.store(V + bb * 16 + rr * 4 + r2[:, None], v2, mask=full_mask) 

1153 tl.store(V + bb * 16 + rr * 4 + r3[:, None], v3, mask=full_mask) 

1154 

1155 

1156@libentry() 

1157@triton.jit 

1158def _rank2_svd_tiny_kernel( 

1159 A, 

1160 U, 

1161 S, 

1162 V, 

1163 BATCH: tl.constexpr, 

1164 M: tl.constexpr, 

1165 N: tl.constexpr, 

1166 TALL: tl.constexpr, 

1167 BLOCK_B: tl.constexpr, 

1168 BLOCK_R: tl.constexpr, 

1169): 

1170 pid = tl.program_id(0) 

1171 b = pid * BLOCK_B + tl.arange(0, BLOCK_B) 

1172 r = tl.arange(0, BLOCK_R) 

1173 bb = b[:, None] 

1174 rr = r[None, :] 

1175 bmask = b < BATCH 

1176 eps = 1.0e-20 

1177 

1178 if TALL: 

1179 mask = (bb < BATCH) & (rr < M) 

1180 base = A + bb * M * N + rr * N 

1181 x = tl.load(base, mask=mask, other=0.0).to(tl.float32) 

1182 y = tl.load(base + 1, mask=mask, other=0.0).to(tl.float32) 

1183 else: 

1184 mask = (bb < BATCH) & (rr < N) 

1185 base = A + bb * M * N + rr 

1186 x = tl.load(base, mask=mask, other=0.0).to(tl.float32) 

1187 y = tl.load(base + N, mask=mask, other=0.0).to(tl.float32) 

1188 

1189 aa = tl.sum(x * x, axis=1) 

1190 bbv = tl.sum(y * y, axis=1) 

1191 ab = tl.sum(x * y, axis=1) 

1192 diff = aa - bbv 

1193 root = tl.sqrt(diff * diff + 4.0 * ab * ab) 

1194 l0 = tl.maximum(0.0, 0.5 * (aa + bbv + root)) 

1195 det = tl.maximum(0.0, aa * bbv - ab * ab) 

1196 l1 = tl.where(l0 > eps, det / l0, 0.0) 

1197 s0 = tl.sqrt(l0) 

1198 s1 = tl.sqrt(l1) 

1199 

1200 ab_abs = tl.abs(ab) 

1201 aa_ge_bb = aa >= bbv 

1202 vx0 = tl.where(ab_abs > eps, ab, tl.where(aa_ge_bb, 1.0, 0.0)) 

1203 vy0 = tl.where(ab_abs > eps, l0 - aa, tl.where(aa_ge_bb, 0.0, 1.0)) 

1204 inv_norm = tl.rsqrt(vx0 * vx0 + vy0 * vy0 + eps) 

1205 vx0 = vx0 * inv_norm 

1206 vy0 = vy0 * inv_norm 

1207 vx1 = -vy0 

1208 vy1 = vx0 

1209 

1210 tl.store(S + b * 2, s0, mask=bmask) 

1211 tl.store(S + b * 2 + 1, s1, mask=bmask) 

1212 inv_s0 = tl.where(s0 > eps, 1.0 / s0, 0.0) 

1213 inv_s1 = tl.where(s1 > eps, 1.0 / s1, 0.0) 

1214 

1215 if TALL: 

1216 u0 = (x * vx0[:, None] + y * vy0[:, None]) * inv_s0[:, None] 

1217 u1 = (x * vx1[:, None] + y * vy1[:, None]) * inv_s1[:, None] 

1218 ubase = U + bb * M * 2 + rr * 2 

1219 tl.store(ubase, u0, mask=mask) 

1220 tl.store(ubase + 1, u1, mask=mask) 

1221 vbase = V + b * 4 

1222 tl.store(vbase, vx0, mask=bmask) 

1223 tl.store(vbase + 1, vx1, mask=bmask) 

1224 tl.store(vbase + 2, vy0, mask=bmask) 

1225 tl.store(vbase + 3, vy1, mask=bmask) 

1226 else: 

1227 ubase = U + b * 4 

1228 tl.store(ubase, vx0, mask=bmask) 

1229 tl.store(ubase + 1, vx1, mask=bmask) 

1230 tl.store(ubase + 2, vy0, mask=bmask) 

1231 tl.store(ubase + 3, vy1, mask=bmask) 

1232 v0 = (x * vx0[:, None] + y * vy0[:, None]) * inv_s0[:, None] 

1233 v1 = (x * vx1[:, None] + y * vy1[:, None]) * inv_s1[:, None] 

1234 vbase = V + bb * N * 2 + rr * 2 

1235 tl.store(vbase, v0, mask=mask) 

1236 tl.store(vbase + 1, v1, mask=mask) 

1237 

1238 

1239@libentry() 

1240@triton.jit 

1241def _rank2_svals_tiny_kernel( 

1242 A, 

1243 S, 

1244 BATCH: tl.constexpr, 

1245 M: tl.constexpr, 

1246 N: tl.constexpr, 

1247 TALL: tl.constexpr, 

1248 BLOCK_B: tl.constexpr, 

1249 BLOCK_R: tl.constexpr, 

1250): 

1251 pid = tl.program_id(0) 

1252 b = pid * BLOCK_B + tl.arange(0, BLOCK_B) 

1253 r = tl.arange(0, BLOCK_R) 

1254 bb = b[:, None] 

1255 rr = r[None, :] 

1256 bmask = b < BATCH 

1257 

1258 if TALL: 

1259 mask = (bb < BATCH) & (rr < M) 

1260 base = A + bb * M * N + rr * N 

1261 x = tl.load(base, mask=mask, other=0.0).to(tl.float32) 

1262 y = tl.load(base + 1, mask=mask, other=0.0).to(tl.float32) 

1263 else: 

1264 mask = (bb < BATCH) & (rr < N) 

1265 base = A + bb * M * N + rr 

1266 x = tl.load(base, mask=mask, other=0.0).to(tl.float32) 

1267 y = tl.load(base + N, mask=mask, other=0.0).to(tl.float32) 

1268 

1269 aa = tl.sum(x * x, axis=1) 

1270 bbv = tl.sum(y * y, axis=1) 

1271 ab = tl.sum(x * y, axis=1) 

1272 diff = aa - bbv 

1273 root = tl.sqrt(diff * diff + 4.0 * ab * ab) 

1274 l0 = tl.maximum(0.0, 0.5 * (aa + bbv + root)) 

1275 det = tl.maximum(0.0, aa * bbv - ab * ab) 

1276 l1 = tl.where(l0 > 1.0e-20, det / l0, 0.0) 

1277 tl.store(S + b * 2, tl.sqrt(l0), mask=bmask) 

1278 tl.store(S + b * 2 + 1, tl.sqrt(l1), mask=bmask) 

1279 

1280 

1281@libentry() 

1282@triton.jit 

1283def _rank2_svals_kernel( 

1284 A, 

1285 S, 

1286 M: tl.constexpr, 

1287 N: tl.constexpr, 

1288 TALL: tl.constexpr, 

1289 BLOCK_R: tl.constexpr, 

1290): 

1291 pid = tl.program_id(0) 

1292 offs = tl.arange(0, BLOCK_R) 

1293 

1294 if TALL: 

1295 mask = offs < M 

1296 base = A + pid * M * N 

1297 x = tl.load(base + offs * N, mask=mask, other=0.0).to(tl.float32) 

1298 y = tl.load(base + offs * N + 1, mask=mask, other=0.0).to(tl.float32) 

1299 else: 

1300 mask = offs < N 

1301 base = A + pid * M * N 

1302 x = tl.load(base + offs, mask=mask, other=0.0).to(tl.float32) 

1303 y = tl.load(base + N + offs, mask=mask, other=0.0).to(tl.float32) 

1304 

1305 aa = tl.sum(x * x) 

1306 bb = tl.sum(y * y) 

1307 ab = tl.sum(x * y) 

1308 diff = aa - bb 

1309 root = tl.sqrt(diff * diff + 4.0 * ab * ab) 

1310 l0 = tl.maximum(0.0, 0.5 * (aa + bb + root)) 

1311 det = tl.maximum(0.0, aa * bb - ab * ab) 

1312 l1 = tl.where(l0 > 1.0e-20, det / l0, 0.0) 

1313 

1314 sbase = S + pid * 2 

1315 tl.store(sbase, tl.sqrt(l0)) 

1316 tl.store(sbase + 1, tl.sqrt(l1)) 

1317 

1318 

1319@libentry() 

1320@triton.jit 

1321def _rank2_svd_kernel( 

1322 A, 

1323 U, 

1324 S, 

1325 V, 

1326 M: tl.constexpr, 

1327 N: tl.constexpr, 

1328 TALL: tl.constexpr, 

1329 BLOCK_R: tl.constexpr, 

1330): 

1331 pid = tl.program_id(0) 

1332 offs = tl.arange(0, BLOCK_R) 

1333 eps = 1.0e-20 

1334 

1335 if TALL: 

1336 mask = offs < M 

1337 base = A + pid * M * N 

1338 x = tl.load(base + offs * N, mask=mask, other=0.0).to(tl.float32) 

1339 y = tl.load(base + offs * N + 1, mask=mask, other=0.0).to(tl.float32) 

1340 else: 

1341 mask = offs < N 

1342 base = A + pid * M * N 

1343 x = tl.load(base + offs, mask=mask, other=0.0).to(tl.float32) 

1344 y = tl.load(base + N + offs, mask=mask, other=0.0).to(tl.float32) 

1345 

1346 aa = tl.sum(x * x) 

1347 bb = tl.sum(y * y) 

1348 ab = tl.sum(x * y) 

1349 diff = aa - bb 

1350 root = tl.sqrt(diff * diff + 4.0 * ab * ab) 

1351 l0 = tl.maximum(0.0, 0.5 * (aa + bb + root)) 

1352 det = tl.maximum(0.0, aa * bb - ab * ab) 

1353 l1 = tl.where(l0 > eps, det / l0, 0.0) 

1354 s0 = tl.sqrt(l0) 

1355 s1 = tl.sqrt(l1) 

1356 

1357 ab_abs = tl.abs(ab) 

1358 aa_ge_bb = aa >= bb 

1359 vx0 = tl.where(ab_abs > eps, ab, tl.where(aa_ge_bb, 1.0, 0.0)) 

1360 vy0 = tl.where(ab_abs > eps, l0 - aa, tl.where(aa_ge_bb, 0.0, 1.0)) 

1361 inv_norm = tl.rsqrt(vx0 * vx0 + vy0 * vy0 + eps) 

1362 vx0 = vx0 * inv_norm 

1363 vy0 = vy0 * inv_norm 

1364 vx1 = -vy0 

1365 vy1 = vx0 

1366 

1367 sbase = S + pid * 2 

1368 tl.store(sbase, s0) 

1369 tl.store(sbase + 1, s1) 

1370 

1371 inv_s0 = tl.where(s0 > eps, 1.0 / s0, 0.0) 

1372 inv_s1 = tl.where(s1 > eps, 1.0 / s1, 0.0) 

1373 

1374 if TALL: 

1375 ubase = U + pid * M * 2 

1376 u0 = (x * vx0 + y * vy0) * inv_s0 

1377 basis0 = tl.where(offs == 0, 1.0, 0.0) 

1378 basis1 = tl.where(offs == 1, 1.0, 0.0) 

1379 u0 = tl.where(s0 > eps, u0, basis0) 

1380 

1381 u1 = (x * vx1 + y * vy1) * inv_s1 

1382 u0_first = tl.sum(tl.where(offs == 0, u0, 0.0)) 

1383 anchor = tl.where(tl.abs(u0_first) < 0.70710678, basis0, basis1) 

1384 dot = tl.sum(anchor * u0) 

1385 fallback_u1 = anchor - dot * u0 

1386 fallback_norm = tl.sum(fallback_u1 * fallback_u1) 

1387 fallback_u1 = fallback_u1 * tl.rsqrt(fallback_norm + eps) 

1388 u1 = tl.where(s1 > s0 * 5.0e-4, u1, fallback_u1) 

1389 tl.store(ubase + offs * 2, u0, mask=mask) 

1390 tl.store(ubase + offs * 2 + 1, u1, mask=mask) 

1391 

1392 vbase = V + pid * 4 

1393 tl.store(vbase, vx0) 

1394 tl.store(vbase + 1, vx1) 

1395 tl.store(vbase + 2, vy0) 

1396 tl.store(vbase + 3, vy1) 

1397 else: 

1398 ubase = U + pid * 4 

1399 tl.store(ubase, vx0) 

1400 tl.store(ubase + 1, vx1) 

1401 tl.store(ubase + 2, vy0) 

1402 tl.store(ubase + 3, vy1) 

1403 

1404 vbase = V + pid * N * 2 

1405 v0 = (x * vx0 + y * vy0) * inv_s0 

1406 basis0 = tl.where(offs == 0, 1.0, 0.0) 

1407 basis1 = tl.where(offs == 1, 1.0, 0.0) 

1408 v0 = tl.where(s0 > eps, v0, basis0) 

1409 

1410 v1 = (x * vx1 + y * vy1) * inv_s1 

1411 v0_first = tl.sum(tl.where(offs == 0, v0, 0.0)) 

1412 anchor = tl.where(tl.abs(v0_first) < 0.70710678, basis0, basis1) 

1413 dot = tl.sum(anchor * v0) 

1414 fallback_v1 = anchor - dot * v0 

1415 fallback_norm = tl.sum(fallback_v1 * fallback_v1) 

1416 fallback_v1 = fallback_v1 * tl.rsqrt(fallback_norm + eps) 

1417 v1 = tl.where(s1 > s0 * 5.0e-4, v1, fallback_v1) 

1418 tl.store(vbase + offs * 2, v0, mask=mask) 

1419 tl.store(vbase + offs * 2 + 1, v1, mask=mask) 

1420 

1421 

1422def _rank2_svd(input): 

1423 batch, m, n = _svd_shape(input) 

1424 a = input.contiguous().reshape(batch, m, n) 

1425 u = torch.empty((batch, m, 2), dtype=input.dtype, device=input.device) 

1426 s = torch.empty((batch, 2), dtype=input.dtype, device=input.device) 

1427 v = torch.empty((batch, n, 2), dtype=input.dtype, device=input.device) 

1428 largest = max(m, n) 

1429 block_r = triton.next_power_of_2(largest) 

1430 with torch_device_fn.device(input.device): 

1431 if largest <= 16 and batch >= 16: 

1432 if largest <= 2: 

1433 block_b = 8 

1434 elif largest == 16: 

1435 block_b = 2 if m >= n else 8 

1436 else: 

1437 block_b = 16 

1438 _rank2_svd_tiny_kernel[(triton.cdiv(batch, block_b),)]( 

1439 a, 

1440 u, 

1441 s, 

1442 v, 

1443 BATCH=batch, 

1444 M=m, 

1445 N=n, 

1446 TALL=m >= n, 

1447 BLOCK_B=block_b, 

1448 BLOCK_R=block_r, 

1449 num_warps=1, 

1450 ) 

1451 else: 

1452 _rank2_svd_kernel[(batch,)]( 

1453 a, 

1454 u, 

1455 s, 

1456 v, 

1457 M=m, 

1458 N=n, 

1459 TALL=m >= n, 

1460 BLOCK_R=block_r, 

1461 num_warps=1 if block_r <= 64 else 4, 

1462 ) 

1463 return ( 

1464 u.reshape(*input.shape[:-2], m, 2), 

1465 s.reshape(*input.shape[:-2], 2), 

1466 v.reshape(*input.shape[:-2], n, 2), 

1467 ) 

1468 

1469 

1470def _rank2_singular_values(input): 

1471 batch, m, n = _svd_shape(input) 

1472 a = input.contiguous().reshape(batch, m, n) 

1473 s = torch.empty((batch, 2), dtype=input.dtype, device=input.device) 

1474 largest = max(m, n) 

1475 block_r = triton.next_power_of_2(largest) 

1476 with torch_device_fn.device(input.device): 

1477 if largest <= 16 and batch >= 16: 

1478 if largest <= 2: 

1479 block_b = 8 

1480 elif largest == 16: 

1481 block_b = 2 if m >= n else 8 

1482 else: 

1483 block_b = 16 

1484 _rank2_svals_tiny_kernel[(triton.cdiv(batch, block_b),)]( 

1485 a, 

1486 s, 

1487 BATCH=batch, 

1488 M=m, 

1489 N=n, 

1490 TALL=m >= n, 

1491 BLOCK_B=block_b, 

1492 BLOCK_R=block_r, 

1493 num_warps=1, 

1494 ) 

1495 else: 

1496 _rank2_svals_kernel[(batch,)]( 

1497 a, 

1498 s, 

1499 M=m, 

1500 N=n, 

1501 TALL=m >= n, 

1502 BLOCK_R=block_r, 

1503 num_warps=1 if block_r <= 64 else 4, 

1504 ) 

1505 return s.reshape(*input.shape[:-2], 2) 

1506 

1507 

1508def _small_jacobi_singular_values(input): 

1509 batch, m, n = _svd_shape(input) 

1510 k = min(m, n) 

1511 rows = max(m, n) 

1512 a = input.contiguous().reshape(batch, m, n) 

1513 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device) 

1514 s = torch.empty((batch, k), dtype=input.dtype, device=input.device) 

1515 block_r = triton.next_power_of_2(rows) 

1516 block_k = triton.next_power_of_2(k) 

1517 sweeps = 3 if k <= 4 else 5 

1518 with torch_device_fn.device(input.device): 

1519 _small_jacobi_svals_kernel[(batch,)]( 

1520 a, 

1521 a_work, 

1522 s, 

1523 M=m, 

1524 N=n, 

1525 K=k, 

1526 ROWS=rows, 

1527 TALL=m >= n, 

1528 BLOCK_R=block_r, 

1529 BLOCK_K=block_k, 

1530 SWEEPS=sweeps, 

1531 num_warps=1 if block_r <= 64 else 4, 

1532 ) 

1533 return s.reshape(*input.shape[:-2], k) 

1534 

1535 

1536def _small_jacobi_svd(input): 

1537 batch, m, n = _svd_shape(input) 

1538 k = min(m, n) 

1539 rows = max(m, n) 

1540 a = input.contiguous().reshape(batch, m, n) 

1541 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device) 

1542 v_work = torch.empty((batch, k, k), dtype=torch.float32, device=input.device) 

1543 u = torch.empty((batch, m, k), dtype=input.dtype, device=input.device) 

1544 s = torch.empty((batch, k), dtype=input.dtype, device=input.device) 

1545 v = torch.empty((batch, n, k), dtype=input.dtype, device=input.device) 

1546 block_r = triton.next_power_of_2(rows) 

1547 block_k = triton.next_power_of_2(k) 

1548 sweeps = 3 if k <= 4 else 5 

1549 with torch_device_fn.device(input.device): 

1550 _small_jacobi_svd_kernel[(batch,)]( 

1551 a, 

1552 a_work, 

1553 v_work, 

1554 u, 

1555 s, 

1556 v, 

1557 M=m, 

1558 N=n, 

1559 K=k, 

1560 ROWS=rows, 

1561 TALL=m >= n, 

1562 BLOCK_R=block_r, 

1563 BLOCK_K=block_k, 

1564 SWEEPS=sweeps, 

1565 num_warps=1 if block_r <= 64 else 4, 

1566 ) 

1567 return ( 

1568 u.reshape(*input.shape[:-2], m, k), 

1569 s.reshape(*input.shape[:-2], k), 

1570 v.reshape(*input.shape[:-2], n, k), 

1571 ) 

1572 

1573 

1574@libentry() 

1575@triton.jit 

1576def _cyclic_jacobi_init_a_kernel( 

1577 A, 

1578 A_WORK, 

1579 M: tl.constexpr, 

1580 N: tl.constexpr, 

1581 K: tl.constexpr, 

1582 ROWS: tl.constexpr, 

1583 TALL: tl.constexpr, 

1584 BLOCK_R: tl.constexpr, 

1585): 

1586 batch = tl.program_id(0) 

1587 col = tl.program_id(1) 

1588 rows = tl.arange(0, BLOCK_R) 

1589 row_mask = rows < ROWS 

1590 a_base = A + batch * M * N 

1591 aw_base = A_WORK + batch * K * ROWS 

1592 

1593 if TALL: 

1594 vals = tl.load(a_base + rows * N + col, mask=row_mask, other=0.0).to(tl.float32) 

1595 else: 

1596 vals = tl.load(a_base + col * N + rows, mask=row_mask, other=0.0).to(tl.float32) 

1597 tl.store(aw_base + col * ROWS + rows, vals, mask=row_mask) 

1598 

1599 

1600@libentry() 

1601@triton.jit 

1602def _cyclic_jacobi_init_kernel( 

1603 A, 

1604 A_WORK, 

1605 V_WORK, 

1606 M: tl.constexpr, 

1607 N: tl.constexpr, 

1608 K: tl.constexpr, 

1609 ROWS: tl.constexpr, 

1610 TALL: tl.constexpr, 

1611 BLOCK_R: tl.constexpr, 

1612 BLOCK_K: tl.constexpr, 

1613): 

1614 batch = tl.program_id(0) 

1615 col = tl.program_id(1) 

1616 rows = tl.arange(0, BLOCK_R) 

1617 basis_cols = tl.arange(0, BLOCK_K) 

1618 row_mask = rows < ROWS 

1619 basis_mask = basis_cols < K 

1620 a_base = A + batch * M * N 

1621 aw_base = A_WORK + batch * K * ROWS 

1622 vw_base = V_WORK + batch * K * K 

1623 

1624 if TALL: 

1625 vals = tl.load(a_base + rows * N + col, mask=row_mask, other=0.0).to(tl.float32) 

1626 else: 

1627 vals = tl.load(a_base + col * N + rows, mask=row_mask, other=0.0).to(tl.float32) 

1628 tl.store(aw_base + col * ROWS + rows, vals, mask=row_mask) 

1629 

1630 ident = tl.where(basis_cols == col, 1.0, 0.0) 

1631 tl.store(vw_base + col * K + basis_cols, ident, mask=basis_mask) 

1632 

1633 

1634@libentry() 

1635@triton.jit 

1636def _cyclic_jacobi_pair_kernel( 

1637 A_WORK, 

1638 V_WORK, 

1639 STEP, 

1640 K: tl.constexpr, 

1641 ROUND: tl.constexpr, 

1642 ROWS: tl.constexpr, 

1643 BLOCK_R: tl.constexpr, 

1644 BLOCK_K: tl.constexpr, 

1645): 

1646 batch = tl.program_id(0) 

1647 pair = tl.program_id(1) 

1648 rows = tl.arange(0, BLOCK_R) 

1649 cols = tl.arange(0, BLOCK_K) 

1650 ring = ROUND - 1 

1651 

1652 pos_p = pair 

1653 pos_q = ROUND - 1 - pair 

1654 p = tl.where(pos_p == 0, 0, ((pos_p + ring - STEP - 1) % ring) + 1) 

1655 q = tl.where(pos_q == 0, 0, ((pos_q + ring - STEP - 1) % ring) + 1) 

1656 valid_pair = (p < K) & (q < K) 

1657 swap = p > q 

1658 p2 = tl.where(swap, q, p) 

1659 q2 = tl.where(swap, p, q) 

1660 row_mask = (rows < ROWS) & valid_pair 

1661 col_mask = (cols < K) & valid_pair 

1662 

1663 aw_base = A_WORK + batch * K * ROWS 

1664 vw_base = V_WORK + batch * K * K 

1665 ap = tl.load(aw_base + p2 * ROWS + rows, mask=row_mask, other=0.0) 

1666 aq = tl.load(aw_base + q2 * ROWS + rows, mask=row_mask, other=0.0) 

1667 alpha = tl.sum(ap * ap) 

1668 beta = tl.sum(aq * aq) 

1669 gamma = tl.sum(ap * aq) 

1670 eps = 1.0e-20 

1671 abs_gamma = tl.abs(gamma) 

1672 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps) 

1673 active = abs_gamma > threshold 

1674 safe_gamma = tl.where(active, gamma, 1.0) 

1675 tau = (beta - alpha) / (2.0 * safe_gamma) 

1676 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0) 

1677 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)) 

1678 c = tl.rsqrt(1.0 + t * t) 

1679 s_rot = t * c 

1680 c = tl.where(active, c, 1.0) 

1681 s_rot = tl.where(active, s_rot, 0.0) 

1682 

1683 new_ap = c * ap - s_rot * aq 

1684 new_aq = s_rot * ap + c * aq 

1685 tl.store(aw_base + p2 * ROWS + rows, new_ap, mask=row_mask) 

1686 tl.store(aw_base + q2 * ROWS + rows, new_aq, mask=row_mask) 

1687 

1688 vp = tl.load(vw_base + p2 * K + cols, mask=col_mask, other=0.0) 

1689 vq = tl.load(vw_base + q2 * K + cols, mask=col_mask, other=0.0) 

1690 new_vp = c * vp - s_rot * vq 

1691 new_vq = s_rot * vp + c * vq 

1692 tl.store(vw_base + p2 * K + cols, new_vp, mask=col_mask) 

1693 tl.store(vw_base + q2 * K + cols, new_vq, mask=col_mask) 

1694 

1695 

1696@libentry() 

1697@triton.jit 

1698def _serial_cyclic_jacobi_kernel( 

1699 A_WORK, 

1700 V_WORK, 

1701 K, 

1702 ROUND, 

1703 ROWS: tl.constexpr, 

1704 SWEEPS, 

1705 TAIL_STEPS, 

1706 BLOCK_R: tl.constexpr, 

1707 BLOCK_K: tl.constexpr, 

1708): 

1709 batch = tl.program_id(0) 

1710 rows = tl.arange(0, BLOCK_R) 

1711 cols = tl.arange(0, BLOCK_K) 

1712 row_base_mask = rows < ROWS 

1713 col_base_mask = cols < K 

1714 aw_base = A_WORK + batch * K * ROWS 

1715 vw_base = V_WORK + batch * K * K 

1716 eps = 1.0e-20 

1717 ring = ROUND - 1 

1718 half_round = ROUND // 2 

1719 

1720 sweep = 0 

1721 while sweep < SWEEPS: 

1722 step = 0 

1723 while step < ROUND - 1: 

1724 pair = 0 

1725 while pair < half_round: 

1726 pos_p = pair 

1727 pos_q = ROUND - 1 - pair 

1728 p = tl.where(pos_p == 0, 0, ((pos_p + ring - step - 1) % ring) + 1) 

1729 q = tl.where(pos_q == 0, 0, ((pos_q + ring - step - 1) % ring) + 1) 

1730 valid_pair = (p < K) & (q < K) 

1731 swap = p > q 

1732 p2 = tl.where(swap, q, p) 

1733 q2 = tl.where(swap, p, q) 

1734 row_mask = row_base_mask & valid_pair 

1735 col_mask = col_base_mask & valid_pair 

1736 

1737 ap = tl.load(aw_base + p2 * ROWS + rows, mask=row_mask, other=0.0) 

1738 aq = tl.load(aw_base + q2 * ROWS + rows, mask=row_mask, other=0.0) 

1739 alpha = tl.sum(ap * ap) 

1740 beta = tl.sum(aq * aq) 

1741 gamma = tl.sum(ap * aq) 

1742 abs_gamma = tl.abs(gamma) 

1743 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps) 

1744 active = abs_gamma > threshold 

1745 safe_gamma = tl.where(active, gamma, 1.0) 

1746 tau = (beta - alpha) / (2.0 * safe_gamma) 

1747 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0) 

1748 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)) 

1749 c = tl.rsqrt(1.0 + t * t) 

1750 s_rot = t * c 

1751 c = tl.where(active, c, 1.0) 

1752 s_rot = tl.where(active, s_rot, 0.0) 

1753 

1754 new_ap = c * ap - s_rot * aq 

1755 new_aq = s_rot * ap + c * aq 

1756 tl.store(aw_base + p2 * ROWS + rows, new_ap, mask=row_mask) 

1757 tl.store(aw_base + q2 * ROWS + rows, new_aq, mask=row_mask) 

1758 

1759 vp = tl.load(vw_base + p2 * K + cols, mask=col_mask, other=0.0) 

1760 vq = tl.load(vw_base + q2 * K + cols, mask=col_mask, other=0.0) 

1761 new_vp = c * vp - s_rot * vq 

1762 new_vq = s_rot * vp + c * vq 

1763 tl.store(vw_base + p2 * K + cols, new_vp, mask=col_mask) 

1764 tl.store(vw_base + q2 * K + cols, new_vq, mask=col_mask) 

1765 pair += 1 

1766 step += 1 

1767 sweep += 1 

1768 

1769 step = 0 

1770 while step < TAIL_STEPS: 

1771 pair = 0 

1772 while pair < half_round: 

1773 pos_p = pair 

1774 pos_q = ROUND - 1 - pair 

1775 p = tl.where(pos_p == 0, 0, ((pos_p + ring - step - 1) % ring) + 1) 

1776 q = tl.where(pos_q == 0, 0, ((pos_q + ring - step - 1) % ring) + 1) 

1777 valid_pair = (p < K) & (q < K) 

1778 swap = p > q 

1779 p2 = tl.where(swap, q, p) 

1780 q2 = tl.where(swap, p, q) 

1781 row_mask = row_base_mask & valid_pair 

1782 col_mask = col_base_mask & valid_pair 

1783 

1784 ap = tl.load(aw_base + p2 * ROWS + rows, mask=row_mask, other=0.0) 

1785 aq = tl.load(aw_base + q2 * ROWS + rows, mask=row_mask, other=0.0) 

1786 alpha = tl.sum(ap * ap) 

1787 beta = tl.sum(aq * aq) 

1788 gamma = tl.sum(ap * aq) 

1789 abs_gamma = tl.abs(gamma) 

1790 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps) 

1791 active = abs_gamma > threshold 

1792 safe_gamma = tl.where(active, gamma, 1.0) 

1793 tau = (beta - alpha) / (2.0 * safe_gamma) 

1794 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0) 

1795 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)) 

1796 c = tl.rsqrt(1.0 + t * t) 

1797 s_rot = t * c 

1798 c = tl.where(active, c, 1.0) 

1799 s_rot = tl.where(active, s_rot, 0.0) 

1800 

1801 new_ap = c * ap - s_rot * aq 

1802 new_aq = s_rot * ap + c * aq 

1803 tl.store(aw_base + p2 * ROWS + rows, new_ap, mask=row_mask) 

1804 tl.store(aw_base + q2 * ROWS + rows, new_aq, mask=row_mask) 

1805 

1806 vp = tl.load(vw_base + p2 * K + cols, mask=col_mask, other=0.0) 

1807 vq = tl.load(vw_base + q2 * K + cols, mask=col_mask, other=0.0) 

1808 new_vp = c * vp - s_rot * vq 

1809 new_vq = s_rot * vp + c * vq 

1810 tl.store(vw_base + p2 * K + cols, new_vp, mask=col_mask) 

1811 tl.store(vw_base + q2 * K + cols, new_vq, mask=col_mask) 

1812 pair += 1 

1813 step += 1 

1814 

1815 

1816@libentry() 

1817@triton.jit 

1818def _cyclic_jacobi_norm_kernel( 

1819 A_WORK, 

1820 S_WORK, 

1821 K: tl.constexpr, 

1822 ROWS: tl.constexpr, 

1823 BLOCK_R: tl.constexpr, 

1824): 

1825 batch = tl.program_id(0) 

1826 col = tl.program_id(1) 

1827 rows = tl.arange(0, BLOCK_R) 

1828 mask = rows < ROWS 

1829 aw_base = A_WORK + batch * K * ROWS 

1830 vals = tl.load(aw_base + col * ROWS + rows, mask=mask, other=0.0) 

1831 norm = tl.sqrt(tl.sum(vals * vals)) 

1832 tl.store(S_WORK + batch * K + col, norm) 

1833 

1834 

1835@libentry() 

1836@triton.jit 

1837def _cyclic_jacobi_finalize_kernel( 

1838 A_WORK, 

1839 V_WORK, 

1840 S_WORK, 

1841 U, 

1842 S, 

1843 V, 

1844 M: tl.constexpr, 

1845 N: tl.constexpr, 

1846 K: tl.constexpr, 

1847 ROWS: tl.constexpr, 

1848 TALL: tl.constexpr, 

1849 BLOCK_R: tl.constexpr, 

1850 BLOCK_K: tl.constexpr, 

1851): 

1852 batch = tl.program_id(0) 

1853 col = tl.program_id(1) 

1854 rows = tl.arange(0, BLOCK_R) 

1855 basis_cols = tl.arange(0, BLOCK_K) 

1856 row_mask = rows < ROWS 

1857 basis_mask = basis_cols < K 

1858 eps = 1.0e-20 

1859 

1860 s_col = tl.load(S_WORK + batch * K + col) 

1861 rank = tl.full((), 0, dtype=tl.int32) 

1862 for other in tl.static_range(0, K): 

1863 s_other = tl.load(S_WORK + batch * K + other) 

1864 rank += ((s_other > s_col) | ((s_other == s_col) & (other < col))).to(tl.int32) 

1865 

1866 aw_base = A_WORK + batch * K * ROWS 

1867 vw_base = V_WORK + batch * K * K 

1868 col_vals = tl.load(aw_base + col * ROWS + rows, mask=row_mask, other=0.0) 

1869 inv_norm = tl.where(s_col > eps, 1.0 / s_col, 0.0) 

1870 basis = tl.load(vw_base + col * K + basis_cols, mask=basis_mask, other=0.0) 

1871 tl.store(S + batch * K + rank, s_col) 

1872 

1873 if TALL: 

1874 tl.store( 

1875 U + batch * M * K + rows * K + rank, 

1876 col_vals * inv_norm, 

1877 mask=row_mask, 

1878 ) 

1879 tl.store( 

1880 V + batch * N * K + basis_cols * K + rank, 

1881 basis, 

1882 mask=basis_mask, 

1883 ) 

1884 else: 

1885 tl.store( 

1886 U + batch * M * K + basis_cols * K + rank, 

1887 basis, 

1888 mask=basis_mask, 

1889 ) 

1890 tl.store( 

1891 V + batch * N * K + rows * K + rank, 

1892 col_vals * inv_norm, 

1893 mask=row_mask, 

1894 ) 

1895 

1896 

1897@libentry() 

1898@triton.jit 

1899def _blocked_jacobi_pair_svals_kernel( 

1900 A_WORK, 

1901 STEP, 

1902 K: tl.constexpr, 

1903 ROUND: tl.constexpr, 

1904 ROWS: tl.constexpr, 

1905 BLOCK_R: tl.constexpr, 

1906): 

1907 batch = tl.program_id(0) 

1908 pair = tl.program_id(1) 

1909 rows = tl.arange(0, BLOCK_R) 

1910 ring = ROUND - 1 

1911 

1912 pos_p = pair 

1913 pos_q = ROUND - 1 - pair 

1914 p = tl.where(pos_p == 0, 0, ((pos_p + ring - STEP - 1) % ring) + 1) 

1915 q = tl.where(pos_q == 0, 0, ((pos_q + ring - STEP - 1) % ring) + 1) 

1916 valid_pair = (p < K) & (q < K) 

1917 swap = p > q 

1918 p2 = tl.where(swap, q, p) 

1919 q2 = tl.where(swap, p, q) 

1920 row_mask = (rows < ROWS) & valid_pair 

1921 

1922 aw_base = A_WORK + batch * K * ROWS 

1923 ap = tl.load(aw_base + p2 * ROWS + rows, mask=row_mask, other=0.0) 

1924 aq = tl.load(aw_base + q2 * ROWS + rows, mask=row_mask, other=0.0) 

1925 alpha = tl.sum(ap * ap) 

1926 beta = tl.sum(aq * aq) 

1927 gamma = tl.sum(ap * aq) 

1928 eps = 1.0e-20 

1929 abs_gamma = tl.abs(gamma) 

1930 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps) 

1931 active = abs_gamma > threshold 

1932 safe_gamma = tl.where(active, gamma, 1.0) 

1933 tau = (beta - alpha) / (2.0 * safe_gamma) 

1934 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0) 

1935 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)) 

1936 c = tl.rsqrt(1.0 + t * t) 

1937 s_rot = t * c 

1938 c = tl.where(active & valid_pair, c, 1.0) 

1939 s_rot = tl.where(active & valid_pair, s_rot, 0.0) 

1940 

1941 new_ap = c * ap - s_rot * aq 

1942 new_aq = s_rot * ap + c * aq 

1943 tl.store(aw_base + p2 * ROWS + rows, new_ap, mask=row_mask) 

1944 tl.store(aw_base + q2 * ROWS + rows, new_aq, mask=row_mask) 

1945 

1946 

1947@libentry() 

1948@triton.jit 

1949def _blocked_jacobi_pair_a_kernel( 

1950 A_WORK, 

1951 ROT_C, 

1952 ROT_S, 

1953 STEP, 

1954 K: tl.constexpr, 

1955 ROUND: tl.constexpr, 

1956 ROWS: tl.constexpr, 

1957 BLOCK_R: tl.constexpr, 

1958): 

1959 batch = tl.program_id(0) 

1960 pair = tl.program_id(1) 

1961 rows = tl.arange(0, BLOCK_R) 

1962 ring = ROUND - 1 

1963 

1964 pos_p = pair 

1965 pos_q = ROUND - 1 - pair 

1966 p = tl.where(pos_p == 0, 0, ((pos_p + ring - STEP - 1) % ring) + 1) 

1967 q = tl.where(pos_q == 0, 0, ((pos_q + ring - STEP - 1) % ring) + 1) 

1968 valid_pair = (p < K) & (q < K) 

1969 swap = p > q 

1970 p2 = tl.where(swap, q, p) 

1971 q2 = tl.where(swap, p, q) 

1972 row_mask = (rows < ROWS) & valid_pair 

1973 

1974 aw_base = A_WORK + batch * K * ROWS 

1975 ap = tl.load(aw_base + p2 * ROWS + rows, mask=row_mask, other=0.0) 

1976 aq = tl.load(aw_base + q2 * ROWS + rows, mask=row_mask, other=0.0) 

1977 alpha = tl.sum(ap * ap) 

1978 beta = tl.sum(aq * aq) 

1979 gamma = tl.sum(ap * aq) 

1980 eps = 1.0e-20 

1981 abs_gamma = tl.abs(gamma) 

1982 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps) 

1983 active = abs_gamma > threshold 

1984 safe_gamma = tl.where(active, gamma, 1.0) 

1985 tau = (beta - alpha) / (2.0 * safe_gamma) 

1986 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0) 

1987 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)) 

1988 c = tl.rsqrt(1.0 + t * t) 

1989 s_rot = t * c 

1990 c = tl.where(active & valid_pair, c, 1.0) 

1991 s_rot = tl.where(active & valid_pair, s_rot, 0.0) 

1992 

1993 new_ap = c * ap - s_rot * aq 

1994 new_aq = s_rot * ap + c * aq 

1995 tl.store(aw_base + p2 * ROWS + rows, new_ap, mask=row_mask) 

1996 tl.store(aw_base + q2 * ROWS + rows, new_aq, mask=row_mask) 

1997 

1998 rot_base = batch * (ROUND // 2) + pair 

1999 tl.store(ROT_C + rot_base, c) 

2000 tl.store(ROT_S + rot_base, s_rot) 

2001 

2002 

2003@libentry() 

2004@triton.jit 

2005def _hier_block_jacobi_pair_a_kernel( 

2006 A_WORK, 

2007 STEP, 

2008 K: tl.constexpr, 

2009 K_BLOCKS: tl.constexpr, 

2010 ROUND_BLOCKS: tl.constexpr, 

2011 ROWS: tl.constexpr, 

2012 TILE_B: tl.constexpr, 

2013 TILE_COLS: tl.constexpr, 

2014 BLOCK_R: tl.constexpr, 

2015 LOCAL_SWEEPS: tl.constexpr, 

2016): 

2017 batch = tl.program_id(0) 

2018 pair = tl.program_id(1) 

2019 rows = tl.arange(0, BLOCK_R) 

2020 local_cols = tl.arange(0, TILE_COLS) 

2021 ring = ROUND_BLOCKS - 1 

2022 

2023 pos_p = pair 

2024 pos_q = ROUND_BLOCKS - 1 - pair 

2025 p_block = tl.where(pos_p == 0, 0, ((pos_p + ring - STEP - 1) % ring) + 1) 

2026 q_block = tl.where(pos_q == 0, 0, ((pos_q + ring - STEP - 1) % ring) + 1) 

2027 valid_pair = (p_block < K_BLOCKS) & (q_block < K_BLOCKS) 

2028 p2 = tl.minimum(p_block, q_block) 

2029 q2 = tl.maximum(p_block, q_block) 

2030 

2031 col_ids = tl.where( 

2032 local_cols < TILE_B, 

2033 p2 * TILE_B + local_cols, 

2034 q2 * TILE_B + local_cols - TILE_B, 

2035 ) 

2036 row_mask = rows < ROWS 

2037 col_mask = (col_ids < K) & valid_pair 

2038 aw_base = A_WORK + batch * K * ROWS 

2039 vals = tl.load( 

2040 aw_base + col_ids[:, None] * ROWS + rows[None, :], 

2041 mask=col_mask[:, None] & row_mask[None, :], 

2042 other=0.0, 

2043 ).to(tl.float32) 

2044 col_axis = local_cols[:, None] 

2045 eps = 1.0e-20 

2046 

2047 for _ in tl.static_range(0, LOCAL_SWEEPS): 

2048 for p in tl.static_range(0, TILE_COLS): 

2049 for q in tl.static_range(p + 1, TILE_COLS): 

2050 ap = tl.sum(tl.where(col_axis == p, vals, 0.0), axis=0) 

2051 aq = tl.sum(tl.where(col_axis == q, vals, 0.0), axis=0) 

2052 alpha = tl.sum(ap * ap) 

2053 beta = tl.sum(aq * aq) 

2054 gamma = tl.sum(ap * aq) 

2055 abs_gamma = tl.abs(gamma) 

2056 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps) 

2057 active = abs_gamma > threshold 

2058 safe_gamma = tl.where(active, gamma, 1.0) 

2059 tau = (beta - alpha) / (2.0 * safe_gamma) 

2060 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0) 

2061 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)) 

2062 c = tl.rsqrt(1.0 + t * t) 

2063 s_rot = t * c 

2064 c = tl.where(active & valid_pair, c, 1.0) 

2065 s_rot = tl.where(active & valid_pair, s_rot, 0.0) 

2066 

2067 new_ap = c * ap - s_rot * aq 

2068 new_aq = s_rot * ap + c * aq 

2069 vals = tl.where(col_axis == p, new_ap[None, :], vals) 

2070 vals = tl.where(col_axis == q, new_aq[None, :], vals) 

2071 

2072 tl.store( 

2073 aw_base + col_ids[:, None] * ROWS + rows[None, :], 

2074 vals, 

2075 mask=col_mask[:, None] & row_mask[None, :], 

2076 ) 

2077 

2078 

2079@libentry() 

2080@triton.jit 

2081def _blocked_jacobi_apply_v_kernel( 

2082 V_WORK, 

2083 ROT_C, 

2084 ROT_S, 

2085 STEP, 

2086 K: tl.constexpr, 

2087 ROUND: tl.constexpr, 

2088 BLOCK_V: tl.constexpr, 

2089): 

2090 batch = tl.program_id(0) 

2091 pair = tl.program_id(1) 

2092 block = tl.program_id(2) 

2093 cols = block * BLOCK_V + tl.arange(0, BLOCK_V) 

2094 ring = ROUND - 1 

2095 

2096 pos_p = pair 

2097 pos_q = ROUND - 1 - pair 

2098 p = tl.where(pos_p == 0, 0, ((pos_p + ring - STEP - 1) % ring) + 1) 

2099 q = tl.where(pos_q == 0, 0, ((pos_q + ring - STEP - 1) % ring) + 1) 

2100 valid_pair = (p < K) & (q < K) 

2101 swap = p > q 

2102 p2 = tl.where(swap, q, p) 

2103 q2 = tl.where(swap, p, q) 

2104 mask = (cols < K) & valid_pair 

2105 

2106 rot_base = batch * (ROUND // 2) + pair 

2107 c = tl.load(ROT_C + rot_base) 

2108 s_rot = tl.load(ROT_S + rot_base) 

2109 vw_base = V_WORK + batch * K * K 

2110 vp = tl.load(vw_base + p2 * K + cols, mask=mask, other=0.0) 

2111 vq = tl.load(vw_base + q2 * K + cols, mask=mask, other=0.0) 

2112 new_vp = c * vp - s_rot * vq 

2113 new_vq = s_rot * vp + c * vq 

2114 tl.store(vw_base + p2 * K + cols, new_vp, mask=mask) 

2115 tl.store(vw_base + q2 * K + cols, new_vq, mask=mask) 

2116 

2117 

2118@libentry() 

2119@triton.jit 

2120def _blocked_jacobi_rank_kernel( 

2121 S_WORK, 

2122 RANKS, 

2123 S, 

2124 K, 

2125): 

2126 batch = tl.program_id(0) 

2127 col = tl.program_id(1) 

2128 s_col = tl.load(S_WORK + batch * K + col) 

2129 rank = tl.full((), 0, dtype=tl.int32) 

2130 other = 0 

2131 while other < K: 

2132 s_other = tl.load(S_WORK + batch * K + other) 

2133 rank += ((s_other > s_col) | ((s_other == s_col) & (other < col))).to(tl.int32) 

2134 other += 1 

2135 tl.store(RANKS + batch * K + col, rank) 

2136 tl.store(S + batch * K + rank, s_col) 

2137 

2138 

2139@libentry() 

2140@triton.jit 

2141def _blocked_jacobi_store_projected_kernel( 

2142 A_WORK, 

2143 S_WORK, 

2144 RANKS, 

2145 PROJECTED, 

2146 K: tl.constexpr, 

2147 ROWS: tl.constexpr, 

2148 OUT_ROWS: tl.constexpr, 

2149 BLOCK_R: tl.constexpr, 

2150): 

2151 batch = tl.program_id(0) 

2152 col = tl.program_id(1) 

2153 block = tl.program_id(2) 

2154 rows = block * BLOCK_R + tl.arange(0, BLOCK_R) 

2155 mask = rows < OUT_ROWS 

2156 rank = tl.load(RANKS + batch * K + col) 

2157 s_col = tl.load(S_WORK + batch * K + col) 

2158 eps = 1.0e-20 

2159 vals = tl.load( 

2160 A_WORK + batch * K * ROWS + col * ROWS + rows, 

2161 mask=mask, 

2162 other=0.0, 

2163 ) 

2164 vals = vals / tl.maximum(s_col, eps) 

2165 basis = tl.where(rows == rank, 1.0, 0.0) 

2166 vals = tl.where(s_col <= eps, basis, vals) 

2167 tl.store( 

2168 PROJECTED + batch * OUT_ROWS * K + rows * K + rank, 

2169 vals, 

2170 mask=mask, 

2171 ) 

2172 

2173 

2174@libentry() 

2175@triton.jit 

2176def _blocked_jacobi_store_basis_kernel( 

2177 V_WORK, 

2178 RANKS, 

2179 BASIS, 

2180 K: tl.constexpr, 

2181 BLOCK_V: tl.constexpr, 

2182): 

2183 batch = tl.program_id(0) 

2184 col = tl.program_id(1) 

2185 block = tl.program_id(2) 

2186 rows = block * BLOCK_V + tl.arange(0, BLOCK_V) 

2187 mask = rows < K 

2188 rank = tl.load(RANKS + batch * K + col) 

2189 vals = tl.load( 

2190 V_WORK + batch * K * K + col * K + rows, 

2191 mask=mask, 

2192 other=0.0, 

2193 ) 

2194 tl.store( 

2195 BASIS + batch * K * K + rows * K + rank, 

2196 vals, 

2197 mask=mask, 

2198 ) 

2199 

2200 

2201@libentry() 

2202@triton.jit 

2203def _thin_reorthogonalize_kernel( 

2204 Q, 

2205 ROWS: tl.constexpr, 

2206 K: tl.constexpr, 

2207 BLOCK_R: tl.constexpr, 

2208): 

2209 batch = tl.program_id(0) 

2210 rows = tl.arange(0, BLOCK_R) 

2211 row_mask = rows < ROWS 

2212 base = Q + batch * ROWS * K 

2213 eps = 1.0e-20 

2214 

2215 for j in tl.static_range(0, K): 

2216 vec = tl.load(base + rows * K + j, mask=row_mask, other=0.0).to(tl.float32) 

2217 

2218 for prev in tl.static_range(0, K): 

2219 if prev < j: 

2220 q_prev = tl.load(base + rows * K + prev, mask=row_mask, other=0.0).to( 

2221 tl.float32 

2222 ) 

2223 coeff = tl.sum(vec * q_prev) 

2224 vec = vec - coeff * q_prev 

2225 

2226 for prev in tl.static_range(0, K): 

2227 if prev < j: 

2228 q_prev = tl.load(base + rows * K + prev, mask=row_mask, other=0.0).to( 

2229 tl.float32 

2230 ) 

2231 coeff = tl.sum(vec * q_prev) 

2232 vec = vec - coeff * q_prev 

2233 

2234 norm = tl.sqrt(tl.sum(vec * vec)) 

2235 basis = tl.where(rows == j, 1.0, 0.0) 

2236 vec = tl.where(norm > eps, vec, basis) 

2237 

2238 for prev in tl.static_range(0, K): 

2239 if prev < j: 

2240 q_prev = tl.load(base + rows * K + prev, mask=row_mask, other=0.0).to( 

2241 tl.float32 

2242 ) 

2243 coeff = tl.sum(vec * q_prev) 

2244 vec = vec - coeff * q_prev 

2245 

2246 norm = tl.sqrt(tl.sum(vec * vec)) 

2247 vec = vec / tl.maximum(norm, eps) 

2248 tl.store(base + rows * K + j, vec, mask=row_mask) 

2249 

2250 

2251def _cyclic_jacobi_svd(input): 

2252 batch, m, n = _svd_shape(input) 

2253 k = min(m, n) 

2254 rows = max(m, n) 

2255 a = input.contiguous().reshape(batch, m, n) 

2256 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device) 

2257 v_work = torch.empty((batch, k, k), dtype=torch.float32, device=input.device) 

2258 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device) 

2259 u = torch.empty((batch, m, k), dtype=input.dtype, device=input.device) 

2260 s = torch.empty((batch, k), dtype=input.dtype, device=input.device) 

2261 v = torch.empty((batch, n, k), dtype=input.dtype, device=input.device) 

2262 block_r = triton.next_power_of_2(rows) 

2263 block_k = triton.next_power_of_2(k) 

2264 sweeps = 6 if k == 32 else 8 if k < 32 else 12 

2265 tail_steps = 20 if k == 32 else 0 

2266 round_size = k if k % 2 == 0 else k + 1 

2267 serial_medium = 16 <= k <= 32 and rows <= 64 and batch <= 32 

2268 with torch_device_fn.device(input.device): 

2269 _cyclic_jacobi_init_kernel[(batch, k)]( 

2270 a, 

2271 a_work, 

2272 v_work, 

2273 M=m, 

2274 N=n, 

2275 K=k, 

2276 ROWS=rows, 

2277 TALL=m >= n, 

2278 BLOCK_R=block_r, 

2279 BLOCK_K=block_k, 

2280 num_warps=1 if block_r <= 64 else 4, 

2281 ) 

2282 if serial_medium: 

2283 _serial_cyclic_jacobi_kernel[(batch,)]( 

2284 a_work, 

2285 v_work, 

2286 K=k, 

2287 ROUND=round_size, 

2288 ROWS=rows, 

2289 SWEEPS=sweeps, 

2290 TAIL_STEPS=tail_steps, 

2291 BLOCK_R=block_r, 

2292 BLOCK_K=block_k, 

2293 num_warps=1, 

2294 ) 

2295 else: 

2296 for _ in range(sweeps): 

2297 for step in range(round_size - 1): 

2298 _cyclic_jacobi_pair_kernel[(batch, round_size // 2)]( 

2299 a_work, 

2300 v_work, 

2301 step, 

2302 K=k, 

2303 ROUND=round_size, 

2304 ROWS=rows, 

2305 BLOCK_R=block_r, 

2306 BLOCK_K=block_k, 

2307 num_warps=1 if block_r <= 64 else 4, 

2308 ) 

2309 _cyclic_jacobi_norm_kernel[(batch, k)]( 

2310 a_work, 

2311 s_work, 

2312 K=k, 

2313 ROWS=rows, 

2314 BLOCK_R=block_r, 

2315 num_warps=1 if block_r <= 64 else 4, 

2316 ) 

2317 _cyclic_jacobi_finalize_kernel[(batch, k)]( 

2318 a_work, 

2319 v_work, 

2320 s_work, 

2321 u, 

2322 s, 

2323 v, 

2324 M=m, 

2325 N=n, 

2326 K=k, 

2327 ROWS=rows, 

2328 TALL=m >= n, 

2329 BLOCK_R=block_r, 

2330 BLOCK_K=block_k, 

2331 num_warps=1 if block_r <= 64 else 4, 

2332 ) 

2333 if k <= 17 and rows <= 64: 

2334 with torch_device_fn.device(input.device): 

2335 if m >= n: 

2336 _thin_reorthogonalize_kernel[(batch,)]( 

2337 v, 

2338 ROWS=n, 

2339 K=k, 

2340 BLOCK_R=triton.next_power_of_2(n), 

2341 num_warps=1, 

2342 ) 

2343 else: 

2344 _thin_reorthogonalize_kernel[(batch,)]( 

2345 u, 

2346 ROWS=m, 

2347 K=k, 

2348 BLOCK_R=triton.next_power_of_2(m), 

2349 num_warps=1, 

2350 ) 

2351 

2352 if m >= n: 

2353 u = _triton_bmm(a, v, (batch, m, k)) 

2354 projected = u 

2355 projected_rows = m 

2356 else: 

2357 a_t = a.transpose(1, 2).contiguous() 

2358 v = _triton_bmm(a_t, u, (batch, n, k)) 

2359 projected = v 

2360 projected_rows = n 

2361 with torch_device_fn.device(input.device): 

2362 _normalize_projection_kernel[(batch, k)]( 

2363 projected, 

2364 s, 

2365 ROWS=projected_rows, 

2366 K=k, 

2367 BLOCK_R=triton.next_power_of_2(projected_rows), 

2368 num_warps=1, 

2369 ) 

2370 _complete_zero_projection_kernel[(batch, k)]( 

2371 projected, 

2372 s, 

2373 ROWS=projected_rows, 

2374 K=k, 

2375 BLOCK_R=triton.next_power_of_2(projected_rows), 

2376 num_warps=1, 

2377 ) 

2378 if batch > 1 and k <= 16: 

2379 _thin_reorthogonalize_kernel[(batch,)]( 

2380 projected, 

2381 ROWS=projected_rows, 

2382 K=k, 

2383 BLOCK_R=triton.next_power_of_2(projected_rows), 

2384 num_warps=1, 

2385 ) 

2386 return ( 

2387 u.reshape(*input.shape[:-2], m, k), 

2388 s.reshape(*input.shape[:-2], k), 

2389 v.reshape(*input.shape[:-2], n, k), 

2390 ) 

2391 

2392 

2393def _projected_jacobi_svd(input): 

2394 batch, m, n = _svd_shape(input) 

2395 k = min(m, n) 

2396 rows = max(m, n) 

2397 tall = m >= n 

2398 a = input.contiguous().reshape(batch, m, n) 

2399 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device) 

2400 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device) 

2401 ranks = torch.empty((batch, k), dtype=torch.int32, device=input.device) 

2402 s = torch.empty((batch, k), dtype=input.dtype, device=input.device) 

2403 

2404 projected_rows = m if tall else n 

2405 projected = torch.empty( 

2406 (batch, projected_rows, k), dtype=input.dtype, device=input.device 

2407 ) 

2408 block_r = triton.next_power_of_2(rows) 

2409 sweeps = 10 if k >= 128 else 8 

2410 round_size = k if k % 2 == 0 else k + 1 

2411 half_round = round_size // 2 

2412 rot_c = torch.empty((batch, half_round), dtype=torch.float32, device=input.device) 

2413 rot_s = torch.empty((batch, half_round), dtype=torch.float32, device=input.device) 

2414 

2415 with torch_device_fn.device(input.device): 

2416 _cyclic_jacobi_init_a_kernel[(batch, k)]( 

2417 a, 

2418 a_work, 

2419 M=m, 

2420 N=n, 

2421 K=k, 

2422 ROWS=rows, 

2423 TALL=tall, 

2424 BLOCK_R=block_r, 

2425 num_warps=1 if block_r <= 64 else 4, 

2426 ) 

2427 for _ in range(sweeps): 

2428 for step in range(round_size - 1): 

2429 _blocked_jacobi_pair_a_kernel[(batch, half_round)]( 

2430 a_work, 

2431 rot_c, 

2432 rot_s, 

2433 step, 

2434 K=k, 

2435 ROUND=round_size, 

2436 ROWS=rows, 

2437 BLOCK_R=block_r, 

2438 num_warps=1 if block_r <= 64 else 4, 

2439 ) 

2440 _cyclic_jacobi_norm_kernel[(batch, k)]( 

2441 a_work, 

2442 s_work, 

2443 K=k, 

2444 ROWS=rows, 

2445 BLOCK_R=block_r, 

2446 num_warps=1 if block_r <= 64 else 4, 

2447 ) 

2448 _blocked_jacobi_rank_kernel[(batch, k)]( 

2449 s_work, 

2450 ranks, 

2451 s, 

2452 k, 

2453 num_warps=1, 

2454 ) 

2455 _blocked_jacobi_store_projected_kernel[ 

2456 (batch, k, triton.cdiv(projected_rows, block_r)) 

2457 ]( 

2458 a_work, 

2459 s_work, 

2460 ranks, 

2461 projected, 

2462 K=k, 

2463 ROWS=rows, 

2464 OUT_ROWS=projected_rows, 

2465 BLOCK_R=block_r, 

2466 num_warps=1 if block_r <= 64 else 4, 

2467 ) 

2468 

2469 if tall: 

2470 u = projected 

2471 v = _triton_bmm(a.transpose(1, 2).contiguous(), u, (batch, n, k)) 

2472 normalized = v 

2473 normalized_rows = n 

2474 else: 

2475 v = projected 

2476 u = _triton_bmm(a, v, (batch, m, k)) 

2477 normalized = u 

2478 normalized_rows = m 

2479 

2480 with torch_device_fn.device(input.device): 

2481 _normalize_projection_kernel[(batch, k)]( 

2482 normalized, 

2483 s, 

2484 ROWS=normalized_rows, 

2485 K=k, 

2486 BLOCK_R=triton.next_power_of_2(normalized_rows), 

2487 num_warps=1 if normalized_rows <= 64 else 4, 

2488 ) 

2489 

2490 return ( 

2491 u.reshape(*input.shape[:-2], m, k), 

2492 s.reshape(*input.shape[:-2], k), 

2493 v.reshape(*input.shape[:-2], n, k), 

2494 ) 

2495 

2496 

2497def _blocked_jacobi_svd(input): 

2498 batch, m, n = _svd_shape(input) 

2499 k = min(m, n) 

2500 rows = max(m, n) 

2501 a = input.contiguous().reshape(batch, m, n) 

2502 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device) 

2503 v_work = torch.empty((batch, k, k), dtype=torch.float32, device=input.device) 

2504 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device) 

2505 ranks = torch.empty((batch, k), dtype=torch.int32, device=input.device) 

2506 u = torch.empty((batch, m, k), dtype=input.dtype, device=input.device) 

2507 s = torch.empty((batch, k), dtype=input.dtype, device=input.device) 

2508 v = torch.empty((batch, n, k), dtype=input.dtype, device=input.device) 

2509 

2510 block_r = triton.next_power_of_2(rows) 

2511 block_k = triton.next_power_of_2(k) 

2512 block_v = 64 

2513 sweeps = 14 if k > 256 else 10 

2514 round_size = k if k % 2 == 0 else k + 1 

2515 half_round = round_size // 2 

2516 rot_c = torch.empty((batch, half_round), dtype=torch.float32, device=input.device) 

2517 rot_s = torch.empty((batch, half_round), dtype=torch.float32, device=input.device) 

2518 with torch_device_fn.device(input.device): 

2519 _cyclic_jacobi_init_kernel[(batch, k)]( 

2520 a, 

2521 a_work, 

2522 v_work, 

2523 M=m, 

2524 N=n, 

2525 K=k, 

2526 ROWS=rows, 

2527 TALL=m >= n, 

2528 BLOCK_R=block_r, 

2529 BLOCK_K=block_k, 

2530 num_warps=1 if block_r <= 64 else 4, 

2531 ) 

2532 for _ in range(sweeps): 

2533 for step in range(round_size - 1): 

2534 _blocked_jacobi_pair_a_kernel[(batch, half_round)]( 

2535 a_work, 

2536 rot_c, 

2537 rot_s, 

2538 step, 

2539 K=k, 

2540 ROUND=round_size, 

2541 ROWS=rows, 

2542 BLOCK_R=block_r, 

2543 num_warps=1 if block_r <= 64 else 4, 

2544 ) 

2545 _blocked_jacobi_apply_v_kernel[ 

2546 (batch, half_round, triton.cdiv(k, block_v)) 

2547 ]( 

2548 v_work, 

2549 rot_c, 

2550 rot_s, 

2551 step, 

2552 K=k, 

2553 ROUND=round_size, 

2554 BLOCK_V=block_v, 

2555 num_warps=1, 

2556 ) 

2557 _cyclic_jacobi_norm_kernel[(batch, k)]( 

2558 a_work, 

2559 s_work, 

2560 K=k, 

2561 ROWS=rows, 

2562 BLOCK_R=block_r, 

2563 num_warps=1 if block_r <= 64 else 4, 

2564 ) 

2565 _blocked_jacobi_rank_kernel[(batch, k)]( 

2566 s_work, 

2567 ranks, 

2568 s, 

2569 k, 

2570 num_warps=1, 

2571 ) 

2572 if m >= n: 

2573 _blocked_jacobi_store_projected_kernel[(batch, k, triton.cdiv(m, block_r))]( 

2574 a_work, 

2575 s_work, 

2576 ranks, 

2577 u, 

2578 K=k, 

2579 ROWS=rows, 

2580 OUT_ROWS=m, 

2581 BLOCK_R=block_r, 

2582 num_warps=1 if block_r <= 64 else 4, 

2583 ) 

2584 _blocked_jacobi_store_basis_kernel[(batch, k, triton.cdiv(n, block_v))]( 

2585 v_work, 

2586 ranks, 

2587 v, 

2588 K=k, 

2589 BLOCK_V=block_v, 

2590 num_warps=1, 

2591 ) 

2592 else: 

2593 _blocked_jacobi_store_basis_kernel[(batch, k, triton.cdiv(m, block_v))]( 

2594 v_work, 

2595 ranks, 

2596 u, 

2597 K=k, 

2598 BLOCK_V=block_v, 

2599 num_warps=1, 

2600 ) 

2601 _blocked_jacobi_store_projected_kernel[(batch, k, triton.cdiv(n, block_r))]( 

2602 a_work, 

2603 s_work, 

2604 ranks, 

2605 v, 

2606 K=k, 

2607 ROWS=rows, 

2608 OUT_ROWS=n, 

2609 BLOCK_R=block_r, 

2610 num_warps=1 if block_r <= 64 else 4, 

2611 ) 

2612 

2613 return ( 

2614 u.reshape(*input.shape[:-2], m, k), 

2615 s.reshape(*input.shape[:-2], k), 

2616 v.reshape(*input.shape[:-2], n, k), 

2617 ) 

2618 

2619 

2620def _blocked_jacobi_square_project_svd(input): 

2621 batch, m, n = _svd_shape(input) 

2622 k = min(m, n) 

2623 rows = max(m, n) 

2624 a = input.contiguous().reshape(batch, m, n) 

2625 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device) 

2626 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device) 

2627 ranks = torch.empty((batch, k), dtype=torch.int32, device=input.device) 

2628 u = torch.empty((batch, m, k), dtype=input.dtype, device=input.device) 

2629 s = torch.empty((batch, k), dtype=input.dtype, device=input.device) 

2630 

2631 block_r = triton.next_power_of_2(rows) 

2632 sweeps = 12 if k <= 256 else 16 

2633 round_size = k if k % 2 == 0 else k + 1 

2634 half_round = round_size // 2 

2635 rot_c = torch.empty((batch, half_round), dtype=torch.float32, device=input.device) 

2636 rot_s = torch.empty((batch, half_round), dtype=torch.float32, device=input.device) 

2637 with torch_device_fn.device(input.device): 

2638 _cyclic_jacobi_init_a_kernel[(batch, k)]( 

2639 a, 

2640 a_work, 

2641 M=m, 

2642 N=n, 

2643 K=k, 

2644 ROWS=rows, 

2645 TALL=True, 

2646 BLOCK_R=block_r, 

2647 num_warps=1 if block_r <= 64 else 4, 

2648 ) 

2649 for _ in range(sweeps): 

2650 for step in range(round_size - 1): 

2651 _blocked_jacobi_pair_a_kernel[(batch, half_round)]( 

2652 a_work, 

2653 rot_c, 

2654 rot_s, 

2655 step, 

2656 K=k, 

2657 ROUND=round_size, 

2658 ROWS=rows, 

2659 BLOCK_R=block_r, 

2660 num_warps=1 if block_r <= 64 else 4, 

2661 ) 

2662 _cyclic_jacobi_norm_kernel[(batch, k)]( 

2663 a_work, 

2664 s_work, 

2665 K=k, 

2666 ROWS=rows, 

2667 BLOCK_R=block_r, 

2668 num_warps=1 if block_r <= 64 else 4, 

2669 ) 

2670 _blocked_jacobi_rank_kernel[(batch, k)]( 

2671 s_work, 

2672 ranks, 

2673 s, 

2674 k, 

2675 num_warps=1, 

2676 ) 

2677 _blocked_jacobi_store_projected_kernel[(batch, k, triton.cdiv(m, block_r))]( 

2678 a_work, 

2679 s_work, 

2680 ranks, 

2681 u, 

2682 K=k, 

2683 ROWS=rows, 

2684 OUT_ROWS=m, 

2685 BLOCK_R=block_r, 

2686 num_warps=1 if block_r <= 64 else 4, 

2687 ) 

2688 

2689 a_t = a.transpose(1, 2).contiguous() 

2690 v = _triton_bmm(a_t, u, (batch, n, k)) 

2691 with torch_device_fn.device(input.device): 

2692 _renorm_projection_update_s_kernel[(batch, k)]( 

2693 v, 

2694 s, 

2695 ROWS=n, 

2696 K=k, 

2697 BLOCK_R=triton.next_power_of_2(n), 

2698 num_warps=1 if n <= 64 else 4, 

2699 ) 

2700 

2701 return ( 

2702 u.reshape(*input.shape[:-2], m, k), 

2703 s.reshape(*input.shape[:-2], k), 

2704 v.reshape(*input.shape[:-2], n, k), 

2705 ) 

2706 

2707 

2708def _hier_block_jacobi_square_project_svd(input): 

2709 batch, m, n = _svd_shape(input) 

2710 k = min(m, n) 

2711 rows = max(m, n) 

2712 a = input.contiguous().reshape(batch, m, n) 

2713 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device) 

2714 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device) 

2715 ranks = torch.empty((batch, k), dtype=torch.int32, device=input.device) 

2716 u = torch.empty((batch, m, k), dtype=input.dtype, device=input.device) 

2717 s = torch.empty((batch, k), dtype=input.dtype, device=input.device) 

2718 

2719 tile_b = 4 if k == 512 else 2 

2720 if m != n or k % tile_b != 0: 

2721 return _unsupported_svd( 

2722 input, 

2723 True, 

2724 True, 

2725 "Hierarchical block Jacobi supports square matrices with " 

2726 "k divisible by two.", 

2727 ) 

2728 

2729 block_r = triton.next_power_of_2(rows) 

2730 block_count = k // tile_b 

2731 round_blocks = block_count if block_count % 2 == 0 else block_count + 1 

2732 half_round_blocks = round_blocks // 2 

2733 sweep_count = 10 if k <= 256 else 12 

2734 tile_cols = tile_b * 2 

2735 with torch_device_fn.device(input.device): 

2736 _cyclic_jacobi_init_a_kernel[(batch, k)]( 

2737 a, 

2738 a_work, 

2739 M=m, 

2740 N=n, 

2741 K=k, 

2742 ROWS=rows, 

2743 TALL=True, 

2744 BLOCK_R=block_r, 

2745 num_warps=1 if block_r <= 64 else 4, 

2746 ) 

2747 for _ in range(sweep_count): 

2748 for step in range(round_blocks - 1): 

2749 _hier_block_jacobi_pair_a_kernel[(batch, half_round_blocks)]( 

2750 a_work, 

2751 step, 

2752 K=k, 

2753 K_BLOCKS=block_count, 

2754 ROUND_BLOCKS=round_blocks, 

2755 ROWS=rows, 

2756 TILE_B=tile_b, 

2757 TILE_COLS=tile_cols, 

2758 BLOCK_R=block_r, 

2759 LOCAL_SWEEPS=1, 

2760 num_warps=4, 

2761 ) 

2762 _cyclic_jacobi_norm_kernel[(batch, k)]( 

2763 a_work, 

2764 s_work, 

2765 K=k, 

2766 ROWS=rows, 

2767 BLOCK_R=block_r, 

2768 num_warps=1 if block_r <= 64 else 4, 

2769 ) 

2770 _blocked_jacobi_rank_kernel[(batch, k)]( 

2771 s_work, 

2772 ranks, 

2773 s, 

2774 k, 

2775 num_warps=1, 

2776 ) 

2777 _blocked_jacobi_store_projected_kernel[(batch, k, triton.cdiv(m, block_r))]( 

2778 a_work, 

2779 s_work, 

2780 ranks, 

2781 u, 

2782 K=k, 

2783 ROWS=rows, 

2784 OUT_ROWS=m, 

2785 BLOCK_R=block_r, 

2786 num_warps=1 if block_r <= 64 else 4, 

2787 ) 

2788 

2789 a_t = a.transpose(1, 2).contiguous() 

2790 v = _triton_bmm(a_t, u, (batch, n, k)) 

2791 with torch_device_fn.device(input.device): 

2792 _renorm_projection_update_s_kernel[(batch, k)]( 

2793 v, 

2794 s, 

2795 ROWS=n, 

2796 K=k, 

2797 BLOCK_R=triton.next_power_of_2(n), 

2798 num_warps=1 if n <= 64 else 4, 

2799 ) 

2800 

2801 return ( 

2802 u.reshape(*input.shape[:-2], m, k), 

2803 s.reshape(*input.shape[:-2], k), 

2804 v.reshape(*input.shape[:-2], n, k), 

2805 ) 

2806 

2807 

2808def _blocked_jacobi_singular_values(input): 

2809 batch, m, n = _svd_shape(input) 

2810 k = min(m, n) 

2811 rows = max(m, n) 

2812 a = input.contiguous().reshape(batch, m, n) 

2813 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device) 

2814 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device) 

2815 ranks = torch.empty((batch, k), dtype=torch.int32, device=input.device) 

2816 s = torch.empty((batch, k), dtype=input.dtype, device=input.device) 

2817 

2818 block_r = triton.next_power_of_2(rows) 

2819 sweeps = 14 if k > 256 else 10 

2820 round_size = k if k % 2 == 0 else k + 1 

2821 half_round = round_size // 2 

2822 with torch_device_fn.device(input.device): 

2823 _cyclic_jacobi_init_a_kernel[(batch, k)]( 

2824 a, 

2825 a_work, 

2826 M=m, 

2827 N=n, 

2828 K=k, 

2829 ROWS=rows, 

2830 TALL=m >= n, 

2831 BLOCK_R=block_r, 

2832 num_warps=1 if block_r <= 64 else 4, 

2833 ) 

2834 for _ in range(sweeps): 

2835 for step in range(round_size - 1): 

2836 _blocked_jacobi_pair_svals_kernel[(batch, half_round)]( 

2837 a_work, 

2838 step, 

2839 K=k, 

2840 ROUND=round_size, 

2841 ROWS=rows, 

2842 BLOCK_R=block_r, 

2843 num_warps=1 if block_r <= 64 else 4, 

2844 ) 

2845 _cyclic_jacobi_norm_kernel[(batch, k)]( 

2846 a_work, 

2847 s_work, 

2848 K=k, 

2849 ROWS=rows, 

2850 BLOCK_R=block_r, 

2851 num_warps=1 if block_r <= 64 else 4, 

2852 ) 

2853 _blocked_jacobi_rank_kernel[(batch, k)]( 

2854 s_work, 

2855 ranks, 

2856 s, 

2857 k, 

2858 num_warps=1, 

2859 ) 

2860 

2861 return s.reshape(*input.shape[:-2], k) 

2862 

2863 

2864def _small4_square_svd(input): 

2865 batch, m, n = _svd_shape(input) 

2866 a = input.contiguous().reshape(batch, m, n) 

2867 u = torch.empty((batch, 4, 4), dtype=input.dtype, device=input.device) 

2868 s = torch.empty((batch, 4), dtype=input.dtype, device=input.device) 

2869 v = torch.empty((batch, 4, 4), dtype=input.dtype, device=input.device) 

2870 block_b = 16 

2871 with torch_device_fn.device(input.device): 

2872 _small4_square_svd_kernel[(triton.cdiv(batch, block_b),)]( 

2873 a, u, s, v, BATCH=batch, BLOCK_B=block_b, SWEEPS=4, num_warps=1 

2874 ) 

2875 return ( 

2876 u.reshape(*input.shape[:-2], 4, 4), 

2877 s.reshape(*input.shape[:-2], 4), 

2878 v.reshape(*input.shape[:-2], 4, 4), 

2879 ) 

2880 

2881 

2882@libentry() 

2883@triton.jit 

2884def _rank1_svd_kernel( 

2885 A, 

2886 U, 

2887 S, 

2888 V, 

2889 M: tl.constexpr, 

2890 N: tl.constexpr, 

2891 TALL: tl.constexpr, 

2892 BLOCK_R: tl.constexpr, 

2893): 

2894 pid = tl.program_id(0) 

2895 offsets = tl.arange(0, BLOCK_R) 

2896 eps = 1.1920928955078125e-7 

2897 a_base = A + pid * M * N 

2898 norm_sq = tl.full((), 0.0, dtype=tl.float32) 

2899 

2900 if TALL: 

2901 for base in range(0, M, BLOCK_R): 

2902 rows = base + offsets 

2903 mask = rows < M 

2904 vals = tl.load(a_base + rows * N, mask=mask, other=0.0).to(tl.float32) 

2905 norm_sq += tl.sum(vals * vals) 

2906 

2907 norm = tl.sqrt(norm_sq) 

2908 denom = tl.maximum(norm, eps) 

2909 tl.store(S + pid, norm) 

2910 tl.store(V + pid, 1.0) 

2911 

2912 u_base = U + pid * M 

2913 for base in range(0, M, BLOCK_R): 

2914 rows = base + offsets 

2915 mask = rows < M 

2916 vals = tl.load(a_base + rows * N, mask=mask, other=0.0).to(tl.float32) 

2917 tl.store(u_base + rows, vals / denom, mask=mask) 

2918 else: 

2919 for base in range(0, N, BLOCK_R): 

2920 cols = base + offsets 

2921 mask = cols < N 

2922 vals = tl.load(a_base + cols, mask=mask, other=0.0).to(tl.float32) 

2923 norm_sq += tl.sum(vals * vals) 

2924 

2925 norm = tl.sqrt(norm_sq) 

2926 denom = tl.maximum(norm, eps) 

2927 tl.store(S + pid, norm) 

2928 tl.store(U + pid, 1.0) 

2929 

2930 v_base = V + pid * N 

2931 for base in range(0, N, BLOCK_R): 

2932 cols = base + offsets 

2933 mask = cols < N 

2934 vals = tl.load(a_base + cols, mask=mask, other=0.0).to(tl.float32) 

2935 tl.store(v_base + cols, vals / denom, mask=mask) 

2936 

2937 

2938def _rank1_svd(input): 

2939 batch, m, n = _svd_shape(input) 

2940 a = input.contiguous().reshape(batch, m, n) 

2941 u = torch.empty((batch, m, 1), dtype=input.dtype, device=input.device) 

2942 s = torch.empty((batch, 1), dtype=input.dtype, device=input.device) 

2943 v = torch.empty((batch, n, 1), dtype=input.dtype, device=input.device) 

2944 if batch != 0: 

2945 rows = max(m, n) 

2946 block_r = _RANK1_BLOCK_R_MAX 

2947 if rows <= _RANK1_BLOCK_R_MAX: 

2948 block_r = triton.next_power_of_2(rows) 

2949 with torch_device_fn.device(input.device): 

2950 _rank1_svd_kernel[(batch,)]( 

2951 a, 

2952 u, 

2953 s, 

2954 v, 

2955 m, 

2956 n, 

2957 TALL=n == 1, 

2958 BLOCK_R=block_r, 

2959 num_warps=1 if block_r <= 64 else 4, 

2960 ) 

2961 return ( 

2962 u.reshape(*input.shape[:-2], m, 1), 

2963 s.reshape(*input.shape[:-2], 1), 

2964 v.reshape(*input.shape[:-2], n, 1), 

2965 ) 

2966 

2967 

2968@libentry() 

2969@triton.jit 

2970def _complex_to_real_embedding_kernel( 

2971 A_RI, 

2972 R, 

2973 M: tl.constexpr, 

2974 N: tl.constexpr, 

2975 BLOCK_SIZE: tl.constexpr, 

2976): 

2977 batch = tl.program_id(0) 

2978 offsets = tl.arange(0, BLOCK_SIZE) 

2979 total = 4 * M * N 

2980 mask = offsets < total 

2981 row = offsets // (2 * N) 

2982 col = offsets - row * (2 * N) 

2983 src_row = tl.where(row < M, row, row - M) 

2984 src_col = tl.where(col < N, col, col - N) 

2985 comp = tl.where((row < M) & (col >= N), 1, 0) 

2986 comp = tl.where((row >= M) & (col < N), 1, comp) 

2987 vals = tl.load( 

2988 A_RI + batch * M * N * 2 + (src_row * N + src_col) * 2 + comp, 

2989 mask=mask, 

2990 other=0.0, 

2991 ) 

2992 sign = tl.where((row < M) & (col >= N), -1.0, 1.0) 

2993 tl.store(R + batch * 4 * M * N + offsets, vals * sign, mask=mask) 

2994 

2995 

2996@libentry() 

2997@triton.jit 

2998def _complex_svd_pick_factor_kernel( 

2999 REAL_FACTOR, 

3000 OUT_RI, 

3001 ROWS: tl.constexpr, 

3002 K: tl.constexpr, 

3003 REAL_K: tl.constexpr, 

3004 BLOCK_SIZE: tl.constexpr, 

3005): 

3006 batch = tl.program_id(0) 

3007 offsets = tl.arange(0, BLOCK_SIZE) 

3008 mask = offsets < ROWS * K 

3009 row = offsets // K 

3010 col = offsets % K 

3011 src_col = col * 2 

3012 real = tl.load( 

3013 REAL_FACTOR + batch * (2 * ROWS) * REAL_K + row * REAL_K + src_col, 

3014 mask=mask, 

3015 other=0.0, 

3016 ) 

3017 imag = tl.load( 

3018 REAL_FACTOR + batch * (2 * ROWS) * REAL_K + (ROWS + row) * REAL_K + src_col, 

3019 mask=mask, 

3020 other=0.0, 

3021 ) 

3022 out_base = OUT_RI + batch * ROWS * K * 2 + offsets * 2 

3023 tl.store(out_base, real, mask=mask) 

3024 tl.store(out_base + 1, imag, mask=mask) 

3025 

3026 

3027@libentry() 

3028@triton.jit 

3029def _complex_svd_pick_s_kernel( 

3030 S_REAL, 

3031 S, 

3032 K: tl.constexpr, 

3033 REAL_K: tl.constexpr, 

3034 BLOCK_K: tl.constexpr, 

3035): 

3036 batch = tl.program_id(0) 

3037 cols = tl.arange(0, BLOCK_K) 

3038 mask = cols < K 

3039 src = cols * 2 

3040 vals_a = tl.load(S_REAL + batch * REAL_K + src, mask=mask, other=0.0) 

3041 vals_b = tl.load(S_REAL + batch * REAL_K + src + 1, mask=mask, other=0.0) 

3042 tl.store(S + batch * K + cols, 0.5 * (vals_a + vals_b), mask=mask) 

3043 

3044 

3045@libentry() 

3046@triton.jit 

3047def _complex_svd_pick_orthonormal_v_kernel( 

3048 V_REAL, 

3049 V_RI, 

3050 ROWS: tl.constexpr, 

3051 K: tl.constexpr, 

3052 REAL_K: tl.constexpr, 

3053 BLOCK_ROWS: tl.constexpr, 

3054 BLOCK_K: tl.constexpr, 

3055): 

3056 batch = tl.program_id(0) 

3057 rows = tl.arange(0, BLOCK_ROWS) 

3058 cols = tl.arange(0, BLOCK_K) 

3059 row_mask = rows < ROWS 

3060 col_mask = cols < K 

3061 src_cols = cols * 2 

3062 base = V_REAL + batch * (2 * ROWS) * REAL_K 

3063 vr = tl.load( 

3064 base + rows[:, None] * REAL_K + src_cols[None, :], 

3065 mask=row_mask[:, None] & col_mask[None, :], 

3066 other=0.0, 

3067 ) 

3068 vi = tl.load( 

3069 base + (ROWS + rows[:, None]) * REAL_K + src_cols[None, :], 

3070 mask=row_mask[:, None] & col_mask[None, :], 

3071 other=0.0, 

3072 ) 

3073 

3074 for c in tl.static_range(0, 16): 

3075 cur_mask = c < K 

3076 cur_r = tl.sum(tl.where(cols[None, :] == c, vr, 0.0), axis=1) 

3077 cur_i = tl.sum(tl.where(cols[None, :] == c, vi, 0.0), axis=1) 

3078 for p in tl.static_range(0, c): 

3079 prev_r = tl.sum(tl.where(cols[None, :] == p, vr, 0.0), axis=1) 

3080 prev_i = tl.sum(tl.where(cols[None, :] == p, vi, 0.0), axis=1) 

3081 coeff_r = tl.sum( 

3082 tl.where(row_mask, prev_r * cur_r + prev_i * cur_i, 0.0), axis=0 

3083 ) 

3084 coeff_i = tl.sum( 

3085 tl.where(row_mask, prev_r * cur_i - prev_i * cur_r, 0.0), axis=0 

3086 ) 

3087 cur_r -= prev_r * coeff_r - prev_i * coeff_i 

3088 cur_i -= prev_r * coeff_i + prev_i * coeff_r 

3089 norm_sq = tl.sum(tl.where(row_mask, cur_r * cur_r + cur_i * cur_i, 0.0), axis=0) 

3090 inv_norm = tl.rsqrt(tl.maximum(norm_sq, 1.0e-20)) 

3091 cur_r *= inv_norm 

3092 cur_i *= inv_norm 

3093 vr = tl.where((cols[None, :] == c) & cur_mask, cur_r[:, None], vr) 

3094 vi = tl.where((cols[None, :] == c) & cur_mask, cur_i[:, None], vi) 

3095 

3096 out_base = V_RI + batch * ROWS * K * 2 

3097 offsets = rows[:, None] * K + cols[None, :] 

3098 mask = row_mask[:, None] & col_mask[None, :] 

3099 tl.store(out_base + offsets * 2, vr, mask=mask) 

3100 tl.store(out_base + offsets * 2 + 1, vi, mask=mask) 

3101 

3102 

3103@libentry() 

3104@triton.jit 

3105def _complex_svd_project_u_kernel( 

3106 A_RI, 

3107 V_RI, 

3108 S, 

3109 U_RI, 

3110 M: tl.constexpr, 

3111 N: tl.constexpr, 

3112 K: tl.constexpr, 

3113 BLOCK_SIZE: tl.constexpr, 

3114): 

3115 batch = tl.program_id(0) 

3116 offsets = tl.arange(0, BLOCK_SIZE) 

3117 mask = offsets < M * K 

3118 row = offsets // K 

3119 col = offsets % K 

3120 

3121 acc_r = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) 

3122 acc_i = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) 

3123 for j in tl.static_range(0, N): 

3124 a_base = A_RI + batch * M * N * 2 + (row * N + j) * 2 

3125 v_base = V_RI + batch * N * K * 2 + (j * K + col) * 2 

3126 ar = tl.load(a_base, mask=mask, other=0.0) 

3127 ai = tl.load(a_base + 1, mask=mask, other=0.0) 

3128 vr = tl.load(v_base, mask=mask, other=0.0) 

3129 vi = tl.load(v_base + 1, mask=mask, other=0.0) 

3130 acc_r += ar * vr - ai * vi 

3131 acc_i += ar * vi + ai * vr 

3132 

3133 s = tl.load(S + batch * K + col, mask=mask, other=1.0) 

3134 inv_s = tl.where(s > 1.0e-20, 1.0 / s, 0.0) 

3135 out_base = U_RI + batch * M * K * 2 + offsets * 2 

3136 tl.store(out_base, acc_r * inv_s, mask=mask) 

3137 tl.store(out_base + 1, acc_i * inv_s, mask=mask) 

3138 

3139 

3140def _complex_svd_via_real_embedding(input): 

3141 batch, m, n = _svd_shape(input) 

3142 k = min(m, n) 

3143 src = input.contiguous() 

3144 a_ri = tensor_wrapper.TypedPtr.reinterpret_tensor(src, src.dtype.to_real()) 

3145 real_matrix = torch.empty( 

3146 (batch, 2 * m, 2 * n), dtype=torch.float32, device=input.device 

3147 ) 

3148 block_size = triton.next_power_of_2(4 * m * n) 

3149 with torch_device_fn.device(input.device): 

3150 _complex_to_real_embedding_kernel[(batch,)]( 

3151 a_ri, 

3152 real_matrix, 

3153 M=m, 

3154 N=n, 

3155 BLOCK_SIZE=block_size, 

3156 num_warps=1, 

3157 ) 

3158 _, s_real, v_real = svd(real_matrix, some=True, compute_uv=True) 

3159 s = torch.empty((batch, k), dtype=torch.float32, device=input.device) 

3160 u = torch.empty((*input.shape[:-2], m, k), dtype=input.dtype, device=input.device) 

3161 v = torch.empty((*input.shape[:-2], n, k), dtype=input.dtype, device=input.device) 

3162 u_ri = tensor_wrapper.TypedPtr.reinterpret_tensor(u, u.dtype.to_real()) 

3163 v_ri = tensor_wrapper.TypedPtr.reinterpret_tensor(v, v.dtype.to_real()) 

3164 with torch_device_fn.device(input.device): 

3165 _complex_svd_pick_s_kernel[(batch,)]( 

3166 s_real, 

3167 s, 

3168 K=k, 

3169 REAL_K=2 * k, 

3170 BLOCK_K=triton.next_power_of_2(k), 

3171 num_warps=1, 

3172 ) 

3173 _complex_svd_pick_orthonormal_v_kernel[(batch,)]( 

3174 v_real, 

3175 v_ri, 

3176 ROWS=n, 

3177 K=k, 

3178 REAL_K=2 * k, 

3179 BLOCK_ROWS=triton.next_power_of_2(n), 

3180 BLOCK_K=triton.next_power_of_2(k), 

3181 num_warps=1, 

3182 ) 

3183 _complex_svd_project_u_kernel[(batch,)]( 

3184 a_ri, 

3185 v_ri, 

3186 s, 

3187 u_ri, 

3188 M=m, 

3189 N=n, 

3190 K=k, 

3191 BLOCK_SIZE=triton.next_power_of_2(m * k), 

3192 num_warps=1, 

3193 ) 

3194 return ( 

3195 u, 

3196 s.reshape(*input.shape[:-2], k), 

3197 v, 

3198 ) 

3199 

3200 

3201def _complex_svd_cpu_fallback(input, some=True, compute_uv=True): 

3202 cpu_u, cpu_s, cpu_v = torch.svd(input.cpu(), some=some, compute_uv=compute_uv) 

3203 return ( 

3204 cpu_u.to(input.device), 

3205 cpu_s.to(input.device), 

3206 cpu_v.to(input.device), 

3207 ) 

3208 

3209 

3210def _gram_svd(input): 

3211 return _unsupported_svd(input, True, True) 

3212 

3213 

3214@libentry() 

3215@triton.jit 

3216def _gram16_finalize_kernel( 

3217 A, 

3218 EVALS, 

3219 EVECS, 

3220 U, 

3221 S, 

3222 V, 

3223 M: tl.constexpr, 

3224 N: tl.constexpr, 

3225 ROWS: tl.constexpr, 

3226 TALL: tl.constexpr, 

3227 EVECS_BATCH_STRIDE: tl.constexpr, 

3228 EVECS_ROW_STRIDE: tl.constexpr, 

3229 EVECS_COL_STRIDE: tl.constexpr, 

3230 BLOCK_R: tl.constexpr, 

3231): 

3232 batch = tl.program_id(0) 

3233 row_block = tl.program_id(1) 

3234 rows = row_block * BLOCK_R + tl.arange(0, BLOCK_R) 

3235 cols = tl.arange(0, 16) 

3236 src_cols = 15 - cols 

3237 row_mask = rows < ROWS 

3238 eps = 1.0e-20 

3239 

3240 vals = tl.load(EVALS + batch * 16 + src_cols) 

3241 s_vals = tl.sqrt(tl.maximum(vals, 0.0)) 

3242 inv_s = tl.where(s_vals > eps, 1.0 / s_vals, 0.0) 

3243 

3244 acc = tl.zeros((BLOCK_R, 16), dtype=tl.float32) 

3245 a_base = A + batch * M * N 

3246 e_base = EVECS + batch * EVECS_BATCH_STRIDE 

3247 for k in tl.static_range(0, 16): 

3248 eig = tl.load(e_base + k * EVECS_ROW_STRIDE + src_cols * EVECS_COL_STRIDE) 

3249 if TALL: 

3250 a_vals = tl.load( 

3251 a_base + rows * N + k, 

3252 mask=row_mask, 

3253 other=0.0, 

3254 ) 

3255 else: 

3256 a_vals = tl.load( 

3257 a_base + k * N + rows, 

3258 mask=row_mask, 

3259 other=0.0, 

3260 ) 

3261 acc += a_vals[:, None] * eig[None, :] 

3262 

3263 projected = acc * inv_s[None, :] 

3264 if TALL: 

3265 tl.store( 

3266 U + batch * M * 16 + rows[:, None] * 16 + cols[None, :], 

3267 projected, 

3268 mask=row_mask[:, None], 

3269 ) 

3270 else: 

3271 tl.store( 

3272 V + batch * N * 16 + rows[:, None] * 16 + cols[None, :], 

3273 projected, 

3274 mask=row_mask[:, None], 

3275 ) 

3276 

3277 head_mask = row_block == 0 

3278 tl.store(S + batch * 16 + cols, s_vals, mask=head_mask) 

3279 

3280 basis_rows = tl.arange(0, 16) 

3281 basis_cols = tl.arange(0, 16) 

3282 basis_src_cols = 15 - basis_cols 

3283 basis = tl.load( 

3284 e_base 

3285 + basis_rows[:, None] * EVECS_ROW_STRIDE 

3286 + basis_src_cols[None, :] * EVECS_COL_STRIDE 

3287 ) 

3288 if TALL: 

3289 tl.store( 

3290 V + batch * N * 16 + basis_rows[:, None] * 16 + basis_cols[None, :], 

3291 basis, 

3292 mask=head_mask, 

3293 ) 

3294 else: 

3295 tl.store( 

3296 U + batch * M * 16 + basis_rows[:, None] * 16 + basis_cols[None, :], 

3297 basis, 

3298 mask=head_mask, 

3299 ) 

3300 

3301 

3302def _gram16_svd(input): 

3303 return _unsupported_svd(input, True, True) 

3304 

3305 

3306def _large_native_svd(input): 

3307 if _can_use_hier_block_square_project_kernel(input, True, True): 

3308 return _hier_block_jacobi_square_project_svd(input) 

3309 if _can_use_blocked_square_project_kernel(input, True, True): 

3310 return _blocked_jacobi_square_project_svd(input) 

3311 if _can_use_blocked_jacobi_kernel(input, True, True): 

3312 return _blocked_jacobi_svd(input) 

3313 return _bidiagonal_qr_dqds_svd(input) 

3314 

3315 

3316def _bidiagonal_qr_dqds_svd(input): 

3317 return _unsupported_svd( 

3318 input, 

3319 True, 

3320 True, 

3321 "The k > 512 blocked-bidiagonalization plus QR/DQDS path is reserved " 

3322 "for the next native large-matrix solver stage.", 

3323 ) 

3324 

3325 

3326def _empty_svd_result(input, some=True, compute_uv=True): 

3327 _, m, n = _svd_shape(input) 

3328 k = min(m, n) 

3329 u_cols = k if compute_uv and some else m 

3330 v_cols = k if compute_uv and some else n 

3331 u = torch.empty( 

3332 (*input.shape[:-2], m, u_cols), dtype=input.dtype, device=input.device 

3333 ) 

3334 s = torch.empty((*input.shape[:-2], k), dtype=input.dtype, device=input.device) 

3335 v = torch.empty( 

3336 (*input.shape[:-2], n, v_cols), dtype=input.dtype, device=input.device 

3337 ) 

3338 return u, s, v 

3339 

3340 

3341@libentry() 

3342@triton.jit 

3343def _complete_svd_factor_kernel( 

3344 THIN, 

3345 FULL, 

3346 ROWS: tl.constexpr, 

3347 THIN_COLS: tl.constexpr, 

3348 FULL_COLS: tl.constexpr, 

3349 BLOCK_ROWS: tl.constexpr, 

3350 BLOCK_COLS: tl.constexpr, 

3351): 

3352 batch = tl.program_id(0) 

3353 rows = tl.arange(0, BLOCK_ROWS) 

3354 cols = tl.arange(0, BLOCK_COLS) 

3355 row_mask = rows < ROWS 

3356 col_mask = cols < FULL_COLS 

3357 vals = tl.load( 

3358 THIN + batch * ROWS * THIN_COLS + rows[:, None] * THIN_COLS + cols[None, :], 

3359 mask=row_mask[:, None] & (cols[None, :] < THIN_COLS), 

3360 other=0.0, 

3361 ) 

3362 identity = tl.where(rows[:, None] == cols[None, :], 1.0, 0.0) 

3363 vals = tl.where(cols[None, :] < THIN_COLS, vals, identity) 

3364 

3365 for c in tl.static_range(0, 64): 

3366 cur_mask = c < FULL_COLS 

3367 cur = tl.sum(tl.where(cols[None, :] == c, vals, 0.0), axis=1) 

3368 for p in tl.static_range(0, c): 

3369 prev = tl.sum(tl.where(cols[None, :] == p, vals, 0.0), axis=1) 

3370 coeff = tl.sum(tl.where(row_mask, prev * cur, 0.0), axis=0) 

3371 cur -= prev * coeff 

3372 norm_sq = tl.sum(tl.where(row_mask, cur * cur, 0.0), axis=0) 

3373 inv_norm = tl.rsqrt(tl.maximum(norm_sq, 1.0e-20)) 

3374 cur *= inv_norm 

3375 vals = tl.where((cols[None, :] == c) & cur_mask, cur[:, None], vals) 

3376 

3377 out_base = FULL + batch * ROWS * FULL_COLS 

3378 offsets = rows[:, None] * FULL_COLS + cols[None, :] 

3379 mask = row_mask[:, None] & col_mask[None, :] 

3380 tl.store(out_base + offsets, vals, mask=mask) 

3381 

3382 

3383def _low_precision_svd_via_float32(input, some=True, compute_uv=True): 

3384 u, s, v = svd(input.to(torch.float32), some=some, compute_uv=compute_uv) 

3385 return u.to(input.dtype), s.to(input.dtype), v.to(input.dtype) 

3386 

3387 

3388def _some_false_svd_via_thin(input): 

3389 batch, m, n = _svd_shape(input) 

3390 k = min(m, n) 

3391 thin_u, s, thin_v = svd(input, some=True, compute_uv=True) 

3392 u = torch.empty((*input.shape[:-2], m, m), dtype=input.dtype, device=input.device) 

3393 v = torch.empty((*input.shape[:-2], n, n), dtype=input.dtype, device=input.device) 

3394 with torch_device_fn.device(input.device): 

3395 _complete_svd_factor_kernel[(batch,)]( 

3396 thin_u, 

3397 u, 

3398 ROWS=m, 

3399 THIN_COLS=k, 

3400 FULL_COLS=m, 

3401 BLOCK_ROWS=triton.next_power_of_2(m), 

3402 BLOCK_COLS=triton.next_power_of_2(m), 

3403 num_warps=4, 

3404 ) 

3405 _complete_svd_factor_kernel[(batch,)]( 

3406 thin_v, 

3407 v, 

3408 ROWS=n, 

3409 THIN_COLS=k, 

3410 FULL_COLS=n, 

3411 BLOCK_ROWS=triton.next_power_of_2(n), 

3412 BLOCK_COLS=triton.next_power_of_2(n), 

3413 num_warps=4, 

3414 ) 

3415 return u, s, v 

3416 

3417 

3418def _compute_uv_false_result(input, s): 

3419 _, m, n = _svd_shape(input) 

3420 u = torch.empty((*input.shape[:-2], m, m), dtype=input.dtype, device=input.device) 

3421 v = torch.empty((*input.shape[:-2], n, n), dtype=input.dtype, device=input.device) 

3422 return u, s, v 

3423 

3424 

3425def _singular_values_only(input): 

3426 _, m, n = _svd_shape(input) 

3427 k = min(m, n) 

3428 largest = max(m, n) 

3429 if k == 2 and largest <= _RANK2_BLOCK_R_MAX: 

3430 return _rank2_singular_values(input) 

3431 if k <= 16 and largest <= 1024: 

3432 return _small_jacobi_singular_values(input) 

3433 if 16 < k <= 512 and largest <= 1024: 

3434 return _blocked_jacobi_singular_values(input) 

3435 return _unsupported_svd(input, True, False) 

3436 

3437 

3438def _should_use_gram16(batch, m, n): 

3439 return batch >= 16 and min(m, n) == 16 and max(m, n) <= 1024 

3440 

3441 

3442def _should_use_gram(batch, m, n): 

3443 k = min(m, n) 

3444 largest = max(m, n) 

3445 if k <= 32: 

3446 return True 

3447 if batch <= 4 and m == n and m <= 256: 

3448 return True 

3449 if (m, n) == (1024, 1024): 

3450 return True 

3451 if batch >= 128 and k <= 64 and largest <= 1024: 

3452 return False 

3453 return False 

3454 

3455 

3456def svd(input, some=True, compute_uv=True): 

3457 logger.debug("GEMS SVD") 

3458 if ( 

3459 input.device.type == "ptpu" 

3460 and input.dtype == torch.complex64 

3461 and some 

3462 and compute_uv 

3463 and input.dim() >= 2 

3464 and 0 not in input.shape[-2:] 

3465 and max(input.shape[-2:]) <= 16 

3466 ): 

3467 return SVDResult(*_complex_svd_cpu_fallback(input, some, compute_uv)) 

3468 if _is_low_precision_cuda_matrix(input): 

3469 return SVDResult(*_low_precision_svd_via_float32(input, some, compute_uv)) 

3470 if _is_float32_cuda_matrix(input) and 0 in input.shape[-2:]: 

3471 return SVDResult(*_empty_svd_result(input, some, compute_uv)) 

3472 if _can_use_singular_values_only(input, some, compute_uv): 

3473 return SVDResult(*_compute_uv_false_result(input, _singular_values_only(input))) 

3474 if ( 

3475 _is_float32_cuda_matrix(input) 

3476 and not some 

3477 and compute_uv 

3478 and max(input.shape[-2:]) <= 64 

3479 ): 

3480 return SVDResult(*_some_false_svd_via_thin(input)) 

3481 if not _is_float32_cuda_matrix(input) or not some: 

3482 return SVDResult(*_unsupported_svd(input, some, compute_uv)) 

3483 batch, m, n = _svd_shape(input) 

3484 k = min(m, n) 

3485 try: 

3486 if k == 1: 

3487 return SVDResult(*_rank1_svd(input)) 

3488 if k == 2 and max(m, n) <= _RANK2_BLOCK_R_MAX: 

3489 return SVDResult(*_rank2_svd(input)) 

3490 if k == 4 and m == 4 and n == 4 and batch >= 16: 

3491 return SVDResult(*_small4_square_svd(input)) 

3492 if _can_use_tall_wide_gram_jacobi_kernel(input, some, compute_uv): 

3493 return SVDResult(*_gram_jacobi_svd(input)) 

3494 use_batched_cyclic16 = k == 16 and batch >= 8 and max(m, n) <= 64 

3495 if ( 

3496 _can_use_small_jacobi_kernel(input, some, compute_uv) 

3497 and not use_batched_cyclic16 

3498 ): 

3499 return SVDResult(*_small_jacobi_svd(input)) 

3500 if _can_use_tsqr_cholesky_kernel(input, some, compute_uv): 

3501 return SVDResult(*_tsqr_cholesky_svd(input)) 

3502 if _can_use_projected_jacobi_kernel(input, some, compute_uv): 

3503 return SVDResult(*_projected_jacobi_svd(input)) 

3504 if _can_use_cyclic_jacobi_kernel(input, some, compute_uv): 

3505 return SVDResult(*_cyclic_jacobi_svd(input)) 

3506 if _can_use_hier_block_square_project_kernel(input, some, compute_uv): 

3507 return SVDResult(*_hier_block_jacobi_square_project_svd(input)) 

3508 if _can_use_blocked_square_project_kernel(input, some, compute_uv): 

3509 return SVDResult(*_blocked_jacobi_square_project_svd(input)) 

3510 if _can_use_blocked_jacobi_kernel(input, some, compute_uv): 

3511 return SVDResult(*_blocked_jacobi_svd(input)) 

3512 if _should_use_gram16(batch, m, n): 

3513 return SVDResult(*_gram16_svd(input)) 

3514 if _should_use_gram(batch, m, n): 

3515 return SVDResult(*_gram_svd(input)) 

3516 return SVDResult(*_large_native_svd(input)) 

3517 except RuntimeError: 

3518 return SVDResult(*_unsupported_svd(input, some, compute_uv))