Coverage for src/flag_gems/ops/svd.py: 21%

1838 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2from collections import namedtuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import device, torch_device_fn 

9from flag_gems.utils import libentry 

10 

11logger = logging.getLogger(__name__) 

12 

13SVDResult = namedtuple("SVDResult", ["U", "S", "V"]) 

14 

15_GRAM_CONDITION_GUARD_MAX_BATCH = 16 

16_GRAM_CONDITION_GUARD_MAX_K = 32 

17_GRAM_CONDITION_EIGEN_RATIO = 1.0e-8 

18_GRAM_TALL_WIDE_MAX_K = 32 

19_GRAM_TALL_WIDE_MAX_ROWS = 1024 

20_RANK1_BLOCK_R_MAX = 1024 

21_RANK2_BLOCK_R_MAX = 2048 

22_TSQR_CHOLESKY_MAX_BATCH = 32 

23_TSQR_CHOLESKY_MAX_K = 128 

24_TSQR_CHOLESKY_MAX_ROWS = 1024 

25 

26 

27def _unsupported_svd(input, some=True, compute_uv=True, reason=None): 

28 batch, m, n = _svd_shape(input) 

29 suffix = "" if reason is None else f" {reason}" 

30 raise NotImplementedError( 

31 "FlagGems native SVD currently supports float32 CUDA matrices with " 

32 "some=True, compute_uv=True, non-empty inputs, and native Triton " 

33 f"rank/Jacobi shape coverage; got batch={batch}, m={m}, n={n}, " 

34 f"dtype={input.dtype}, device={input.device}, some={some}, " 

35 f"compute_uv={compute_uv}.{suffix}" 

36 ) 

37 

38 

39def _is_iluvatar_backend(): 

40 return device.vendor_name == "iluvatar" 

41 

42 

43def _svd_shape(input): 

44 if input.dim() < 2: 

45 return 0, 0, 0 

46 m = input.shape[-2] 

47 n = input.shape[-1] 

48 batch = 1 

49 for dim in input.shape[:-2]: 

50 batch *= dim 

51 return batch, m, n 

52 

53 

54def _should_guard_gram_spectrum(batch, k): 

55 return batch <= _GRAM_CONDITION_GUARD_MAX_BATCH and k <= _GRAM_CONDITION_GUARD_MAX_K 

56 

57 

58def _is_float32_cuda_matrix(input): 

59 return input.is_cuda and input.dtype == torch.float32 and input.dim() >= 2 

60 

61 

62def _is_low_precision_cuda_matrix(input): 

63 return ( 

64 input.is_cuda 

65 and input.dtype in (torch.float16, torch.bfloat16) 

66 and input.dim() >= 2 

67 ) 

68 

69 

70def _can_use_rank1_kernel(input, some=True, compute_uv=True): 

71 _, m, n = _svd_shape(input) 

72 return _is_float32_cuda_matrix(input) and some and compute_uv and min(m, n) == 1 

73 

74 

75def _can_use_rank2_kernel(input, some=True, compute_uv=True): 

76 _, m, n = _svd_shape(input) 

77 return ( 

78 _is_float32_cuda_matrix(input) 

79 and some 

80 and compute_uv 

81 and min(m, n) == 2 

82 and max(m, n) <= _RANK2_BLOCK_R_MAX 

83 ) 

84 

85 

86def _can_use_2x2_kernel(input): 

87 _, m, n = _svd_shape(input) 

88 return _can_use_rank2_kernel(input, True, True) and m == 2 and n == 2 

89 

90 

91def _can_use_4x4_kernel(input, some=True, compute_uv=True): 

92 _, m, n = _svd_shape(input) 

93 return _is_float32_cuda_matrix(input) and some and compute_uv and m == 4 and n == 4 

94 

95 

96def _can_use_small_jacobi_kernel(input, some=True, compute_uv=True): 

97 _, m, n = _svd_shape(input) 

98 return ( 

99 _is_float32_cuda_matrix(input) 

100 and some 

101 and compute_uv 

102 and not _is_iluvatar_backend() 

103 and min(m, n) <= 16 

104 and max(m, n) <= 1024 

105 ) 

106 

107 

108def _can_use_cyclic_jacobi_kernel(input, some=True, compute_uv=True): 

109 _, m, n = _svd_shape(input) 

110 k = min(m, n) 

111 return ( 

112 _is_float32_cuda_matrix(input) 

113 and some 

114 and compute_uv 

115 and 16 <= k <= 64 

116 and max(m, n) <= 1024 

117 ) 

118 

119 

120def _can_use_gram_jacobi_kernel(input, some=True, compute_uv=True): 

121 _, m, n = _svd_shape(input) 

122 k = min(m, n) 

123 return ( 

124 _is_float32_cuda_matrix(input) 

125 and some 

126 and compute_uv 

127 and 16 <= k <= 32 

128 and max(m, n) <= 64 

129 ) 

130 

131 

132def _can_use_tall_wide_gram_jacobi_kernel(input, some=True, compute_uv=True): 

133 batch, m, n = _svd_shape(input) 

134 k = min(m, n) 

135 rows = max(m, n) 

136 return ( 

137 _is_float32_cuda_matrix(input) 

138 and some 

139 and compute_uv 

140 and batch >= 128 

141 and 16 <= k <= _GRAM_TALL_WIDE_MAX_K 

142 and rows <= _GRAM_TALL_WIDE_MAX_ROWS 

143 and rows >= 2 * k 

144 ) 

145 

146 

147def _can_use_tsqr_cholesky_kernel(input, some=True, compute_uv=True): 

148 # Input-dependent TSQR safety needs a native device-side guard before dispatch. 

149 return False 

150 

151 

152def _can_use_blocked_jacobi_kernel(input, some=True, compute_uv=True): 

153 _, m, n = _svd_shape(input) 

154 k = min(m, n) 

155 return ( 

156 _is_float32_cuda_matrix(input) 

157 and some 

158 and compute_uv 

159 and 64 < k <= 512 

160 and max(m, n) <= 1024 

161 ) 

162 

163 

164def _can_use_blocked_square_project_kernel(input, some=True, compute_uv=True): 

165 batch, m, n = _svd_shape(input) 

166 k = min(m, n) 

167 return ( 

168 _is_float32_cuda_matrix(input) 

169 and some 

170 and compute_uv 

171 and batch == 1 

172 and m == n 

173 and 128 <= k <= 512 

174 ) 

175 

176 

177def _can_use_hier_block_square_project_kernel(input, some=True, compute_uv=True): 

178 batch, m, n = _svd_shape(input) 

179 k = min(m, n) 

180 return ( 

181 _is_float32_cuda_matrix(input) 

182 and some 

183 and compute_uv 

184 and batch <= 2 

185 and m == n 

186 and k in (256, 512) 

187 ) 

188 

189 

190def _can_use_projected_jacobi_kernel(input, some=True, compute_uv=True): 

191 batch, m, n = _svd_shape(input) 

192 k = min(m, n) 

193 return ( 

194 _is_float32_cuda_matrix(input) 

195 and some 

196 and compute_uv 

197 and 4 <= batch <= 32 

198 and k == 64 

199 and max(m, n) <= 128 

200 ) 

201 

202 

203def _can_use_singular_values_only(input, some=True, compute_uv=False): 

204 _, m, n = _svd_shape(input) 

205 k = min(m, n) 

206 return ( 

207 _is_float32_cuda_matrix(input) 

208 and not compute_uv 

209 and k <= 512 

210 and max(m, n) <= 1024 

211 ) 

212 

213 

214@libentry() 

215@triton.jit 

216def _small_jacobi_svd_kernel( 

217 A, 

218 A_WORK, 

219 V_WORK, 

220 U, 

221 S, 

222 V, 

223 M: tl.constexpr, 

224 N: tl.constexpr, 

225 K: tl.constexpr, 

226 ROWS: tl.constexpr, 

227 TALL: tl.constexpr, 

228 BLOCK_R: tl.constexpr, 

229 BLOCK_K: tl.constexpr, 

230 SWEEPS: tl.constexpr, 

231): 

232 pid = tl.program_id(0) 

233 rows = tl.arange(0, BLOCK_R) 

234 cols = tl.arange(0, BLOCK_K) 

235 row_mask = rows < ROWS 

236 col_mask = cols < K 

237 eps = 1.0e-20 

238 

239 a_base = A + pid * M * N 

240 aw_base = A_WORK + pid * K * ROWS 

241 vw_base = V_WORK + pid * K * K 

242 

243 for j in tl.static_range(0, K): 

244 if TALL: 

245 vals = tl.load(a_base + rows * N + j, mask=row_mask, other=0.0).to( 

246 tl.float32 

247 ) 

248 else: 

249 vals = tl.load(a_base + j * N + rows, mask=row_mask, other=0.0).to( 

250 tl.float32 

251 ) 

252 tl.store(aw_base + j * ROWS + rows, vals, mask=row_mask) 

253 ident_col = tl.where(cols == j, 1.0, 0.0) 

254 tl.store(vw_base + j * K + cols, ident_col, mask=col_mask) 

255 

256 for _ in tl.static_range(0, SWEEPS): 

257 for p in tl.static_range(0, K): 

258 for q in tl.static_range(p + 1, K): 

259 ap = tl.load(aw_base + p * ROWS + rows, mask=row_mask, other=0.0) 

260 aq = tl.load(aw_base + q * ROWS + rows, mask=row_mask, other=0.0) 

261 alpha = tl.sum(ap * ap) 

262 beta = tl.sum(aq * aq) 

263 gamma = tl.sum(ap * aq) 

264 abs_gamma = tl.abs(gamma) 

265 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps) 

266 active = abs_gamma > threshold 

267 

268 safe_gamma = tl.where(active, gamma, 1.0) 

269 tau = (beta - alpha) / (2.0 * safe_gamma) 

270 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0) 

271 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)) 

272 c = tl.rsqrt(1.0 + t * t) 

273 s_rot = t * c 

274 c = tl.where(active, c, 1.0) 

275 s_rot = tl.where(active, s_rot, 0.0) 

276 

277 new_ap = c * ap - s_rot * aq 

278 new_aq = s_rot * ap + c * aq 

279 tl.store(aw_base + p * ROWS + rows, new_ap, mask=row_mask) 

280 tl.store(aw_base + q * ROWS + rows, new_aq, mask=row_mask) 

281 

282 vp = tl.load(vw_base + p * K + cols, mask=col_mask, other=0.0) 

283 vq = tl.load(vw_base + q * K + cols, mask=col_mask, other=0.0) 

284 new_vp = c * vp - s_rot * vq 

285 new_vq = s_rot * vp + c * vq 

286 tl.store(vw_base + p * K + cols, new_vp, mask=col_mask) 

287 tl.store(vw_base + q * K + cols, new_vq, mask=col_mask) 

288 

289 s_idx = tl.arange(0, BLOCK_K) 

290 s_vals = tl.full((BLOCK_K,), 0.0, dtype=tl.float32) 

291 for j in tl.static_range(0, K): 

292 col = tl.load(aw_base + j * ROWS + rows, mask=row_mask, other=0.0) 

293 norm = tl.sqrt(tl.sum(col * col)) 

294 s_vals = tl.where(s_idx == j, norm, s_vals) 

295 

296 ranks = tl.zeros((BLOCK_K,), dtype=tl.int32) 

297 for i in tl.static_range(0, K): 

298 si = tl.sum(tl.where(s_idx == i, s_vals, 0.0)) 

299 beats = ((si > s_vals) | ((si == s_vals) & (i < s_idx))) & (s_idx < K) 

300 ranks = ranks + beats.to(tl.int32) 

301 

302 for j in tl.static_range(0, K): 

303 col = tl.load(aw_base + j * ROWS + rows, mask=row_mask, other=0.0) 

304 norm = tl.sum(tl.where(s_idx == j, s_vals, 0.0)) 

305 rank = tl.sum(tl.where(s_idx == j, ranks, 0)) 

306 inv_norm = tl.where(norm > eps, 1.0 / norm, 0.0) 

307 tl.store(S + pid * K + rank, norm) 

308 

309 basis = tl.load(vw_base + j * K + cols, mask=col_mask, other=0.0) 

310 if TALL: 

311 tl.store(U + pid * M * K + rows * K + rank, col * inv_norm, mask=row_mask) 

312 tl.store(V + pid * N * K + cols * K + rank, basis, mask=col_mask) 

313 else: 

314 tl.store(U + pid * M * K + cols * K + rank, basis, mask=col_mask) 

315 tl.store(V + pid * N * K + rows * K + rank, col * inv_norm, mask=row_mask) 

316 

317 

318@libentry() 

319@triton.jit 

320def _small_jacobi_svals_kernel( 

321 A, 

322 A_WORK, 

323 S, 

324 M: tl.constexpr, 

325 N: tl.constexpr, 

326 K: tl.constexpr, 

327 ROWS: tl.constexpr, 

328 TALL: tl.constexpr, 

329 BLOCK_R: tl.constexpr, 

330 BLOCK_K: tl.constexpr, 

331 SWEEPS: tl.constexpr, 

332): 

333 pid = tl.program_id(0) 

334 rows = tl.arange(0, BLOCK_R) 

335 s_idx = tl.arange(0, BLOCK_K) 

336 row_mask = rows < ROWS 

337 eps = 1.0e-20 

338 

339 a_base = A + pid * M * N 

340 aw_base = A_WORK + pid * K * ROWS 

341 

342 for j in tl.static_range(0, K): 

343 if TALL: 

344 vals = tl.load(a_base + rows * N + j, mask=row_mask, other=0.0).to( 

345 tl.float32 

346 ) 

347 else: 

348 vals = tl.load(a_base + j * N + rows, mask=row_mask, other=0.0).to( 

349 tl.float32 

350 ) 

351 tl.store(aw_base + j * ROWS + rows, vals, mask=row_mask) 

352 

353 for _ in tl.static_range(0, SWEEPS): 

354 for p in tl.static_range(0, K): 

355 for q in tl.static_range(p + 1, K): 

356 ap = tl.load(aw_base + p * ROWS + rows, mask=row_mask, other=0.0) 

357 aq = tl.load(aw_base + q * ROWS + rows, mask=row_mask, other=0.0) 

358 alpha = tl.sum(ap * ap) 

359 beta = tl.sum(aq * aq) 

360 gamma = tl.sum(ap * aq) 

361 abs_gamma = tl.abs(gamma) 

362 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps) 

363 active = abs_gamma > threshold 

364 

365 safe_gamma = tl.where(active, gamma, 1.0) 

366 tau = (beta - alpha) / (2.0 * safe_gamma) 

367 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0) 

368 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)) 

369 c = tl.rsqrt(1.0 + t * t) 

370 s_rot = t * c 

371 c = tl.where(active, c, 1.0) 

372 s_rot = tl.where(active, s_rot, 0.0) 

373 

374 new_ap = c * ap - s_rot * aq 

375 new_aq = s_rot * ap + c * aq 

376 tl.store(aw_base + p * ROWS + rows, new_ap, mask=row_mask) 

377 tl.store(aw_base + q * ROWS + rows, new_aq, mask=row_mask) 

378 

379 s_vals = tl.full((BLOCK_K,), 0.0, dtype=tl.float32) 

380 for j in tl.static_range(0, K): 

381 col = tl.load(aw_base + j * ROWS + rows, mask=row_mask, other=0.0) 

382 norm = tl.sqrt(tl.sum(col * col)) 

383 s_vals = tl.where(s_idx == j, norm, s_vals) 

384 

385 ranks = tl.zeros((BLOCK_K,), dtype=tl.int32) 

386 for i in tl.static_range(0, K): 

387 si = tl.sum(tl.where(s_idx == i, s_vals, 0.0)) 

388 beats = ((si > s_vals) | ((si == s_vals) & (i < s_idx))) & (s_idx < K) 

389 ranks = ranks + beats.to(tl.int32) 

390 

391 for j in tl.static_range(0, K): 

392 norm = tl.sum(tl.where(s_idx == j, s_vals, 0.0)) 

393 rank = tl.sum(tl.where(s_idx == j, ranks, 0)) 

394 tl.store(S + pid * K + rank, norm) 

395 

396 

397def _can_use_streaming_jacobi_kernel(input, some=True, compute_uv=True): 

398 _, m, n = _svd_shape(input) 

399 return ( 

400 _is_float32_cuda_matrix(input) 

401 and some 

402 and compute_uv 

403 and 16 < min(m, n) <= 64 

404 and max(m, n) <= 1024 

405 ) 

406 

407 

408def _can_use_gram_kernel(input, some=True, compute_uv=True): 

409 _, m, n = _svd_shape(input) 

410 return _is_float32_cuda_matrix(input) and some and compute_uv and min(m, n) <= 1024 

411 

412 

413@libentry() 

414@triton.jit 

415def _triton_bmm_kernel( 

416 A, 

417 B, 

418 C, 

419 stride_ab, 

420 stride_am, 

421 stride_ak, 

422 stride_bb, 

423 stride_bk, 

424 stride_bn, 

425 M: tl.constexpr, 

426 N: tl.constexpr, 

427 K: tl.constexpr, 

428 BLOCK_M: tl.constexpr, 

429 BLOCK_N: tl.constexpr, 

430 BLOCK_K: tl.constexpr, 

431): 

432 tile = tl.program_id(0) 

433 batch = tl.program_id(1) 

434 tiles_n = tl.cdiv(N, BLOCK_N) 

435 tile_m = tile // tiles_n 

436 tile_n = tile - tile_m * tiles_n 

437 

438 offs_m = tile_m * BLOCK_M + tl.arange(0, BLOCK_M) 

439 offs_n = tile_n * BLOCK_N + tl.arange(0, BLOCK_N) 

440 offs_k = tl.arange(0, BLOCK_K) 

441 

442 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

443 a_base = A + batch * stride_ab 

444 b_base = B + batch * stride_bb 

445 for k_start in range(0, K, BLOCK_K): 

446 k = k_start + offs_k 

447 a = tl.load( 

448 a_base + offs_m[:, None] * stride_am + k[None, :] * stride_ak, 

449 mask=(offs_m[:, None] < M) & (k[None, :] < K), 

450 other=0.0, 

451 ) 

452 b = tl.load( 

453 b_base + k[:, None] * stride_bk + offs_n[None, :] * stride_bn, 

454 mask=(k[:, None] < K) & (offs_n[None, :] < N), 

455 other=0.0, 

456 ) 

457 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

458 

459 tl.store( 

460 C + batch * M * N + offs_m[:, None] * N + offs_n[None, :], 

461 acc, 

462 mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), 

463 ) 

464 

465 

466def _triton_bmm(left, right, out_shape): 

467 batch, m, k = left.shape 

468 right_batch, right_k, n = right.shape 

469 assert batch == right_batch, "Batch dim mismatch" 

470 assert k == right_k, "K dim mismatch" 

471 out = torch.empty((batch, m, n), dtype=left.dtype, device=left.device) 

472 block_m = 16 if m <= 16 else 32 

473 block_n = 16 if n <= 16 else 32 

474 block_k = 32 

475 grid = (triton.cdiv(m, block_m) * triton.cdiv(n, block_n), batch) 

476 with torch_device_fn.device(left.device): 

477 _triton_bmm_kernel[grid]( 

478 left, 

479 right, 

480 out, 

481 left.stride(0), 

482 left.stride(1), 

483 left.stride(2), 

484 right.stride(0), 

485 right.stride(1), 

486 right.stride(2), 

487 M=m, 

488 N=n, 

489 K=k, 

490 BLOCK_M=block_m, 

491 BLOCK_N=block_n, 

492 BLOCK_K=block_k, 

493 num_warps=1 if block_m == 16 and block_n == 16 else 4, 

494 ) 

495 return out.reshape(out_shape) 

496 

497 

498@libentry() 

499@triton.jit 

500def _gram_build_tiled_kernel( 

501 A, 

502 GRAM, 

503 M: tl.constexpr, 

504 N: tl.constexpr, 

505 K: tl.constexpr, 

506 ROWS: tl.constexpr, 

507 TALL: tl.constexpr, 

508 BLOCK_I: tl.constexpr, 

509 BLOCK_J: tl.constexpr, 

510 BLOCK_R: tl.constexpr, 

511): 

512 tile_i = tl.program_id(0) 

513 tile_j = tl.program_id(1) 

514 batch = tl.program_id(2) 

515 offs_i = tile_i * BLOCK_I + tl.arange(0, BLOCK_I) 

516 offs_j = tile_j * BLOCK_J + tl.arange(0, BLOCK_J) 

517 rows = tl.arange(0, BLOCK_R) 

518 i_mask = offs_i < K 

519 j_mask = offs_j < K 

520 a_base = A + batch * M * N 

521 acc = tl.zeros((BLOCK_I, BLOCK_J), dtype=tl.float32) 

522 

523 for row_start in range(0, ROWS, BLOCK_R): 

524 chunk_rows = row_start + rows 

525 row_mask = chunk_rows < ROWS 

526 if TALL: 

527 lhs = tl.load( 

528 a_base + chunk_rows[None, :] * N + offs_i[:, None], 

529 mask=i_mask[:, None] & row_mask[None, :], 

530 other=0.0, 

531 ).to(tl.float32) 

532 rhs = tl.load( 

533 a_base + chunk_rows[:, None] * N + offs_j[None, :], 

534 mask=row_mask[:, None] & j_mask[None, :], 

535 other=0.0, 

536 ).to(tl.float32) 

537 else: 

538 lhs = tl.load( 

539 a_base + offs_i[:, None] * N + chunk_rows[None, :], 

540 mask=i_mask[:, None] & row_mask[None, :], 

541 other=0.0, 

542 ).to(tl.float32) 

543 rhs = tl.load( 

544 a_base + offs_j[None, :] * N + chunk_rows[:, None], 

545 mask=row_mask[:, None] & j_mask[None, :], 

546 other=0.0, 

547 ).to(tl.float32) 

548 acc += tl.dot(lhs, rhs, out_dtype=tl.float32, allow_tf32=False) 

549 

550 tl.store( 

551 GRAM + batch * K * K + offs_i[:, None] * K + offs_j[None, :], 

552 acc, 

553 mask=i_mask[:, None] & j_mask[None, :], 

554 ) 

555 

556 

557@libentry() 

558@triton.jit 

559def _cholesky_upper_kernel( 

560 GRAM, 

561 R, 

562 STATUS, 

563 K: tl.constexpr, 

564 BLOCK_K: tl.constexpr, 

565): 

566 batch = tl.program_id(0) 

567 cols = tl.arange(0, BLOCK_K) 

568 col_mask = cols < K 

569 base_g = GRAM + batch * K * K 

570 base_r = R + batch * K * K 

571 

572 diag_vals = tl.load( 

573 base_g + cols * K + cols, 

574 mask=col_mask, 

575 other=0.0, 

576 ).to(tl.float32) 

577 max_diag = tl.max(tl.abs(diag_vals), axis=0) 

578 tol = tl.maximum(max_diag * 1.0e-8, 1.0e-20) 

579 status = tl.full((), 0, dtype=tl.int32) 

580 finite_limit = 3.4028234663852886e38 

581 

582 j = 0 

583 while j < K: 

584 row_mask = col_mask & (cols >= j) 

585 gram_row = tl.load( 

586 base_g + j * K + cols, 

587 mask=row_mask, 

588 other=0.0, 

589 ).to(tl.float32) 

590 diag = tl.load(base_g + j * K + j).to(tl.float32) 

591 

592 p = 0 

593 while p < j: 

594 r_pj = tl.load(base_r + p * K + j).to(tl.float32) 

595 r_pcols = tl.load( 

596 base_r + p * K + cols, 

597 mask=row_mask, 

598 other=0.0, 

599 ).to(tl.float32) 

600 gram_row -= r_pj * r_pcols 

601 diag -= r_pj * r_pj 

602 p += 1 

603 

604 good_diag = (diag == diag) & (tl.abs(diag) < finite_limit) & (diag > tol) 

605 pivot = tl.sqrt(tl.maximum(diag, tol)) 

606 r_vals = gram_row / pivot 

607 r_vals = tl.where(cols == j, pivot, r_vals) 

608 r_vals = tl.where(row_mask, r_vals, 0.0) 

609 bad_vals = tl.sum( 

610 (((r_vals != r_vals) | (tl.abs(r_vals) >= finite_limit)) & row_mask).to( 

611 tl.int32 

612 ), 

613 axis=0, 

614 ) 

615 status = tl.where(good_diag & (bad_vals == 0), status, 1) 

616 tl.store(base_r + j * K + cols, r_vals, mask=col_mask) 

617 j += 1 

618 

619 tl.store(STATUS + batch, status) 

620 

621 

622def _tsqr_guard_fallback_svd(input): 

623 _, m, n = _svd_shape(input) 

624 k = min(m, n) 

625 if 16 <= k <= 64 and max(m, n) <= 1024: 

626 return _cyclic_jacobi_svd(input) 

627 if 64 < k <= 512 and max(m, n) <= 1024: 

628 return _blocked_jacobi_svd(input) 

629 return _unsupported_svd( 

630 input, 

631 True, 

632 True, 

633 "TSQR/Cholesky guard could not find a native Jacobi fallback.", 

634 ) 

635 

636 

637def _tsqr_cholesky_svd(input): 

638 batch, m, n = _svd_shape(input) 

639 k = min(m, n) 

640 rows = max(m, n) 

641 tall = m >= n 

642 a = input.contiguous().reshape(batch, m, n) 

643 gram = torch.empty((batch, k, k), dtype=torch.float32, device=input.device) 

644 r = torch.empty((batch, k, k), dtype=torch.float32, device=input.device) 

645 status = torch.empty((batch,), dtype=torch.int32, device=input.device) 

646 block_k = triton.next_power_of_2(k) 

647 block_tile = 32 

648 block_r = 64 

649 

650 with torch_device_fn.device(input.device): 

651 _gram_build_tiled_kernel[ 

652 ( 

653 triton.cdiv(k, block_tile), 

654 triton.cdiv(k, block_tile), 

655 batch, 

656 ) 

657 ]( 

658 a, 

659 gram, 

660 M=m, 

661 N=n, 

662 K=k, 

663 ROWS=rows, 

664 TALL=tall, 

665 BLOCK_I=block_tile, 

666 BLOCK_J=block_tile, 

667 BLOCK_R=block_r, 

668 num_warps=4, 

669 ) 

670 _cholesky_upper_kernel[(batch,)]( 

671 gram, 

672 r, 

673 status, 

674 K=k, 

675 BLOCK_K=block_k, 

676 num_warps=4, 

677 ) 

678 

679 _, s, basis = svd(r, some=True, compute_uv=True) 

680 basis = basis.reshape(batch, k, k) 

681 s = s.reshape(batch, k) 

682 

683 if tall: 

684 u = _triton_bmm(a, basis, (batch, m, k)) 

685 v = basis 

686 projected = u 

687 projected_rows = m 

688 else: 

689 u = basis 

690 v = _triton_bmm(a.transpose(1, 2).contiguous(), basis, (batch, n, k)) 

691 projected = v 

692 projected_rows = n 

693 

694 with torch_device_fn.device(input.device): 

695 _normalize_projection_kernel[(batch, k)]( 

696 projected, 

697 s, 

698 ROWS=projected_rows, 

699 K=k, 

700 BLOCK_R=triton.next_power_of_2(projected_rows), 

701 num_warps=1 if projected_rows <= 64 else 4, 

702 ) 

703 

704 return ( 

705 u.reshape(*input.shape[:-2], m, k), 

706 s.reshape(*input.shape[:-2], k), 

707 v.reshape(*input.shape[:-2], n, k), 

708 ) 

709 

710 

711@libentry() 

712@triton.jit 

713def _gram_build_kernel( 

714 A, 

715 GRAM, 

716 M: tl.constexpr, 

717 N: tl.constexpr, 

718 K: tl.constexpr, 

719 ROWS: tl.constexpr, 

720 TALL: tl.constexpr, 

721 BLOCK_K: tl.constexpr, 

722 BLOCK_R: tl.constexpr, 

723): 

724 batch = tl.program_id(0) 

725 i = tl.arange(0, BLOCK_K) 

726 j = tl.arange(0, BLOCK_K) 

727 rows = tl.arange(0, BLOCK_R) 

728 k_mask = i < K 

729 a_base = A + batch * M * N 

730 acc = tl.zeros((BLOCK_K, BLOCK_K), dtype=tl.float32) 

731 

732 for row_start in range(0, ROWS, BLOCK_R): 

733 chunk_rows = row_start + rows 

734 row_mask = chunk_rows < ROWS 

735 

736 if TALL: 

737 lhs = tl.load( 

738 a_base + chunk_rows[None, :] * N + i[:, None], 

739 mask=k_mask[:, None] & row_mask[None, :], 

740 other=0.0, 

741 ).to(tl.float32) 

742 rhs = tl.load( 

743 a_base + chunk_rows[:, None] * N + j[None, :], 

744 mask=row_mask[:, None] & (j[None, :] < K), 

745 other=0.0, 

746 ).to(tl.float32) 

747 else: 

748 lhs = tl.load( 

749 a_base + i[:, None] * N + chunk_rows[None, :], 

750 mask=k_mask[:, None] & row_mask[None, :], 

751 other=0.0, 

752 ).to(tl.float32) 

753 rhs = tl.load( 

754 a_base + j[None, :] * N + chunk_rows[:, None], 

755 mask=row_mask[:, None] & (j[None, :] < K), 

756 other=0.0, 

757 ).to(tl.float32) 

758 

759 acc += tl.dot(lhs, rhs, out_dtype=tl.float32, allow_tf32=False) 

760 

761 tl.store( 

762 GRAM + batch * K * K + i[:, None] * K + j[None, :], 

763 acc, 

764 mask=k_mask[:, None] & (j[None, :] < K), 

765 ) 

766 

767 

768@libentry() 

769@triton.jit 

770def _gram_jacobi_sym_kernel( 

771 GRAM, 

772 EVECS, 

773 EVALS, 

774 K, 

775 SWEEPS, 

776 BLOCK_K: tl.constexpr, 

777): 

778 batch = tl.program_id(0) 

779 r = tl.arange(0, BLOCK_K) 

780 cidx = tl.arange(0, BLOCK_K) 

781 rr = r[:, None] 

782 cc = cidx[None, :] 

783 mask = (rr < K) & (cc < K) 

784 base = GRAM + batch * K * K 

785 

786 g = tl.load(base + rr * K + cc, mask=mask, other=0.0).to(tl.float32) 

787 v = tl.where((rr == cc) & mask, 1.0, 0.0) 

788 eps = 1.0e-20 

789 

790 sweep = 0 

791 while sweep < SWEEPS: 

792 p = 0 

793 while p < K - 1: 

794 q = p + 1 

795 while q < K: 

796 diag_p = tl.sum(tl.where((rr == p) & (cc == p), g, 0.0)) 

797 diag_q = tl.sum(tl.where((rr == q) & (cc == q), g, 0.0)) 

798 off = tl.sum(tl.where((rr == p) & (cc == q), g, 0.0)) 

799 abs_off = tl.abs(off) 

800 threshold = 1.0e-7 * tl.sqrt(tl.abs(diag_p * diag_q) + eps) 

801 active = abs_off > threshold 

802 safe_off = tl.where(active, off, 1.0) 

803 tau = (diag_q - diag_p) / (2.0 * safe_off) 

804 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0) 

805 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)) 

806 crot = tl.rsqrt(1.0 + t * t) 

807 srot = t * crot 

808 crot = tl.where(active, crot, 1.0) 

809 srot = tl.where(active, srot, 0.0) 

810 

811 col_p = tl.sum(tl.where(cc == p, g, 0.0), axis=1) 

812 col_q = tl.sum(tl.where(cc == q, g, 0.0), axis=1) 

813 row_p = tl.sum(tl.where(rr == p, g, 0.0), axis=0) 

814 row_q = tl.sum(tl.where(rr == q, g, 0.0), axis=0) 

815 

816 new_col_p = crot * col_p - srot * col_q 

817 new_col_q = srot * col_p + crot * col_q 

818 new_row_p = crot * row_p - srot * row_q 

819 new_row_q = srot * row_p + crot * row_q 

820 g = tl.where(cc == p, new_col_p[:, None], g) 

821 g = tl.where(cc == q, new_col_q[:, None], g) 

822 g = tl.where(rr == p, new_row_p[None, :], g) 

823 g = tl.where(rr == q, new_row_q[None, :], g) 

824 

825 new_pp = ( 

826 crot * crot * diag_p 

827 - 2.0 * crot * srot * off 

828 + srot * srot * diag_q 

829 ) 

830 new_qq = ( 

831 srot * srot * diag_p 

832 + 2.0 * crot * srot * off 

833 + crot * crot * diag_q 

834 ) 

835 g = tl.where((rr == p) & (cc == p), new_pp, g) 

836 g = tl.where((rr == q) & (cc == q), new_qq, g) 

837 g = tl.where(((rr == p) & (cc == q)) | ((rr == q) & (cc == p)), 0.0, g) 

838 

839 vec_p = tl.sum(tl.where(cc == p, v, 0.0), axis=1) 

840 vec_q = tl.sum(tl.where(cc == q, v, 0.0), axis=1) 

841 new_vec_p = crot * vec_p - srot * vec_q 

842 new_vec_q = srot * vec_p + crot * vec_q 

843 v = tl.where(cc == p, new_vec_p[:, None], v) 

844 v = tl.where(cc == q, new_vec_q[:, None], v) 

845 q += 1 

846 p += 1 

847 sweep += 1 

848 

849 diag = tl.sum(tl.where(rr == cc, g, 0.0), axis=1) 

850 tl.store(EVALS + batch * K + r, diag, mask=r < K) 

851 tl.store(EVECS + batch * K * K + rr * K + cc, v, mask=mask) 

852 

853 

854@libentry() 

855@triton.jit 

856def _gram_sort_basis_kernel( 

857 EVALS, 

858 EVECS, 

859 BASIS, 

860 S, 

861 K: tl.constexpr, 

862 BLOCK_K: tl.constexpr, 

863): 

864 batch = tl.program_id(0) 

865 col = tl.program_id(1) 

866 rows = tl.arange(0, BLOCK_K) 

867 row_mask = rows < K 

868 eval_col = tl.maximum(tl.load(EVALS + batch * K + col), 0.0) 

869 rank = tl.full((), 0, dtype=tl.int32) 

870 for other in tl.static_range(0, K): 

871 eval_other = tl.maximum(tl.load(EVALS + batch * K + other), 0.0) 

872 rank += ( 

873 (eval_other > eval_col) | ((eval_other == eval_col) & (other < col)) 

874 ).to(tl.int32) 

875 

876 vec = tl.load( 

877 EVECS + batch * K * K + rows * K + col, 

878 mask=row_mask, 

879 other=0.0, 

880 ) 

881 tl.store(S + batch * K + rank, tl.sqrt(eval_col)) 

882 tl.store( 

883 BASIS + batch * K * K + rows * K + rank, 

884 vec, 

885 mask=row_mask, 

886 ) 

887 

888 

889@libentry() 

890@triton.jit 

891def _normalize_projection_kernel( 

892 Q, 

893 S, 

894 ROWS: tl.constexpr, 

895 K: tl.constexpr, 

896 BLOCK_R: tl.constexpr, 

897): 

898 batch = tl.program_id(0) 

899 col = tl.program_id(1) 

900 rows = tl.arange(0, BLOCK_R) 

901 mask = rows < ROWS 

902 eps = 1.0e-20 

903 sval = tl.load(S + batch * K + col) 

904 vals = tl.load(Q + batch * ROWS * K + rows * K + col, mask=mask, other=0.0) 

905 vals = vals / tl.maximum(sval, eps) 

906 tl.store(Q + batch * ROWS * K + rows * K + col, vals, mask=mask) 

907 

908 

909@libentry() 

910@triton.jit 

911def _renorm_projection_update_s_kernel( 

912 Q, 

913 S, 

914 ROWS: tl.constexpr, 

915 K: tl.constexpr, 

916 BLOCK_R: tl.constexpr, 

917): 

918 batch = tl.program_id(0) 

919 col = tl.program_id(1) 

920 rows = tl.arange(0, BLOCK_R) 

921 mask = rows < ROWS 

922 vals = tl.load(Q + batch * ROWS * K + rows * K + col, mask=mask, other=0.0) 

923 vals_f32 = vals.to(tl.float32) 

924 norm = tl.sqrt(tl.sum(vals_f32 * vals_f32, axis=0)) 

925 inv_norm = tl.rsqrt(tl.maximum(norm * norm, 1.0e-40)) 

926 basis = tl.where(rows == col, 1.0, 0.0) 

927 vals = tl.where(norm <= 1.0e-20, basis, vals * inv_norm) 

928 tl.store(S + batch * K + col, norm) 

929 tl.store(Q + batch * ROWS * K + rows * K + col, vals, mask=mask) 

930 

931 

932@libentry() 

933@triton.jit 

934def _complete_zero_projection_kernel( 

935 Q, 

936 S, 

937 ROWS: tl.constexpr, 

938 K: tl.constexpr, 

939 BLOCK_R: tl.constexpr, 

940): 

941 batch = tl.program_id(0) 

942 col = tl.program_id(1) 

943 rows = tl.arange(0, BLOCK_R) 

944 mask = rows < ROWS 

945 eps = 1.0e-12 

946 sval = tl.load(S + batch * K + col) 

947 basis = tl.where(rows == col, 1.0, 0.0) 

948 old = tl.load(Q + batch * ROWS * K + rows * K + col, mask=mask, other=0.0) 

949 vals = tl.where(sval <= eps, basis, old) 

950 tl.store(Q + batch * ROWS * K + rows * K + col, vals, mask=mask) 

951 

952 

953def _gram_jacobi_svd(input): 

954 batch, m, n = _svd_shape(input) 

955 k = min(m, n) 

956 rows = max(m, n) 

957 tall = m >= n 

958 a = input.contiguous().reshape(batch, m, n) 

959 gram = torch.empty((batch, k, k), dtype=torch.float32, device=input.device) 

960 eigvecs = torch.empty((batch, k, k), dtype=torch.float32, device=input.device) 

961 evals = torch.empty((batch, k), dtype=torch.float32, device=input.device) 

962 basis = torch.empty((batch, k, k), dtype=torch.float32, device=input.device) 

963 s = torch.empty((batch, k), dtype=input.dtype, device=input.device) 

964 block_k = triton.next_power_of_2(k) 

965 block_r = min(triton.next_power_of_2(rows), 64 if k > 32 else 128) 

966 sweeps = 12 if k <= 17 else 10 

967 

968 with torch_device_fn.device(input.device): 

969 _gram_build_kernel[(batch,)]( 

970 a, 

971 gram, 

972 M=m, 

973 N=n, 

974 K=k, 

975 ROWS=rows, 

976 TALL=tall, 

977 BLOCK_K=block_k, 

978 BLOCK_R=block_r, 

979 num_warps=4, 

980 ) 

981 _gram_jacobi_sym_kernel[(batch,)]( 

982 gram, 

983 eigvecs, 

984 evals, 

985 k, 

986 sweeps, 

987 BLOCK_K=block_k, 

988 num_warps=4, 

989 ) 

990 with torch_device_fn.device(input.device): 

991 _gram_sort_basis_kernel[(batch, k)]( 

992 evals, 

993 eigvecs, 

994 basis, 

995 s, 

996 K=k, 

997 BLOCK_K=block_k, 

998 num_warps=1, 

999 ) 

1000 

1001 if tall: 

1002 u = _triton_bmm(a, basis, (batch, m, k)) 

1003 v = basis 

1004 proj_rows = m 

1005 else: 

1006 a_t = a.transpose(1, 2).contiguous() 

1007 v = _triton_bmm(a_t, basis, (batch, n, k)) 

1008 u = basis 

1009 proj_rows = n 

1010 

1011 with torch_device_fn.device(input.device): 

1012 _renorm_projection_update_s_kernel[(batch, k)]( 

1013 u if tall else v, 

1014 s, 

1015 ROWS=proj_rows, 

1016 K=k, 

1017 BLOCK_R=triton.next_power_of_2(proj_rows), 

1018 num_warps=1 if proj_rows <= 64 else 4, 

1019 ) 

1020 _complete_zero_projection_kernel[(batch, k)]( 

1021 u if tall else v, 

1022 s, 

1023 ROWS=proj_rows, 

1024 K=k, 

1025 BLOCK_R=triton.next_power_of_2(proj_rows), 

1026 num_warps=1 if proj_rows <= 64 else 4, 

1027 ) 

1028 if k <= _GRAM_TALL_WIDE_MAX_K: 

1029 _thin_reorthogonalize_kernel[(batch,)]( 

1030 u if tall else v, 

1031 ROWS=proj_rows, 

1032 K=k, 

1033 BLOCK_R=triton.next_power_of_2(proj_rows), 

1034 num_warps=1 if proj_rows <= 64 else 4, 

1035 ) 

1036 

1037 return ( 

1038 u.reshape(*input.shape[:-2], m, k), 

1039 s.reshape(*input.shape[:-2], k), 

1040 v.reshape(*input.shape[:-2], n, k), 

1041 ) 

1042 

1043 

1044@triton.jit 

1045def _rotate_pair_4(ap, aq, vp, vq): 

1046 eps = 1.0e-20 

1047 alpha = tl.sum(ap * ap, axis=1) 

1048 beta = tl.sum(aq * aq, axis=1) 

1049 gamma = tl.sum(ap * aq, axis=1) 

1050 abs_gamma = tl.abs(gamma) 

1051 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps) 

1052 active = abs_gamma > threshold 

1053 safe_gamma = tl.where(active, gamma, 1.0) 

1054 tau = (beta - alpha) / (2.0 * safe_gamma) 

1055 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0) 

1056 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)) 

1057 c = tl.rsqrt(1.0 + t * t) 

1058 s_rot = t * c 

1059 c = tl.where(active, c, 1.0) 

1060 s_rot = tl.where(active, s_rot, 0.0) 

1061 new_ap = c[:, None] * ap - s_rot[:, None] * aq 

1062 new_aq = s_rot[:, None] * ap + c[:, None] * aq 

1063 new_vp = c[:, None] * vp - s_rot[:, None] * vq 

1064 new_vq = s_rot[:, None] * vp + c[:, None] * vq 

1065 return new_ap, new_aq, new_vp, new_vq 

1066 

1067 

1068@libentry() 

1069@triton.jit 

1070def _small4_square_svd_kernel( 

1071 A, 

1072 U, 

1073 S, 

1074 V, 

1075 BATCH: tl.constexpr, 

1076 BLOCK_B: tl.constexpr, 

1077 SWEEPS: tl.constexpr, 

1078): 

1079 pid = tl.program_id(0) 

1080 b = pid * BLOCK_B + tl.arange(0, BLOCK_B) 

1081 r = tl.arange(0, 4) 

1082 bb = b[:, None] 

1083 rr = r[None, :] 

1084 mask = b < BATCH 

1085 full_mask = (bb < BATCH) & (rr < 4) 

1086 base = A + bb * 16 + rr * 4 

1087 

1088 c0 = tl.load(base, mask=full_mask, other=0.0).to(tl.float32) 

1089 c1 = tl.load(base + 1, mask=full_mask, other=0.0).to(tl.float32) 

1090 c2 = tl.load(base + 2, mask=full_mask, other=0.0).to(tl.float32) 

1091 c3 = tl.load(base + 3, mask=full_mask, other=0.0).to(tl.float32) 

1092 

1093 v0 = tl.where(rr == 0, 1.0, 0.0) 

1094 v1 = tl.where(rr == 1, 1.0, 0.0) 

1095 v2 = tl.where(rr == 2, 1.0, 0.0) 

1096 v3 = tl.where(rr == 3, 1.0, 0.0) 

1097 

1098 for _ in tl.static_range(0, SWEEPS): 

1099 c0, c1, v0, v1 = _rotate_pair_4(c0, c1, v0, v1) 

1100 c0, c2, v0, v2 = _rotate_pair_4(c0, c2, v0, v2) 

1101 c0, c3, v0, v3 = _rotate_pair_4(c0, c3, v0, v3) 

1102 c1, c2, v1, v2 = _rotate_pair_4(c1, c2, v1, v2) 

1103 c1, c3, v1, v3 = _rotate_pair_4(c1, c3, v1, v3) 

1104 c2, c3, v2, v3 = _rotate_pair_4(c2, c3, v2, v3) 

1105 

1106 s0 = tl.sqrt(tl.sum(c0 * c0, axis=1)) 

1107 s1 = tl.sqrt(tl.sum(c1 * c1, axis=1)) 

1108 s2 = tl.sqrt(tl.sum(c2 * c2, axis=1)) 

1109 s3 = tl.sqrt(tl.sum(c3 * c3, axis=1)) 

1110 r0 = (s1 > s0).to(tl.int32) + (s2 > s0).to(tl.int32) + (s3 > s0).to(tl.int32) 

1111 r1 = ((s0 >= s1).to(tl.int32)) + (s2 > s1).to(tl.int32) + (s3 > s1).to(tl.int32) 

1112 r2 = ((s0 >= s2).to(tl.int32)) + ((s1 >= s2).to(tl.int32)) + (s3 > s2).to(tl.int32) 

1113 r3 = ( 

1114 ((s0 >= s3).to(tl.int32)) 

1115 + ((s1 >= s3).to(tl.int32)) 

1116 + ((s2 >= s3).to(tl.int32)) 

1117 ) 

1118 eps = 1.0e-20 

1119 

1120 tl.store(S + b * 4 + r0, s0, mask=mask) 

1121 tl.store(S + b * 4 + r1, s1, mask=mask) 

1122 tl.store(S + b * 4 + r2, s2, mask=mask) 

1123 tl.store(S + b * 4 + r3, s3, mask=mask) 

1124 

1125 tl.store( 

1126 U + bb * 16 + rr * 4 + r0[:, None], 

1127 c0 / tl.maximum(s0[:, None], eps), 

1128 mask=full_mask, 

1129 ) 

1130 tl.store( 

1131 U + bb * 16 + rr * 4 + r1[:, None], 

1132 c1 / tl.maximum(s1[:, None], eps), 

1133 mask=full_mask, 

1134 ) 

1135 tl.store( 

1136 U + bb * 16 + rr * 4 + r2[:, None], 

1137 c2 / tl.maximum(s2[:, None], eps), 

1138 mask=full_mask, 

1139 ) 

1140 tl.store( 

1141 U + bb * 16 + rr * 4 + r3[:, None], 

1142 c3 / tl.maximum(s3[:, None], eps), 

1143 mask=full_mask, 

1144 ) 

1145 

1146 tl.store(V + bb * 16 + rr * 4 + r0[:, None], v0, mask=full_mask) 

1147 tl.store(V + bb * 16 + rr * 4 + r1[:, None], v1, mask=full_mask) 

1148 tl.store(V + bb * 16 + rr * 4 + r2[:, None], v2, mask=full_mask) 

1149 tl.store(V + bb * 16 + rr * 4 + r3[:, None], v3, mask=full_mask) 

1150 

1151 

1152@libentry() 

1153@triton.jit 

1154def _rank2_svd_tiny_kernel( 

1155 A, 

1156 U, 

1157 S, 

1158 V, 

1159 BATCH: tl.constexpr, 

1160 M: tl.constexpr, 

1161 N: tl.constexpr, 

1162 TALL: tl.constexpr, 

1163 BLOCK_B: tl.constexpr, 

1164 BLOCK_R: tl.constexpr, 

1165): 

1166 pid = tl.program_id(0) 

1167 b = pid * BLOCK_B + tl.arange(0, BLOCK_B) 

1168 r = tl.arange(0, BLOCK_R) 

1169 bb = b[:, None] 

1170 rr = r[None, :] 

1171 bmask = b < BATCH 

1172 eps = 1.0e-20 

1173 

1174 if TALL: 

1175 mask = (bb < BATCH) & (rr < M) 

1176 base = A + bb * M * N + rr * N 

1177 x = tl.load(base, mask=mask, other=0.0).to(tl.float32) 

1178 y = tl.load(base + 1, mask=mask, other=0.0).to(tl.float32) 

1179 else: 

1180 mask = (bb < BATCH) & (rr < N) 

1181 base = A + bb * M * N + rr 

1182 x = tl.load(base, mask=mask, other=0.0).to(tl.float32) 

1183 y = tl.load(base + N, mask=mask, other=0.0).to(tl.float32) 

1184 

1185 aa = tl.sum(x * x, axis=1) 

1186 bbv = tl.sum(y * y, axis=1) 

1187 ab = tl.sum(x * y, axis=1) 

1188 diff = aa - bbv 

1189 root = tl.sqrt(diff * diff + 4.0 * ab * ab) 

1190 l0 = tl.maximum(0.0, 0.5 * (aa + bbv + root)) 

1191 det = tl.maximum(0.0, aa * bbv - ab * ab) 

1192 l1 = tl.where(l0 > eps, det / l0, 0.0) 

1193 s0 = tl.sqrt(l0) 

1194 s1 = tl.sqrt(l1) 

1195 

1196 ab_abs = tl.abs(ab) 

1197 aa_ge_bb = aa >= bbv 

1198 vx0 = tl.where(ab_abs > eps, ab, tl.where(aa_ge_bb, 1.0, 0.0)) 

1199 vy0 = tl.where(ab_abs > eps, l0 - aa, tl.where(aa_ge_bb, 0.0, 1.0)) 

1200 inv_norm = tl.rsqrt(vx0 * vx0 + vy0 * vy0 + eps) 

1201 vx0 = vx0 * inv_norm 

1202 vy0 = vy0 * inv_norm 

1203 vx1 = -vy0 

1204 vy1 = vx0 

1205 

1206 tl.store(S + b * 2, s0, mask=bmask) 

1207 tl.store(S + b * 2 + 1, s1, mask=bmask) 

1208 inv_s0 = tl.where(s0 > eps, 1.0 / s0, 0.0) 

1209 inv_s1 = tl.where(s1 > eps, 1.0 / s1, 0.0) 

1210 

1211 if TALL: 

1212 u0 = (x * vx0[:, None] + y * vy0[:, None]) * inv_s0[:, None] 

1213 u1 = (x * vx1[:, None] + y * vy1[:, None]) * inv_s1[:, None] 

1214 ubase = U + bb * M * 2 + rr * 2 

1215 tl.store(ubase, u0, mask=mask) 

1216 tl.store(ubase + 1, u1, mask=mask) 

1217 vbase = V + b * 4 

1218 tl.store(vbase, vx0, mask=bmask) 

1219 tl.store(vbase + 1, vx1, mask=bmask) 

1220 tl.store(vbase + 2, vy0, mask=bmask) 

1221 tl.store(vbase + 3, vy1, mask=bmask) 

1222 else: 

1223 ubase = U + b * 4 

1224 tl.store(ubase, vx0, mask=bmask) 

1225 tl.store(ubase + 1, vx1, mask=bmask) 

1226 tl.store(ubase + 2, vy0, mask=bmask) 

1227 tl.store(ubase + 3, vy1, mask=bmask) 

1228 v0 = (x * vx0[:, None] + y * vy0[:, None]) * inv_s0[:, None] 

1229 v1 = (x * vx1[:, None] + y * vy1[:, None]) * inv_s1[:, None] 

1230 vbase = V + bb * N * 2 + rr * 2 

1231 tl.store(vbase, v0, mask=mask) 

1232 tl.store(vbase + 1, v1, mask=mask) 

1233 

1234 

1235@libentry() 

1236@triton.jit 

1237def _rank2_svals_tiny_kernel( 

1238 A, 

1239 S, 

1240 BATCH: tl.constexpr, 

1241 M: tl.constexpr, 

1242 N: tl.constexpr, 

1243 TALL: tl.constexpr, 

1244 BLOCK_B: tl.constexpr, 

1245 BLOCK_R: tl.constexpr, 

1246): 

1247 pid = tl.program_id(0) 

1248 b = pid * BLOCK_B + tl.arange(0, BLOCK_B) 

1249 r = tl.arange(0, BLOCK_R) 

1250 bb = b[:, None] 

1251 rr = r[None, :] 

1252 bmask = b < BATCH 

1253 

1254 if TALL: 

1255 mask = (bb < BATCH) & (rr < M) 

1256 base = A + bb * M * N + rr * N 

1257 x = tl.load(base, mask=mask, other=0.0).to(tl.float32) 

1258 y = tl.load(base + 1, mask=mask, other=0.0).to(tl.float32) 

1259 else: 

1260 mask = (bb < BATCH) & (rr < N) 

1261 base = A + bb * M * N + rr 

1262 x = tl.load(base, mask=mask, other=0.0).to(tl.float32) 

1263 y = tl.load(base + N, mask=mask, other=0.0).to(tl.float32) 

1264 

1265 aa = tl.sum(x * x, axis=1) 

1266 bbv = tl.sum(y * y, axis=1) 

1267 ab = tl.sum(x * y, axis=1) 

1268 diff = aa - bbv 

1269 root = tl.sqrt(diff * diff + 4.0 * ab * ab) 

1270 l0 = tl.maximum(0.0, 0.5 * (aa + bbv + root)) 

1271 det = tl.maximum(0.0, aa * bbv - ab * ab) 

1272 l1 = tl.where(l0 > 1.0e-20, det / l0, 0.0) 

1273 tl.store(S + b * 2, tl.sqrt(l0), mask=bmask) 

1274 tl.store(S + b * 2 + 1, tl.sqrt(l1), mask=bmask) 

1275 

1276 

1277@libentry() 

1278@triton.jit 

1279def _rank2_svals_kernel( 

1280 A, 

1281 S, 

1282 M: tl.constexpr, 

1283 N: tl.constexpr, 

1284 TALL: tl.constexpr, 

1285 BLOCK_R: tl.constexpr, 

1286): 

1287 pid = tl.program_id(0) 

1288 offs = tl.arange(0, BLOCK_R) 

1289 

1290 if TALL: 

1291 mask = offs < M 

1292 base = A + pid * M * N 

1293 x = tl.load(base + offs * N, mask=mask, other=0.0).to(tl.float32) 

1294 y = tl.load(base + offs * N + 1, mask=mask, other=0.0).to(tl.float32) 

1295 else: 

1296 mask = offs < N 

1297 base = A + pid * M * N 

1298 x = tl.load(base + offs, mask=mask, other=0.0).to(tl.float32) 

1299 y = tl.load(base + N + offs, mask=mask, other=0.0).to(tl.float32) 

1300 

1301 aa = tl.sum(x * x) 

1302 bb = tl.sum(y * y) 

1303 ab = tl.sum(x * y) 

1304 diff = aa - bb 

1305 root = tl.sqrt(diff * diff + 4.0 * ab * ab) 

1306 l0 = tl.maximum(0.0, 0.5 * (aa + bb + root)) 

1307 det = tl.maximum(0.0, aa * bb - ab * ab) 

1308 l1 = tl.where(l0 > 1.0e-20, det / l0, 0.0) 

1309 

1310 sbase = S + pid * 2 

1311 tl.store(sbase, tl.sqrt(l0)) 

1312 tl.store(sbase + 1, tl.sqrt(l1)) 

1313 

1314 

1315@libentry() 

1316@triton.jit 

1317def _rank2_svd_kernel( 

1318 A, 

1319 U, 

1320 S, 

1321 V, 

1322 M: tl.constexpr, 

1323 N: tl.constexpr, 

1324 TALL: tl.constexpr, 

1325 BLOCK_R: tl.constexpr, 

1326): 

1327 pid = tl.program_id(0) 

1328 offs = tl.arange(0, BLOCK_R) 

1329 eps = 1.0e-20 

1330 

1331 if TALL: 

1332 mask = offs < M 

1333 base = A + pid * M * N 

1334 x = tl.load(base + offs * N, mask=mask, other=0.0).to(tl.float32) 

1335 y = tl.load(base + offs * N + 1, mask=mask, other=0.0).to(tl.float32) 

1336 else: 

1337 mask = offs < N 

1338 base = A + pid * M * N 

1339 x = tl.load(base + offs, mask=mask, other=0.0).to(tl.float32) 

1340 y = tl.load(base + N + offs, mask=mask, other=0.0).to(tl.float32) 

1341 

1342 aa = tl.sum(x * x) 

1343 bb = tl.sum(y * y) 

1344 ab = tl.sum(x * y) 

1345 diff = aa - bb 

1346 root = tl.sqrt(diff * diff + 4.0 * ab * ab) 

1347 l0 = tl.maximum(0.0, 0.5 * (aa + bb + root)) 

1348 det = tl.maximum(0.0, aa * bb - ab * ab) 

1349 l1 = tl.where(l0 > eps, det / l0, 0.0) 

1350 s0 = tl.sqrt(l0) 

1351 s1 = tl.sqrt(l1) 

1352 

1353 ab_abs = tl.abs(ab) 

1354 aa_ge_bb = aa >= bb 

1355 vx0 = tl.where(ab_abs > eps, ab, tl.where(aa_ge_bb, 1.0, 0.0)) 

1356 vy0 = tl.where(ab_abs > eps, l0 - aa, tl.where(aa_ge_bb, 0.0, 1.0)) 

1357 inv_norm = tl.rsqrt(vx0 * vx0 + vy0 * vy0 + eps) 

1358 vx0 = vx0 * inv_norm 

1359 vy0 = vy0 * inv_norm 

1360 vx1 = -vy0 

1361 vy1 = vx0 

1362 

1363 sbase = S + pid * 2 

1364 tl.store(sbase, s0) 

1365 tl.store(sbase + 1, s1) 

1366 

1367 inv_s0 = tl.where(s0 > eps, 1.0 / s0, 0.0) 

1368 inv_s1 = tl.where(s1 > eps, 1.0 / s1, 0.0) 

1369 

1370 if TALL: 

1371 ubase = U + pid * M * 2 

1372 u0 = (x * vx0 + y * vy0) * inv_s0 

1373 basis0 = tl.where(offs == 0, 1.0, 0.0) 

1374 basis1 = tl.where(offs == 1, 1.0, 0.0) 

1375 u0 = tl.where(s0 > eps, u0, basis0) 

1376 

1377 u1 = (x * vx1 + y * vy1) * inv_s1 

1378 u0_first = tl.sum(tl.where(offs == 0, u0, 0.0)) 

1379 anchor = tl.where(tl.abs(u0_first) < 0.70710678, basis0, basis1) 

1380 dot = tl.sum(anchor * u0) 

1381 fallback_u1 = anchor - dot * u0 

1382 fallback_norm = tl.sum(fallback_u1 * fallback_u1) 

1383 fallback_u1 = fallback_u1 * tl.rsqrt(fallback_norm + eps) 

1384 u1 = tl.where(s1 > s0 * 5.0e-4, u1, fallback_u1) 

1385 tl.store(ubase + offs * 2, u0, mask=mask) 

1386 tl.store(ubase + offs * 2 + 1, u1, mask=mask) 

1387 

1388 vbase = V + pid * 4 

1389 tl.store(vbase, vx0) 

1390 tl.store(vbase + 1, vx1) 

1391 tl.store(vbase + 2, vy0) 

1392 tl.store(vbase + 3, vy1) 

1393 else: 

1394 ubase = U + pid * 4 

1395 tl.store(ubase, vx0) 

1396 tl.store(ubase + 1, vx1) 

1397 tl.store(ubase + 2, vy0) 

1398 tl.store(ubase + 3, vy1) 

1399 

1400 vbase = V + pid * N * 2 

1401 v0 = (x * vx0 + y * vy0) * inv_s0 

1402 basis0 = tl.where(offs == 0, 1.0, 0.0) 

1403 basis1 = tl.where(offs == 1, 1.0, 0.0) 

1404 v0 = tl.where(s0 > eps, v0, basis0) 

1405 

1406 v1 = (x * vx1 + y * vy1) * inv_s1 

1407 v0_first = tl.sum(tl.where(offs == 0, v0, 0.0)) 

1408 anchor = tl.where(tl.abs(v0_first) < 0.70710678, basis0, basis1) 

1409 dot = tl.sum(anchor * v0) 

1410 fallback_v1 = anchor - dot * v0 

1411 fallback_norm = tl.sum(fallback_v1 * fallback_v1) 

1412 fallback_v1 = fallback_v1 * tl.rsqrt(fallback_norm + eps) 

1413 v1 = tl.where(s1 > s0 * 5.0e-4, v1, fallback_v1) 

1414 tl.store(vbase + offs * 2, v0, mask=mask) 

1415 tl.store(vbase + offs * 2 + 1, v1, mask=mask) 

1416 

1417 

1418def _rank2_svd(input): 

1419 batch, m, n = _svd_shape(input) 

1420 a = input.contiguous().reshape(batch, m, n) 

1421 u = torch.empty((batch, m, 2), dtype=input.dtype, device=input.device) 

1422 s = torch.empty((batch, 2), dtype=input.dtype, device=input.device) 

1423 v = torch.empty((batch, n, 2), dtype=input.dtype, device=input.device) 

1424 largest = max(m, n) 

1425 block_r = triton.next_power_of_2(largest) 

1426 with torch_device_fn.device(input.device): 

1427 if largest <= 16 and batch >= 16: 

1428 if largest <= 2: 

1429 block_b = 8 

1430 elif largest == 16: 

1431 block_b = 2 if m >= n else 8 

1432 else: 

1433 block_b = 16 

1434 _rank2_svd_tiny_kernel[(triton.cdiv(batch, block_b),)]( 

1435 a, 

1436 u, 

1437 s, 

1438 v, 

1439 BATCH=batch, 

1440 M=m, 

1441 N=n, 

1442 TALL=m >= n, 

1443 BLOCK_B=block_b, 

1444 BLOCK_R=block_r, 

1445 num_warps=1, 

1446 ) 

1447 else: 

1448 _rank2_svd_kernel[(batch,)]( 

1449 a, 

1450 u, 

1451 s, 

1452 v, 

1453 M=m, 

1454 N=n, 

1455 TALL=m >= n, 

1456 BLOCK_R=block_r, 

1457 num_warps=1 if block_r <= 64 else 4, 

1458 ) 

1459 return ( 

1460 u.reshape(*input.shape[:-2], m, 2), 

1461 s.reshape(*input.shape[:-2], 2), 

1462 v.reshape(*input.shape[:-2], n, 2), 

1463 ) 

1464 

1465 

1466def _rank2_singular_values(input): 

1467 batch, m, n = _svd_shape(input) 

1468 a = input.contiguous().reshape(batch, m, n) 

1469 s = torch.empty((batch, 2), dtype=input.dtype, device=input.device) 

1470 largest = max(m, n) 

1471 block_r = triton.next_power_of_2(largest) 

1472 with torch_device_fn.device(input.device): 

1473 if largest <= 16 and batch >= 16: 

1474 if largest <= 2: 

1475 block_b = 8 

1476 elif largest == 16: 

1477 block_b = 2 if m >= n else 8 

1478 else: 

1479 block_b = 16 

1480 _rank2_svals_tiny_kernel[(triton.cdiv(batch, block_b),)]( 

1481 a, 

1482 s, 

1483 BATCH=batch, 

1484 M=m, 

1485 N=n, 

1486 TALL=m >= n, 

1487 BLOCK_B=block_b, 

1488 BLOCK_R=block_r, 

1489 num_warps=1, 

1490 ) 

1491 else: 

1492 _rank2_svals_kernel[(batch,)]( 

1493 a, 

1494 s, 

1495 M=m, 

1496 N=n, 

1497 TALL=m >= n, 

1498 BLOCK_R=block_r, 

1499 num_warps=1 if block_r <= 64 else 4, 

1500 ) 

1501 return s.reshape(*input.shape[:-2], 2) 

1502 

1503 

1504def _small_jacobi_singular_values(input): 

1505 batch, m, n = _svd_shape(input) 

1506 k = min(m, n) 

1507 rows = max(m, n) 

1508 a = input.contiguous().reshape(batch, m, n) 

1509 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device) 

1510 s = torch.empty((batch, k), dtype=input.dtype, device=input.device) 

1511 block_r = triton.next_power_of_2(rows) 

1512 block_k = triton.next_power_of_2(k) 

1513 sweeps = 3 if k <= 4 else 5 

1514 with torch_device_fn.device(input.device): 

1515 _small_jacobi_svals_kernel[(batch,)]( 

1516 a, 

1517 a_work, 

1518 s, 

1519 M=m, 

1520 N=n, 

1521 K=k, 

1522 ROWS=rows, 

1523 TALL=m >= n, 

1524 BLOCK_R=block_r, 

1525 BLOCK_K=block_k, 

1526 SWEEPS=sweeps, 

1527 num_warps=1 if block_r <= 64 else 4, 

1528 ) 

1529 return s.reshape(*input.shape[:-2], k) 

1530 

1531 

1532def _small_jacobi_svd(input): 

1533 batch, m, n = _svd_shape(input) 

1534 k = min(m, n) 

1535 rows = max(m, n) 

1536 a = input.contiguous().reshape(batch, m, n) 

1537 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device) 

1538 v_work = torch.empty((batch, k, k), dtype=torch.float32, device=input.device) 

1539 u = torch.empty((batch, m, k), dtype=input.dtype, device=input.device) 

1540 s = torch.empty((batch, k), dtype=input.dtype, device=input.device) 

1541 v = torch.empty((batch, n, k), dtype=input.dtype, device=input.device) 

1542 block_r = triton.next_power_of_2(rows) 

1543 block_k = triton.next_power_of_2(k) 

1544 sweeps = 3 if k <= 4 else 5 

1545 with torch_device_fn.device(input.device): 

1546 _small_jacobi_svd_kernel[(batch,)]( 

1547 a, 

1548 a_work, 

1549 v_work, 

1550 u, 

1551 s, 

1552 v, 

1553 M=m, 

1554 N=n, 

1555 K=k, 

1556 ROWS=rows, 

1557 TALL=m >= n, 

1558 BLOCK_R=block_r, 

1559 BLOCK_K=block_k, 

1560 SWEEPS=sweeps, 

1561 num_warps=1 if block_r <= 64 else 4, 

1562 ) 

1563 return ( 

1564 u.reshape(*input.shape[:-2], m, k), 

1565 s.reshape(*input.shape[:-2], k), 

1566 v.reshape(*input.shape[:-2], n, k), 

1567 ) 

1568 

1569 

1570@libentry() 

1571@triton.jit 

1572def _cyclic_jacobi_init_a_kernel( 

1573 A, 

1574 A_WORK, 

1575 M: tl.constexpr, 

1576 N: tl.constexpr, 

1577 K: tl.constexpr, 

1578 ROWS: tl.constexpr, 

1579 TALL: tl.constexpr, 

1580 BLOCK_R: tl.constexpr, 

1581): 

1582 batch = tl.program_id(0) 

1583 col = tl.program_id(1) 

1584 rows = tl.arange(0, BLOCK_R) 

1585 row_mask = rows < ROWS 

1586 a_base = A + batch * M * N 

1587 aw_base = A_WORK + batch * K * ROWS 

1588 

1589 if TALL: 

1590 vals = tl.load(a_base + rows * N + col, mask=row_mask, other=0.0).to(tl.float32) 

1591 else: 

1592 vals = tl.load(a_base + col * N + rows, mask=row_mask, other=0.0).to(tl.float32) 

1593 tl.store(aw_base + col * ROWS + rows, vals, mask=row_mask) 

1594 

1595 

1596@libentry() 

1597@triton.jit 

1598def _cyclic_jacobi_init_kernel( 

1599 A, 

1600 A_WORK, 

1601 V_WORK, 

1602 M: tl.constexpr, 

1603 N: tl.constexpr, 

1604 K: tl.constexpr, 

1605 ROWS: tl.constexpr, 

1606 TALL: tl.constexpr, 

1607 BLOCK_R: tl.constexpr, 

1608 BLOCK_K: tl.constexpr, 

1609): 

1610 batch = tl.program_id(0) 

1611 col = tl.program_id(1) 

1612 rows = tl.arange(0, BLOCK_R) 

1613 basis_cols = tl.arange(0, BLOCK_K) 

1614 row_mask = rows < ROWS 

1615 basis_mask = basis_cols < K 

1616 a_base = A + batch * M * N 

1617 aw_base = A_WORK + batch * K * ROWS 

1618 vw_base = V_WORK + batch * K * K 

1619 

1620 if TALL: 

1621 vals = tl.load(a_base + rows * N + col, mask=row_mask, other=0.0).to(tl.float32) 

1622 else: 

1623 vals = tl.load(a_base + col * N + rows, mask=row_mask, other=0.0).to(tl.float32) 

1624 tl.store(aw_base + col * ROWS + rows, vals, mask=row_mask) 

1625 

1626 ident = tl.where(basis_cols == col, 1.0, 0.0) 

1627 tl.store(vw_base + col * K + basis_cols, ident, mask=basis_mask) 

1628 

1629 

1630@libentry() 

1631@triton.jit 

1632def _cyclic_jacobi_pair_kernel( 

1633 A_WORK, 

1634 V_WORK, 

1635 STEP, 

1636 K: tl.constexpr, 

1637 ROUND: tl.constexpr, 

1638 ROWS: tl.constexpr, 

1639 BLOCK_R: tl.constexpr, 

1640 BLOCK_K: tl.constexpr, 

1641): 

1642 batch = tl.program_id(0) 

1643 pair = tl.program_id(1) 

1644 rows = tl.arange(0, BLOCK_R) 

1645 cols = tl.arange(0, BLOCK_K) 

1646 ring = ROUND - 1 

1647 

1648 pos_p = pair 

1649 pos_q = ROUND - 1 - pair 

1650 p = tl.where(pos_p == 0, 0, ((pos_p + ring - STEP - 1) % ring) + 1) 

1651 q = tl.where(pos_q == 0, 0, ((pos_q + ring - STEP - 1) % ring) + 1) 

1652 valid_pair = (p < K) & (q < K) 

1653 swap = p > q 

1654 p2 = tl.where(swap, q, p) 

1655 q2 = tl.where(swap, p, q) 

1656 row_mask = (rows < ROWS) & valid_pair 

1657 col_mask = (cols < K) & valid_pair 

1658 

1659 aw_base = A_WORK + batch * K * ROWS 

1660 vw_base = V_WORK + batch * K * K 

1661 ap = tl.load(aw_base + p2 * ROWS + rows, mask=row_mask, other=0.0) 

1662 aq = tl.load(aw_base + q2 * ROWS + rows, mask=row_mask, other=0.0) 

1663 alpha = tl.sum(ap * ap) 

1664 beta = tl.sum(aq * aq) 

1665 gamma = tl.sum(ap * aq) 

1666 eps = 1.0e-20 

1667 abs_gamma = tl.abs(gamma) 

1668 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps) 

1669 active = abs_gamma > threshold 

1670 safe_gamma = tl.where(active, gamma, 1.0) 

1671 tau = (beta - alpha) / (2.0 * safe_gamma) 

1672 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0) 

1673 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)) 

1674 c = tl.rsqrt(1.0 + t * t) 

1675 s_rot = t * c 

1676 c = tl.where(active, c, 1.0) 

1677 s_rot = tl.where(active, s_rot, 0.0) 

1678 

1679 new_ap = c * ap - s_rot * aq 

1680 new_aq = s_rot * ap + c * aq 

1681 tl.store(aw_base + p2 * ROWS + rows, new_ap, mask=row_mask) 

1682 tl.store(aw_base + q2 * ROWS + rows, new_aq, mask=row_mask) 

1683 

1684 vp = tl.load(vw_base + p2 * K + cols, mask=col_mask, other=0.0) 

1685 vq = tl.load(vw_base + q2 * K + cols, mask=col_mask, other=0.0) 

1686 new_vp = c * vp - s_rot * vq 

1687 new_vq = s_rot * vp + c * vq 

1688 tl.store(vw_base + p2 * K + cols, new_vp, mask=col_mask) 

1689 tl.store(vw_base + q2 * K + cols, new_vq, mask=col_mask) 

1690 

1691 

1692@libentry() 

1693@triton.jit 

1694def _serial_cyclic_jacobi_kernel( 

1695 A_WORK, 

1696 V_WORK, 

1697 K, 

1698 ROUND, 

1699 ROWS: tl.constexpr, 

1700 SWEEPS, 

1701 TAIL_STEPS, 

1702 BLOCK_R: tl.constexpr, 

1703 BLOCK_K: tl.constexpr, 

1704): 

1705 batch = tl.program_id(0) 

1706 rows = tl.arange(0, BLOCK_R) 

1707 cols = tl.arange(0, BLOCK_K) 

1708 row_base_mask = rows < ROWS 

1709 col_base_mask = cols < K 

1710 aw_base = A_WORK + batch * K * ROWS 

1711 vw_base = V_WORK + batch * K * K 

1712 eps = 1.0e-20 

1713 ring = ROUND - 1 

1714 half_round = ROUND // 2 

1715 

1716 sweep = 0 

1717 while sweep < SWEEPS: 

1718 step = 0 

1719 while step < ROUND - 1: 

1720 pair = 0 

1721 while pair < half_round: 

1722 pos_p = pair 

1723 pos_q = ROUND - 1 - pair 

1724 p = tl.where(pos_p == 0, 0, ((pos_p + ring - step - 1) % ring) + 1) 

1725 q = tl.where(pos_q == 0, 0, ((pos_q + ring - step - 1) % ring) + 1) 

1726 valid_pair = (p < K) & (q < K) 

1727 swap = p > q 

1728 p2 = tl.where(swap, q, p) 

1729 q2 = tl.where(swap, p, q) 

1730 row_mask = row_base_mask & valid_pair 

1731 col_mask = col_base_mask & valid_pair 

1732 

1733 ap = tl.load(aw_base + p2 * ROWS + rows, mask=row_mask, other=0.0) 

1734 aq = tl.load(aw_base + q2 * ROWS + rows, mask=row_mask, other=0.0) 

1735 alpha = tl.sum(ap * ap) 

1736 beta = tl.sum(aq * aq) 

1737 gamma = tl.sum(ap * aq) 

1738 abs_gamma = tl.abs(gamma) 

1739 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps) 

1740 active = abs_gamma > threshold 

1741 safe_gamma = tl.where(active, gamma, 1.0) 

1742 tau = (beta - alpha) / (2.0 * safe_gamma) 

1743 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0) 

1744 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)) 

1745 c = tl.rsqrt(1.0 + t * t) 

1746 s_rot = t * c 

1747 c = tl.where(active, c, 1.0) 

1748 s_rot = tl.where(active, s_rot, 0.0) 

1749 

1750 new_ap = c * ap - s_rot * aq 

1751 new_aq = s_rot * ap + c * aq 

1752 tl.store(aw_base + p2 * ROWS + rows, new_ap, mask=row_mask) 

1753 tl.store(aw_base + q2 * ROWS + rows, new_aq, mask=row_mask) 

1754 

1755 vp = tl.load(vw_base + p2 * K + cols, mask=col_mask, other=0.0) 

1756 vq = tl.load(vw_base + q2 * K + cols, mask=col_mask, other=0.0) 

1757 new_vp = c * vp - s_rot * vq 

1758 new_vq = s_rot * vp + c * vq 

1759 tl.store(vw_base + p2 * K + cols, new_vp, mask=col_mask) 

1760 tl.store(vw_base + q2 * K + cols, new_vq, mask=col_mask) 

1761 pair += 1 

1762 step += 1 

1763 sweep += 1 

1764 

1765 step = 0 

1766 while step < TAIL_STEPS: 

1767 pair = 0 

1768 while pair < half_round: 

1769 pos_p = pair 

1770 pos_q = ROUND - 1 - pair 

1771 p = tl.where(pos_p == 0, 0, ((pos_p + ring - step - 1) % ring) + 1) 

1772 q = tl.where(pos_q == 0, 0, ((pos_q + ring - step - 1) % ring) + 1) 

1773 valid_pair = (p < K) & (q < K) 

1774 swap = p > q 

1775 p2 = tl.where(swap, q, p) 

1776 q2 = tl.where(swap, p, q) 

1777 row_mask = row_base_mask & valid_pair 

1778 col_mask = col_base_mask & valid_pair 

1779 

1780 ap = tl.load(aw_base + p2 * ROWS + rows, mask=row_mask, other=0.0) 

1781 aq = tl.load(aw_base + q2 * ROWS + rows, mask=row_mask, other=0.0) 

1782 alpha = tl.sum(ap * ap) 

1783 beta = tl.sum(aq * aq) 

1784 gamma = tl.sum(ap * aq) 

1785 abs_gamma = tl.abs(gamma) 

1786 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps) 

1787 active = abs_gamma > threshold 

1788 safe_gamma = tl.where(active, gamma, 1.0) 

1789 tau = (beta - alpha) / (2.0 * safe_gamma) 

1790 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0) 

1791 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)) 

1792 c = tl.rsqrt(1.0 + t * t) 

1793 s_rot = t * c 

1794 c = tl.where(active, c, 1.0) 

1795 s_rot = tl.where(active, s_rot, 0.0) 

1796 

1797 new_ap = c * ap - s_rot * aq 

1798 new_aq = s_rot * ap + c * aq 

1799 tl.store(aw_base + p2 * ROWS + rows, new_ap, mask=row_mask) 

1800 tl.store(aw_base + q2 * ROWS + rows, new_aq, mask=row_mask) 

1801 

1802 vp = tl.load(vw_base + p2 * K + cols, mask=col_mask, other=0.0) 

1803 vq = tl.load(vw_base + q2 * K + cols, mask=col_mask, other=0.0) 

1804 new_vp = c * vp - s_rot * vq 

1805 new_vq = s_rot * vp + c * vq 

1806 tl.store(vw_base + p2 * K + cols, new_vp, mask=col_mask) 

1807 tl.store(vw_base + q2 * K + cols, new_vq, mask=col_mask) 

1808 pair += 1 

1809 step += 1 

1810 

1811 

1812@libentry() 

1813@triton.jit 

1814def _cyclic_jacobi_norm_kernel( 

1815 A_WORK, 

1816 S_WORK, 

1817 K: tl.constexpr, 

1818 ROWS: tl.constexpr, 

1819 BLOCK_R: tl.constexpr, 

1820): 

1821 batch = tl.program_id(0) 

1822 col = tl.program_id(1) 

1823 rows = tl.arange(0, BLOCK_R) 

1824 mask = rows < ROWS 

1825 aw_base = A_WORK + batch * K * ROWS 

1826 vals = tl.load(aw_base + col * ROWS + rows, mask=mask, other=0.0) 

1827 norm = tl.sqrt(tl.sum(vals * vals)) 

1828 tl.store(S_WORK + batch * K + col, norm) 

1829 

1830 

1831@libentry() 

1832@triton.jit 

1833def _cyclic_jacobi_finalize_kernel( 

1834 A_WORK, 

1835 V_WORK, 

1836 S_WORK, 

1837 U, 

1838 S, 

1839 V, 

1840 M: tl.constexpr, 

1841 N: tl.constexpr, 

1842 K: tl.constexpr, 

1843 ROWS: tl.constexpr, 

1844 TALL: tl.constexpr, 

1845 BLOCK_R: tl.constexpr, 

1846 BLOCK_K: tl.constexpr, 

1847): 

1848 batch = tl.program_id(0) 

1849 col = tl.program_id(1) 

1850 rows = tl.arange(0, BLOCK_R) 

1851 basis_cols = tl.arange(0, BLOCK_K) 

1852 row_mask = rows < ROWS 

1853 basis_mask = basis_cols < K 

1854 eps = 1.0e-20 

1855 

1856 s_col = tl.load(S_WORK + batch * K + col) 

1857 rank = tl.full((), 0, dtype=tl.int32) 

1858 for other in tl.static_range(0, K): 

1859 s_other = tl.load(S_WORK + batch * K + other) 

1860 rank += ((s_other > s_col) | ((s_other == s_col) & (other < col))).to(tl.int32) 

1861 

1862 aw_base = A_WORK + batch * K * ROWS 

1863 vw_base = V_WORK + batch * K * K 

1864 col_vals = tl.load(aw_base + col * ROWS + rows, mask=row_mask, other=0.0) 

1865 inv_norm = tl.where(s_col > eps, 1.0 / s_col, 0.0) 

1866 basis = tl.load(vw_base + col * K + basis_cols, mask=basis_mask, other=0.0) 

1867 tl.store(S + batch * K + rank, s_col) 

1868 

1869 if TALL: 

1870 tl.store( 

1871 U + batch * M * K + rows * K + rank, 

1872 col_vals * inv_norm, 

1873 mask=row_mask, 

1874 ) 

1875 tl.store( 

1876 V + batch * N * K + basis_cols * K + rank, 

1877 basis, 

1878 mask=basis_mask, 

1879 ) 

1880 else: 

1881 tl.store( 

1882 U + batch * M * K + basis_cols * K + rank, 

1883 basis, 

1884 mask=basis_mask, 

1885 ) 

1886 tl.store( 

1887 V + batch * N * K + rows * K + rank, 

1888 col_vals * inv_norm, 

1889 mask=row_mask, 

1890 ) 

1891 

1892 

1893@libentry() 

1894@triton.jit 

1895def _blocked_jacobi_pair_svals_kernel( 

1896 A_WORK, 

1897 STEP, 

1898 K: tl.constexpr, 

1899 ROUND: tl.constexpr, 

1900 ROWS: tl.constexpr, 

1901 BLOCK_R: tl.constexpr, 

1902): 

1903 batch = tl.program_id(0) 

1904 pair = tl.program_id(1) 

1905 rows = tl.arange(0, BLOCK_R) 

1906 ring = ROUND - 1 

1907 

1908 pos_p = pair 

1909 pos_q = ROUND - 1 - pair 

1910 p = tl.where(pos_p == 0, 0, ((pos_p + ring - STEP - 1) % ring) + 1) 

1911 q = tl.where(pos_q == 0, 0, ((pos_q + ring - STEP - 1) % ring) + 1) 

1912 valid_pair = (p < K) & (q < K) 

1913 swap = p > q 

1914 p2 = tl.where(swap, q, p) 

1915 q2 = tl.where(swap, p, q) 

1916 row_mask = (rows < ROWS) & valid_pair 

1917 

1918 aw_base = A_WORK + batch * K * ROWS 

1919 ap = tl.load(aw_base + p2 * ROWS + rows, mask=row_mask, other=0.0) 

1920 aq = tl.load(aw_base + q2 * ROWS + rows, mask=row_mask, other=0.0) 

1921 alpha = tl.sum(ap * ap) 

1922 beta = tl.sum(aq * aq) 

1923 gamma = tl.sum(ap * aq) 

1924 eps = 1.0e-20 

1925 abs_gamma = tl.abs(gamma) 

1926 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps) 

1927 active = abs_gamma > threshold 

1928 safe_gamma = tl.where(active, gamma, 1.0) 

1929 tau = (beta - alpha) / (2.0 * safe_gamma) 

1930 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0) 

1931 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)) 

1932 c = tl.rsqrt(1.0 + t * t) 

1933 s_rot = t * c 

1934 c = tl.where(active & valid_pair, c, 1.0) 

1935 s_rot = tl.where(active & valid_pair, s_rot, 0.0) 

1936 

1937 new_ap = c * ap - s_rot * aq 

1938 new_aq = s_rot * ap + c * aq 

1939 tl.store(aw_base + p2 * ROWS + rows, new_ap, mask=row_mask) 

1940 tl.store(aw_base + q2 * ROWS + rows, new_aq, mask=row_mask) 

1941 

1942 

1943@libentry() 

1944@triton.jit 

1945def _blocked_jacobi_pair_a_kernel( 

1946 A_WORK, 

1947 ROT_C, 

1948 ROT_S, 

1949 STEP, 

1950 K: tl.constexpr, 

1951 ROUND: tl.constexpr, 

1952 ROWS: tl.constexpr, 

1953 BLOCK_R: tl.constexpr, 

1954): 

1955 batch = tl.program_id(0) 

1956 pair = tl.program_id(1) 

1957 rows = tl.arange(0, BLOCK_R) 

1958 ring = ROUND - 1 

1959 

1960 pos_p = pair 

1961 pos_q = ROUND - 1 - pair 

1962 p = tl.where(pos_p == 0, 0, ((pos_p + ring - STEP - 1) % ring) + 1) 

1963 q = tl.where(pos_q == 0, 0, ((pos_q + ring - STEP - 1) % ring) + 1) 

1964 valid_pair = (p < K) & (q < K) 

1965 swap = p > q 

1966 p2 = tl.where(swap, q, p) 

1967 q2 = tl.where(swap, p, q) 

1968 row_mask = (rows < ROWS) & valid_pair 

1969 

1970 aw_base = A_WORK + batch * K * ROWS 

1971 ap = tl.load(aw_base + p2 * ROWS + rows, mask=row_mask, other=0.0) 

1972 aq = tl.load(aw_base + q2 * ROWS + rows, mask=row_mask, other=0.0) 

1973 alpha = tl.sum(ap * ap) 

1974 beta = tl.sum(aq * aq) 

1975 gamma = tl.sum(ap * aq) 

1976 eps = 1.0e-20 

1977 abs_gamma = tl.abs(gamma) 

1978 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps) 

1979 active = abs_gamma > threshold 

1980 safe_gamma = tl.where(active, gamma, 1.0) 

1981 tau = (beta - alpha) / (2.0 * safe_gamma) 

1982 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0) 

1983 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)) 

1984 c = tl.rsqrt(1.0 + t * t) 

1985 s_rot = t * c 

1986 c = tl.where(active & valid_pair, c, 1.0) 

1987 s_rot = tl.where(active & valid_pair, s_rot, 0.0) 

1988 

1989 new_ap = c * ap - s_rot * aq 

1990 new_aq = s_rot * ap + c * aq 

1991 tl.store(aw_base + p2 * ROWS + rows, new_ap, mask=row_mask) 

1992 tl.store(aw_base + q2 * ROWS + rows, new_aq, mask=row_mask) 

1993 

1994 rot_base = batch * (ROUND // 2) + pair 

1995 tl.store(ROT_C + rot_base, c) 

1996 tl.store(ROT_S + rot_base, s_rot) 

1997 

1998 

1999@libentry() 

2000@triton.jit 

2001def _hier_block_jacobi_pair_a_kernel( 

2002 A_WORK, 

2003 STEP, 

2004 K: tl.constexpr, 

2005 K_BLOCKS: tl.constexpr, 

2006 ROUND_BLOCKS: tl.constexpr, 

2007 ROWS: tl.constexpr, 

2008 TILE_B: tl.constexpr, 

2009 TILE_COLS: tl.constexpr, 

2010 BLOCK_R: tl.constexpr, 

2011 LOCAL_SWEEPS: tl.constexpr, 

2012): 

2013 batch = tl.program_id(0) 

2014 pair = tl.program_id(1) 

2015 rows = tl.arange(0, BLOCK_R) 

2016 local_cols = tl.arange(0, TILE_COLS) 

2017 ring = ROUND_BLOCKS - 1 

2018 

2019 pos_p = pair 

2020 pos_q = ROUND_BLOCKS - 1 - pair 

2021 p_block = tl.where(pos_p == 0, 0, ((pos_p + ring - STEP - 1) % ring) + 1) 

2022 q_block = tl.where(pos_q == 0, 0, ((pos_q + ring - STEP - 1) % ring) + 1) 

2023 valid_pair = (p_block < K_BLOCKS) & (q_block < K_BLOCKS) 

2024 p2 = tl.minimum(p_block, q_block) 

2025 q2 = tl.maximum(p_block, q_block) 

2026 

2027 col_ids = tl.where( 

2028 local_cols < TILE_B, 

2029 p2 * TILE_B + local_cols, 

2030 q2 * TILE_B + local_cols - TILE_B, 

2031 ) 

2032 row_mask = rows < ROWS 

2033 col_mask = (col_ids < K) & valid_pair 

2034 aw_base = A_WORK + batch * K * ROWS 

2035 vals = tl.load( 

2036 aw_base + col_ids[:, None] * ROWS + rows[None, :], 

2037 mask=col_mask[:, None] & row_mask[None, :], 

2038 other=0.0, 

2039 ).to(tl.float32) 

2040 col_axis = local_cols[:, None] 

2041 eps = 1.0e-20 

2042 

2043 for _ in tl.static_range(0, LOCAL_SWEEPS): 

2044 for p in tl.static_range(0, TILE_COLS): 

2045 for q in tl.static_range(p + 1, TILE_COLS): 

2046 ap = tl.sum(tl.where(col_axis == p, vals, 0.0), axis=0) 

2047 aq = tl.sum(tl.where(col_axis == q, vals, 0.0), axis=0) 

2048 alpha = tl.sum(ap * ap) 

2049 beta = tl.sum(aq * aq) 

2050 gamma = tl.sum(ap * aq) 

2051 abs_gamma = tl.abs(gamma) 

2052 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps) 

2053 active = abs_gamma > threshold 

2054 safe_gamma = tl.where(active, gamma, 1.0) 

2055 tau = (beta - alpha) / (2.0 * safe_gamma) 

2056 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0) 

2057 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)) 

2058 c = tl.rsqrt(1.0 + t * t) 

2059 s_rot = t * c 

2060 c = tl.where(active & valid_pair, c, 1.0) 

2061 s_rot = tl.where(active & valid_pair, s_rot, 0.0) 

2062 

2063 new_ap = c * ap - s_rot * aq 

2064 new_aq = s_rot * ap + c * aq 

2065 vals = tl.where(col_axis == p, new_ap[None, :], vals) 

2066 vals = tl.where(col_axis == q, new_aq[None, :], vals) 

2067 

2068 tl.store( 

2069 aw_base + col_ids[:, None] * ROWS + rows[None, :], 

2070 vals, 

2071 mask=col_mask[:, None] & row_mask[None, :], 

2072 ) 

2073 

2074 

2075@libentry() 

2076@triton.jit 

2077def _blocked_jacobi_apply_v_kernel( 

2078 V_WORK, 

2079 ROT_C, 

2080 ROT_S, 

2081 STEP, 

2082 K: tl.constexpr, 

2083 ROUND: tl.constexpr, 

2084 BLOCK_V: tl.constexpr, 

2085): 

2086 batch = tl.program_id(0) 

2087 pair = tl.program_id(1) 

2088 block = tl.program_id(2) 

2089 cols = block * BLOCK_V + tl.arange(0, BLOCK_V) 

2090 ring = ROUND - 1 

2091 

2092 pos_p = pair 

2093 pos_q = ROUND - 1 - pair 

2094 p = tl.where(pos_p == 0, 0, ((pos_p + ring - STEP - 1) % ring) + 1) 

2095 q = tl.where(pos_q == 0, 0, ((pos_q + ring - STEP - 1) % ring) + 1) 

2096 valid_pair = (p < K) & (q < K) 

2097 swap = p > q 

2098 p2 = tl.where(swap, q, p) 

2099 q2 = tl.where(swap, p, q) 

2100 mask = (cols < K) & valid_pair 

2101 

2102 rot_base = batch * (ROUND // 2) + pair 

2103 c = tl.load(ROT_C + rot_base) 

2104 s_rot = tl.load(ROT_S + rot_base) 

2105 vw_base = V_WORK + batch * K * K 

2106 vp = tl.load(vw_base + p2 * K + cols, mask=mask, other=0.0) 

2107 vq = tl.load(vw_base + q2 * K + cols, mask=mask, other=0.0) 

2108 new_vp = c * vp - s_rot * vq 

2109 new_vq = s_rot * vp + c * vq 

2110 tl.store(vw_base + p2 * K + cols, new_vp, mask=mask) 

2111 tl.store(vw_base + q2 * K + cols, new_vq, mask=mask) 

2112 

2113 

2114@libentry() 

2115@triton.jit 

2116def _blocked_jacobi_rank_kernel( 

2117 S_WORK, 

2118 RANKS, 

2119 S, 

2120 K, 

2121): 

2122 batch = tl.program_id(0) 

2123 col = tl.program_id(1) 

2124 s_col = tl.load(S_WORK + batch * K + col) 

2125 rank = tl.full((), 0, dtype=tl.int32) 

2126 other = 0 

2127 while other < K: 

2128 s_other = tl.load(S_WORK + batch * K + other) 

2129 rank += ((s_other > s_col) | ((s_other == s_col) & (other < col))).to(tl.int32) 

2130 other += 1 

2131 tl.store(RANKS + batch * K + col, rank) 

2132 tl.store(S + batch * K + rank, s_col) 

2133 

2134 

2135@libentry() 

2136@triton.jit 

2137def _blocked_jacobi_store_projected_kernel( 

2138 A_WORK, 

2139 S_WORK, 

2140 RANKS, 

2141 PROJECTED, 

2142 K: tl.constexpr, 

2143 ROWS: tl.constexpr, 

2144 OUT_ROWS: tl.constexpr, 

2145 BLOCK_R: tl.constexpr, 

2146): 

2147 batch = tl.program_id(0) 

2148 col = tl.program_id(1) 

2149 block = tl.program_id(2) 

2150 rows = block * BLOCK_R + tl.arange(0, BLOCK_R) 

2151 mask = rows < OUT_ROWS 

2152 rank = tl.load(RANKS + batch * K + col) 

2153 s_col = tl.load(S_WORK + batch * K + col) 

2154 eps = 1.0e-20 

2155 vals = tl.load( 

2156 A_WORK + batch * K * ROWS + col * ROWS + rows, 

2157 mask=mask, 

2158 other=0.0, 

2159 ) 

2160 vals = vals / tl.maximum(s_col, eps) 

2161 basis = tl.where(rows == rank, 1.0, 0.0) 

2162 vals = tl.where(s_col <= eps, basis, vals) 

2163 tl.store( 

2164 PROJECTED + batch * OUT_ROWS * K + rows * K + rank, 

2165 vals, 

2166 mask=mask, 

2167 ) 

2168 

2169 

2170@libentry() 

2171@triton.jit 

2172def _blocked_jacobi_store_basis_kernel( 

2173 V_WORK, 

2174 RANKS, 

2175 BASIS, 

2176 K: tl.constexpr, 

2177 BLOCK_V: tl.constexpr, 

2178): 

2179 batch = tl.program_id(0) 

2180 col = tl.program_id(1) 

2181 block = tl.program_id(2) 

2182 rows = block * BLOCK_V + tl.arange(0, BLOCK_V) 

2183 mask = rows < K 

2184 rank = tl.load(RANKS + batch * K + col) 

2185 vals = tl.load( 

2186 V_WORK + batch * K * K + col * K + rows, 

2187 mask=mask, 

2188 other=0.0, 

2189 ) 

2190 tl.store( 

2191 BASIS + batch * K * K + rows * K + rank, 

2192 vals, 

2193 mask=mask, 

2194 ) 

2195 

2196 

2197@libentry() 

2198@triton.jit 

2199def _thin_reorthogonalize_kernel( 

2200 Q, 

2201 ROWS: tl.constexpr, 

2202 K: tl.constexpr, 

2203 BLOCK_R: tl.constexpr, 

2204): 

2205 batch = tl.program_id(0) 

2206 rows = tl.arange(0, BLOCK_R) 

2207 row_mask = rows < ROWS 

2208 base = Q + batch * ROWS * K 

2209 eps = 1.0e-20 

2210 

2211 for j in tl.static_range(0, K): 

2212 vec = tl.load(base + rows * K + j, mask=row_mask, other=0.0).to(tl.float32) 

2213 

2214 for prev in tl.static_range(0, K): 

2215 if prev < j: 

2216 q_prev = tl.load(base + rows * K + prev, mask=row_mask, other=0.0).to( 

2217 tl.float32 

2218 ) 

2219 coeff = tl.sum(vec * q_prev) 

2220 vec = vec - coeff * q_prev 

2221 

2222 for prev in tl.static_range(0, K): 

2223 if prev < j: 

2224 q_prev = tl.load(base + rows * K + prev, mask=row_mask, other=0.0).to( 

2225 tl.float32 

2226 ) 

2227 coeff = tl.sum(vec * q_prev) 

2228 vec = vec - coeff * q_prev 

2229 

2230 norm = tl.sqrt(tl.sum(vec * vec)) 

2231 basis = tl.where(rows == j, 1.0, 0.0) 

2232 vec = tl.where(norm > eps, vec, basis) 

2233 

2234 for prev in tl.static_range(0, K): 

2235 if prev < j: 

2236 q_prev = tl.load(base + rows * K + prev, mask=row_mask, other=0.0).to( 

2237 tl.float32 

2238 ) 

2239 coeff = tl.sum(vec * q_prev) 

2240 vec = vec - coeff * q_prev 

2241 

2242 norm = tl.sqrt(tl.sum(vec * vec)) 

2243 vec = vec / tl.maximum(norm, eps) 

2244 tl.store(base + rows * K + j, vec, mask=row_mask) 

2245 

2246 

2247def _cyclic_jacobi_svd(input): 

2248 batch, m, n = _svd_shape(input) 

2249 k = min(m, n) 

2250 rows = max(m, n) 

2251 a = input.contiguous().reshape(batch, m, n) 

2252 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device) 

2253 v_work = torch.empty((batch, k, k), dtype=torch.float32, device=input.device) 

2254 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device) 

2255 u = torch.empty((batch, m, k), dtype=input.dtype, device=input.device) 

2256 s = torch.empty((batch, k), dtype=input.dtype, device=input.device) 

2257 v = torch.empty((batch, n, k), dtype=input.dtype, device=input.device) 

2258 block_r = triton.next_power_of_2(rows) 

2259 block_k = triton.next_power_of_2(k) 

2260 sweeps = 6 if k == 32 else 8 if k < 32 else 12 

2261 tail_steps = 20 if k == 32 else 0 

2262 round_size = k if k % 2 == 0 else k + 1 

2263 serial_medium = 16 <= k <= 32 and rows <= 64 and batch <= 32 

2264 with torch_device_fn.device(input.device): 

2265 _cyclic_jacobi_init_kernel[(batch, k)]( 

2266 a, 

2267 a_work, 

2268 v_work, 

2269 M=m, 

2270 N=n, 

2271 K=k, 

2272 ROWS=rows, 

2273 TALL=m >= n, 

2274 BLOCK_R=block_r, 

2275 BLOCK_K=block_k, 

2276 num_warps=1 if block_r <= 64 else 4, 

2277 ) 

2278 if serial_medium: 

2279 _serial_cyclic_jacobi_kernel[(batch,)]( 

2280 a_work, 

2281 v_work, 

2282 K=k, 

2283 ROUND=round_size, 

2284 ROWS=rows, 

2285 SWEEPS=sweeps, 

2286 TAIL_STEPS=tail_steps, 

2287 BLOCK_R=block_r, 

2288 BLOCK_K=block_k, 

2289 num_warps=1, 

2290 ) 

2291 else: 

2292 for _ in range(sweeps): 

2293 for step in range(round_size - 1): 

2294 _cyclic_jacobi_pair_kernel[(batch, round_size // 2)]( 

2295 a_work, 

2296 v_work, 

2297 step, 

2298 K=k, 

2299 ROUND=round_size, 

2300 ROWS=rows, 

2301 BLOCK_R=block_r, 

2302 BLOCK_K=block_k, 

2303 num_warps=1 if block_r <= 64 else 4, 

2304 ) 

2305 _cyclic_jacobi_norm_kernel[(batch, k)]( 

2306 a_work, 

2307 s_work, 

2308 K=k, 

2309 ROWS=rows, 

2310 BLOCK_R=block_r, 

2311 num_warps=1 if block_r <= 64 else 4, 

2312 ) 

2313 _cyclic_jacobi_finalize_kernel[(batch, k)]( 

2314 a_work, 

2315 v_work, 

2316 s_work, 

2317 u, 

2318 s, 

2319 v, 

2320 M=m, 

2321 N=n, 

2322 K=k, 

2323 ROWS=rows, 

2324 TALL=m >= n, 

2325 BLOCK_R=block_r, 

2326 BLOCK_K=block_k, 

2327 num_warps=1 if block_r <= 64 else 4, 

2328 ) 

2329 if k <= 17 and rows <= 64: 

2330 with torch_device_fn.device(input.device): 

2331 if m >= n: 

2332 _thin_reorthogonalize_kernel[(batch,)]( 

2333 v, 

2334 ROWS=n, 

2335 K=k, 

2336 BLOCK_R=triton.next_power_of_2(n), 

2337 num_warps=1, 

2338 ) 

2339 else: 

2340 _thin_reorthogonalize_kernel[(batch,)]( 

2341 u, 

2342 ROWS=m, 

2343 K=k, 

2344 BLOCK_R=triton.next_power_of_2(m), 

2345 num_warps=1, 

2346 ) 

2347 

2348 if m >= n: 

2349 u = _triton_bmm(a, v, (batch, m, k)) 

2350 projected = u 

2351 projected_rows = m 

2352 else: 

2353 a_t = a.transpose(1, 2).contiguous() 

2354 v = _triton_bmm(a_t, u, (batch, n, k)) 

2355 projected = v 

2356 projected_rows = n 

2357 with torch_device_fn.device(input.device): 

2358 _normalize_projection_kernel[(batch, k)]( 

2359 projected, 

2360 s, 

2361 ROWS=projected_rows, 

2362 K=k, 

2363 BLOCK_R=triton.next_power_of_2(projected_rows), 

2364 num_warps=1, 

2365 ) 

2366 _complete_zero_projection_kernel[(batch, k)]( 

2367 projected, 

2368 s, 

2369 ROWS=projected_rows, 

2370 K=k, 

2371 BLOCK_R=triton.next_power_of_2(projected_rows), 

2372 num_warps=1, 

2373 ) 

2374 if batch > 1 and k <= 16: 

2375 _thin_reorthogonalize_kernel[(batch,)]( 

2376 projected, 

2377 ROWS=projected_rows, 

2378 K=k, 

2379 BLOCK_R=triton.next_power_of_2(projected_rows), 

2380 num_warps=1, 

2381 ) 

2382 return ( 

2383 u.reshape(*input.shape[:-2], m, k), 

2384 s.reshape(*input.shape[:-2], k), 

2385 v.reshape(*input.shape[:-2], n, k), 

2386 ) 

2387 

2388 

2389def _projected_jacobi_svd(input): 

2390 batch, m, n = _svd_shape(input) 

2391 k = min(m, n) 

2392 rows = max(m, n) 

2393 tall = m >= n 

2394 a = input.contiguous().reshape(batch, m, n) 

2395 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device) 

2396 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device) 

2397 ranks = torch.empty((batch, k), dtype=torch.int32, device=input.device) 

2398 s = torch.empty((batch, k), dtype=input.dtype, device=input.device) 

2399 

2400 projected_rows = m if tall else n 

2401 projected = torch.empty( 

2402 (batch, projected_rows, k), dtype=input.dtype, device=input.device 

2403 ) 

2404 block_r = triton.next_power_of_2(rows) 

2405 sweeps = 10 if k >= 128 else 8 

2406 round_size = k if k % 2 == 0 else k + 1 

2407 half_round = round_size // 2 

2408 rot_c = torch.empty((batch, half_round), dtype=torch.float32, device=input.device) 

2409 rot_s = torch.empty((batch, half_round), dtype=torch.float32, device=input.device) 

2410 

2411 with torch_device_fn.device(input.device): 

2412 _cyclic_jacobi_init_a_kernel[(batch, k)]( 

2413 a, 

2414 a_work, 

2415 M=m, 

2416 N=n, 

2417 K=k, 

2418 ROWS=rows, 

2419 TALL=tall, 

2420 BLOCK_R=block_r, 

2421 num_warps=1 if block_r <= 64 else 4, 

2422 ) 

2423 for _ in range(sweeps): 

2424 for step in range(round_size - 1): 

2425 _blocked_jacobi_pair_a_kernel[(batch, half_round)]( 

2426 a_work, 

2427 rot_c, 

2428 rot_s, 

2429 step, 

2430 K=k, 

2431 ROUND=round_size, 

2432 ROWS=rows, 

2433 BLOCK_R=block_r, 

2434 num_warps=1 if block_r <= 64 else 4, 

2435 ) 

2436 _cyclic_jacobi_norm_kernel[(batch, k)]( 

2437 a_work, 

2438 s_work, 

2439 K=k, 

2440 ROWS=rows, 

2441 BLOCK_R=block_r, 

2442 num_warps=1 if block_r <= 64 else 4, 

2443 ) 

2444 _blocked_jacobi_rank_kernel[(batch, k)]( 

2445 s_work, 

2446 ranks, 

2447 s, 

2448 k, 

2449 num_warps=1, 

2450 ) 

2451 _blocked_jacobi_store_projected_kernel[ 

2452 (batch, k, triton.cdiv(projected_rows, block_r)) 

2453 ]( 

2454 a_work, 

2455 s_work, 

2456 ranks, 

2457 projected, 

2458 K=k, 

2459 ROWS=rows, 

2460 OUT_ROWS=projected_rows, 

2461 BLOCK_R=block_r, 

2462 num_warps=1 if block_r <= 64 else 4, 

2463 ) 

2464 

2465 if tall: 

2466 u = projected 

2467 v = _triton_bmm(a.transpose(1, 2).contiguous(), u, (batch, n, k)) 

2468 normalized = v 

2469 normalized_rows = n 

2470 else: 

2471 v = projected 

2472 u = _triton_bmm(a, v, (batch, m, k)) 

2473 normalized = u 

2474 normalized_rows = m 

2475 

2476 with torch_device_fn.device(input.device): 

2477 _normalize_projection_kernel[(batch, k)]( 

2478 normalized, 

2479 s, 

2480 ROWS=normalized_rows, 

2481 K=k, 

2482 BLOCK_R=triton.next_power_of_2(normalized_rows), 

2483 num_warps=1 if normalized_rows <= 64 else 4, 

2484 ) 

2485 

2486 return ( 

2487 u.reshape(*input.shape[:-2], m, k), 

2488 s.reshape(*input.shape[:-2], k), 

2489 v.reshape(*input.shape[:-2], n, k), 

2490 ) 

2491 

2492 

2493def _blocked_jacobi_svd(input): 

2494 batch, m, n = _svd_shape(input) 

2495 k = min(m, n) 

2496 rows = max(m, n) 

2497 a = input.contiguous().reshape(batch, m, n) 

2498 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device) 

2499 v_work = torch.empty((batch, k, k), dtype=torch.float32, device=input.device) 

2500 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device) 

2501 ranks = torch.empty((batch, k), dtype=torch.int32, device=input.device) 

2502 u = torch.empty((batch, m, k), dtype=input.dtype, device=input.device) 

2503 s = torch.empty((batch, k), dtype=input.dtype, device=input.device) 

2504 v = torch.empty((batch, n, k), dtype=input.dtype, device=input.device) 

2505 

2506 block_r = triton.next_power_of_2(rows) 

2507 block_k = triton.next_power_of_2(k) 

2508 block_v = 64 

2509 sweeps = 14 if k > 256 else 10 

2510 round_size = k if k % 2 == 0 else k + 1 

2511 half_round = round_size // 2 

2512 rot_c = torch.empty((batch, half_round), dtype=torch.float32, device=input.device) 

2513 rot_s = torch.empty((batch, half_round), dtype=torch.float32, device=input.device) 

2514 with torch_device_fn.device(input.device): 

2515 _cyclic_jacobi_init_kernel[(batch, k)]( 

2516 a, 

2517 a_work, 

2518 v_work, 

2519 M=m, 

2520 N=n, 

2521 K=k, 

2522 ROWS=rows, 

2523 TALL=m >= n, 

2524 BLOCK_R=block_r, 

2525 BLOCK_K=block_k, 

2526 num_warps=1 if block_r <= 64 else 4, 

2527 ) 

2528 for _ in range(sweeps): 

2529 for step in range(round_size - 1): 

2530 _blocked_jacobi_pair_a_kernel[(batch, half_round)]( 

2531 a_work, 

2532 rot_c, 

2533 rot_s, 

2534 step, 

2535 K=k, 

2536 ROUND=round_size, 

2537 ROWS=rows, 

2538 BLOCK_R=block_r, 

2539 num_warps=1 if block_r <= 64 else 4, 

2540 ) 

2541 _blocked_jacobi_apply_v_kernel[ 

2542 (batch, half_round, triton.cdiv(k, block_v)) 

2543 ]( 

2544 v_work, 

2545 rot_c, 

2546 rot_s, 

2547 step, 

2548 K=k, 

2549 ROUND=round_size, 

2550 BLOCK_V=block_v, 

2551 num_warps=1, 

2552 ) 

2553 _cyclic_jacobi_norm_kernel[(batch, k)]( 

2554 a_work, 

2555 s_work, 

2556 K=k, 

2557 ROWS=rows, 

2558 BLOCK_R=block_r, 

2559 num_warps=1 if block_r <= 64 else 4, 

2560 ) 

2561 _blocked_jacobi_rank_kernel[(batch, k)]( 

2562 s_work, 

2563 ranks, 

2564 s, 

2565 k, 

2566 num_warps=1, 

2567 ) 

2568 if m >= n: 

2569 _blocked_jacobi_store_projected_kernel[(batch, k, triton.cdiv(m, block_r))]( 

2570 a_work, 

2571 s_work, 

2572 ranks, 

2573 u, 

2574 K=k, 

2575 ROWS=rows, 

2576 OUT_ROWS=m, 

2577 BLOCK_R=block_r, 

2578 num_warps=1 if block_r <= 64 else 4, 

2579 ) 

2580 _blocked_jacobi_store_basis_kernel[(batch, k, triton.cdiv(n, block_v))]( 

2581 v_work, 

2582 ranks, 

2583 v, 

2584 K=k, 

2585 BLOCK_V=block_v, 

2586 num_warps=1, 

2587 ) 

2588 else: 

2589 _blocked_jacobi_store_basis_kernel[(batch, k, triton.cdiv(m, block_v))]( 

2590 v_work, 

2591 ranks, 

2592 u, 

2593 K=k, 

2594 BLOCK_V=block_v, 

2595 num_warps=1, 

2596 ) 

2597 _blocked_jacobi_store_projected_kernel[(batch, k, triton.cdiv(n, block_r))]( 

2598 a_work, 

2599 s_work, 

2600 ranks, 

2601 v, 

2602 K=k, 

2603 ROWS=rows, 

2604 OUT_ROWS=n, 

2605 BLOCK_R=block_r, 

2606 num_warps=1 if block_r <= 64 else 4, 

2607 ) 

2608 

2609 return ( 

2610 u.reshape(*input.shape[:-2], m, k), 

2611 s.reshape(*input.shape[:-2], k), 

2612 v.reshape(*input.shape[:-2], n, k), 

2613 ) 

2614 

2615 

2616def _blocked_jacobi_square_project_svd(input): 

2617 batch, m, n = _svd_shape(input) 

2618 k = min(m, n) 

2619 rows = max(m, n) 

2620 a = input.contiguous().reshape(batch, m, n) 

2621 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device) 

2622 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device) 

2623 ranks = torch.empty((batch, k), dtype=torch.int32, device=input.device) 

2624 u = torch.empty((batch, m, k), dtype=input.dtype, device=input.device) 

2625 s = torch.empty((batch, k), dtype=input.dtype, device=input.device) 

2626 

2627 block_r = triton.next_power_of_2(rows) 

2628 sweeps = 12 if k <= 256 else 16 

2629 round_size = k if k % 2 == 0 else k + 1 

2630 half_round = round_size // 2 

2631 rot_c = torch.empty((batch, half_round), dtype=torch.float32, device=input.device) 

2632 rot_s = torch.empty((batch, half_round), dtype=torch.float32, device=input.device) 

2633 with torch_device_fn.device(input.device): 

2634 _cyclic_jacobi_init_a_kernel[(batch, k)]( 

2635 a, 

2636 a_work, 

2637 M=m, 

2638 N=n, 

2639 K=k, 

2640 ROWS=rows, 

2641 TALL=True, 

2642 BLOCK_R=block_r, 

2643 num_warps=1 if block_r <= 64 else 4, 

2644 ) 

2645 for _ in range(sweeps): 

2646 for step in range(round_size - 1): 

2647 _blocked_jacobi_pair_a_kernel[(batch, half_round)]( 

2648 a_work, 

2649 rot_c, 

2650 rot_s, 

2651 step, 

2652 K=k, 

2653 ROUND=round_size, 

2654 ROWS=rows, 

2655 BLOCK_R=block_r, 

2656 num_warps=1 if block_r <= 64 else 4, 

2657 ) 

2658 _cyclic_jacobi_norm_kernel[(batch, k)]( 

2659 a_work, 

2660 s_work, 

2661 K=k, 

2662 ROWS=rows, 

2663 BLOCK_R=block_r, 

2664 num_warps=1 if block_r <= 64 else 4, 

2665 ) 

2666 _blocked_jacobi_rank_kernel[(batch, k)]( 

2667 s_work, 

2668 ranks, 

2669 s, 

2670 k, 

2671 num_warps=1, 

2672 ) 

2673 _blocked_jacobi_store_projected_kernel[(batch, k, triton.cdiv(m, block_r))]( 

2674 a_work, 

2675 s_work, 

2676 ranks, 

2677 u, 

2678 K=k, 

2679 ROWS=rows, 

2680 OUT_ROWS=m, 

2681 BLOCK_R=block_r, 

2682 num_warps=1 if block_r <= 64 else 4, 

2683 ) 

2684 

2685 a_t = a.transpose(1, 2).contiguous() 

2686 v = _triton_bmm(a_t, u, (batch, n, k)) 

2687 with torch_device_fn.device(input.device): 

2688 _renorm_projection_update_s_kernel[(batch, k)]( 

2689 v, 

2690 s, 

2691 ROWS=n, 

2692 K=k, 

2693 BLOCK_R=triton.next_power_of_2(n), 

2694 num_warps=1 if n <= 64 else 4, 

2695 ) 

2696 

2697 return ( 

2698 u.reshape(*input.shape[:-2], m, k), 

2699 s.reshape(*input.shape[:-2], k), 

2700 v.reshape(*input.shape[:-2], n, k), 

2701 ) 

2702 

2703 

2704def _hier_block_jacobi_square_project_svd(input): 

2705 batch, m, n = _svd_shape(input) 

2706 k = min(m, n) 

2707 rows = max(m, n) 

2708 a = input.contiguous().reshape(batch, m, n) 

2709 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device) 

2710 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device) 

2711 ranks = torch.empty((batch, k), dtype=torch.int32, device=input.device) 

2712 u = torch.empty((batch, m, k), dtype=input.dtype, device=input.device) 

2713 s = torch.empty((batch, k), dtype=input.dtype, device=input.device) 

2714 

2715 tile_b = 4 if k == 512 else 2 

2716 if m != n or k % tile_b != 0: 

2717 return _unsupported_svd( 

2718 input, 

2719 True, 

2720 True, 

2721 "Hierarchical block Jacobi supports square matrices with " 

2722 "k divisible by two.", 

2723 ) 

2724 

2725 block_r = triton.next_power_of_2(rows) 

2726 block_count = k // tile_b 

2727 round_blocks = block_count if block_count % 2 == 0 else block_count + 1 

2728 half_round_blocks = round_blocks // 2 

2729 sweep_count = 10 if k <= 256 else 12 

2730 tile_cols = tile_b * 2 

2731 with torch_device_fn.device(input.device): 

2732 _cyclic_jacobi_init_a_kernel[(batch, k)]( 

2733 a, 

2734 a_work, 

2735 M=m, 

2736 N=n, 

2737 K=k, 

2738 ROWS=rows, 

2739 TALL=True, 

2740 BLOCK_R=block_r, 

2741 num_warps=1 if block_r <= 64 else 4, 

2742 ) 

2743 for _ in range(sweep_count): 

2744 for step in range(round_blocks - 1): 

2745 _hier_block_jacobi_pair_a_kernel[(batch, half_round_blocks)]( 

2746 a_work, 

2747 step, 

2748 K=k, 

2749 K_BLOCKS=block_count, 

2750 ROUND_BLOCKS=round_blocks, 

2751 ROWS=rows, 

2752 TILE_B=tile_b, 

2753 TILE_COLS=tile_cols, 

2754 BLOCK_R=block_r, 

2755 LOCAL_SWEEPS=1, 

2756 num_warps=4, 

2757 ) 

2758 _cyclic_jacobi_norm_kernel[(batch, k)]( 

2759 a_work, 

2760 s_work, 

2761 K=k, 

2762 ROWS=rows, 

2763 BLOCK_R=block_r, 

2764 num_warps=1 if block_r <= 64 else 4, 

2765 ) 

2766 _blocked_jacobi_rank_kernel[(batch, k)]( 

2767 s_work, 

2768 ranks, 

2769 s, 

2770 k, 

2771 num_warps=1, 

2772 ) 

2773 _blocked_jacobi_store_projected_kernel[(batch, k, triton.cdiv(m, block_r))]( 

2774 a_work, 

2775 s_work, 

2776 ranks, 

2777 u, 

2778 K=k, 

2779 ROWS=rows, 

2780 OUT_ROWS=m, 

2781 BLOCK_R=block_r, 

2782 num_warps=1 if block_r <= 64 else 4, 

2783 ) 

2784 

2785 a_t = a.transpose(1, 2).contiguous() 

2786 v = _triton_bmm(a_t, u, (batch, n, k)) 

2787 with torch_device_fn.device(input.device): 

2788 _renorm_projection_update_s_kernel[(batch, k)]( 

2789 v, 

2790 s, 

2791 ROWS=n, 

2792 K=k, 

2793 BLOCK_R=triton.next_power_of_2(n), 

2794 num_warps=1 if n <= 64 else 4, 

2795 ) 

2796 

2797 return ( 

2798 u.reshape(*input.shape[:-2], m, k), 

2799 s.reshape(*input.shape[:-2], k), 

2800 v.reshape(*input.shape[:-2], n, k), 

2801 ) 

2802 

2803 

2804def _blocked_jacobi_singular_values(input): 

2805 batch, m, n = _svd_shape(input) 

2806 k = min(m, n) 

2807 rows = max(m, n) 

2808 a = input.contiguous().reshape(batch, m, n) 

2809 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device) 

2810 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device) 

2811 ranks = torch.empty((batch, k), dtype=torch.int32, device=input.device) 

2812 s = torch.empty((batch, k), dtype=input.dtype, device=input.device) 

2813 

2814 block_r = triton.next_power_of_2(rows) 

2815 sweeps = 14 if k > 256 else 10 

2816 round_size = k if k % 2 == 0 else k + 1 

2817 half_round = round_size // 2 

2818 with torch_device_fn.device(input.device): 

2819 _cyclic_jacobi_init_a_kernel[(batch, k)]( 

2820 a, 

2821 a_work, 

2822 M=m, 

2823 N=n, 

2824 K=k, 

2825 ROWS=rows, 

2826 TALL=m >= n, 

2827 BLOCK_R=block_r, 

2828 num_warps=1 if block_r <= 64 else 4, 

2829 ) 

2830 for _ in range(sweeps): 

2831 for step in range(round_size - 1): 

2832 _blocked_jacobi_pair_svals_kernel[(batch, half_round)]( 

2833 a_work, 

2834 step, 

2835 K=k, 

2836 ROUND=round_size, 

2837 ROWS=rows, 

2838 BLOCK_R=block_r, 

2839 num_warps=1 if block_r <= 64 else 4, 

2840 ) 

2841 _cyclic_jacobi_norm_kernel[(batch, k)]( 

2842 a_work, 

2843 s_work, 

2844 K=k, 

2845 ROWS=rows, 

2846 BLOCK_R=block_r, 

2847 num_warps=1 if block_r <= 64 else 4, 

2848 ) 

2849 _blocked_jacobi_rank_kernel[(batch, k)]( 

2850 s_work, 

2851 ranks, 

2852 s, 

2853 k, 

2854 num_warps=1, 

2855 ) 

2856 

2857 return s.reshape(*input.shape[:-2], k) 

2858 

2859 

2860def _small4_square_svd(input): 

2861 batch, m, n = _svd_shape(input) 

2862 a = input.contiguous().reshape(batch, m, n) 

2863 u = torch.empty((batch, 4, 4), dtype=input.dtype, device=input.device) 

2864 s = torch.empty((batch, 4), dtype=input.dtype, device=input.device) 

2865 v = torch.empty((batch, 4, 4), dtype=input.dtype, device=input.device) 

2866 block_b = 16 

2867 with torch_device_fn.device(input.device): 

2868 _small4_square_svd_kernel[(triton.cdiv(batch, block_b),)]( 

2869 a, u, s, v, BATCH=batch, BLOCK_B=block_b, SWEEPS=4, num_warps=1 

2870 ) 

2871 return ( 

2872 u.reshape(*input.shape[:-2], 4, 4), 

2873 s.reshape(*input.shape[:-2], 4), 

2874 v.reshape(*input.shape[:-2], 4, 4), 

2875 ) 

2876 

2877 

2878@libentry() 

2879@triton.jit 

2880def _rank1_svd_kernel( 

2881 A, 

2882 U, 

2883 S, 

2884 V, 

2885 M: tl.constexpr, 

2886 N: tl.constexpr, 

2887 TALL: tl.constexpr, 

2888 BLOCK_R: tl.constexpr, 

2889): 

2890 pid = tl.program_id(0) 

2891 offsets = tl.arange(0, BLOCK_R) 

2892 eps = 1.1920928955078125e-7 

2893 a_base = A + pid * M * N 

2894 norm_sq = tl.full((), 0.0, dtype=tl.float32) 

2895 

2896 if TALL: 

2897 for base in range(0, M, BLOCK_R): 

2898 rows = base + offsets 

2899 mask = rows < M 

2900 vals = tl.load(a_base + rows * N, mask=mask, other=0.0).to(tl.float32) 

2901 norm_sq += tl.sum(vals * vals) 

2902 

2903 norm = tl.sqrt(norm_sq) 

2904 denom = tl.maximum(norm, eps) 

2905 tl.store(S + pid, norm) 

2906 tl.store(V + pid, 1.0) 

2907 

2908 u_base = U + pid * M 

2909 for base in range(0, M, BLOCK_R): 

2910 rows = base + offsets 

2911 mask = rows < M 

2912 vals = tl.load(a_base + rows * N, mask=mask, other=0.0).to(tl.float32) 

2913 tl.store(u_base + rows, vals / denom, mask=mask) 

2914 else: 

2915 for base in range(0, N, BLOCK_R): 

2916 cols = base + offsets 

2917 mask = cols < N 

2918 vals = tl.load(a_base + cols, mask=mask, other=0.0).to(tl.float32) 

2919 norm_sq += tl.sum(vals * vals) 

2920 

2921 norm = tl.sqrt(norm_sq) 

2922 denom = tl.maximum(norm, eps) 

2923 tl.store(S + pid, norm) 

2924 tl.store(U + pid, 1.0) 

2925 

2926 v_base = V + pid * N 

2927 for base in range(0, N, BLOCK_R): 

2928 cols = base + offsets 

2929 mask = cols < N 

2930 vals = tl.load(a_base + cols, mask=mask, other=0.0).to(tl.float32) 

2931 tl.store(v_base + cols, vals / denom, mask=mask) 

2932 

2933 

2934def _rank1_svd(input): 

2935 batch, m, n = _svd_shape(input) 

2936 a = input.contiguous().reshape(batch, m, n) 

2937 u = torch.empty((batch, m, 1), dtype=input.dtype, device=input.device) 

2938 s = torch.empty((batch, 1), dtype=input.dtype, device=input.device) 

2939 v = torch.empty((batch, n, 1), dtype=input.dtype, device=input.device) 

2940 if batch != 0: 

2941 rows = max(m, n) 

2942 block_r = _RANK1_BLOCK_R_MAX 

2943 if rows <= _RANK1_BLOCK_R_MAX: 

2944 block_r = triton.next_power_of_2(rows) 

2945 with torch_device_fn.device(input.device): 

2946 _rank1_svd_kernel[(batch,)]( 

2947 a, 

2948 u, 

2949 s, 

2950 v, 

2951 m, 

2952 n, 

2953 TALL=n == 1, 

2954 BLOCK_R=block_r, 

2955 num_warps=1 if block_r <= 64 else 4, 

2956 ) 

2957 return ( 

2958 u.reshape(*input.shape[:-2], m, 1), 

2959 s.reshape(*input.shape[:-2], 1), 

2960 v.reshape(*input.shape[:-2], n, 1), 

2961 ) 

2962 

2963 

2964@libentry() 

2965@triton.jit 

2966def _complex_to_real_embedding_kernel( 

2967 A_RI, 

2968 R, 

2969 M: tl.constexpr, 

2970 N: tl.constexpr, 

2971 BLOCK_SIZE: tl.constexpr, 

2972): 

2973 batch = tl.program_id(0) 

2974 offsets = tl.arange(0, BLOCK_SIZE) 

2975 total = 4 * M * N 

2976 mask = offsets < total 

2977 row = offsets // (2 * N) 

2978 col = offsets - row * (2 * N) 

2979 src_row = tl.where(row < M, row, row - M) 

2980 src_col = tl.where(col < N, col, col - N) 

2981 comp = tl.where((row < M) & (col >= N), 1, 0) 

2982 comp = tl.where((row >= M) & (col < N), 1, comp) 

2983 vals = tl.load( 

2984 A_RI + batch * M * N * 2 + (src_row * N + src_col) * 2 + comp, 

2985 mask=mask, 

2986 other=0.0, 

2987 ) 

2988 sign = tl.where((row < M) & (col >= N), -1.0, 1.0) 

2989 tl.store(R + batch * 4 * M * N + offsets, vals * sign, mask=mask) 

2990 

2991 

2992@libentry() 

2993@triton.jit 

2994def _complex_svd_pick_factor_kernel( 

2995 REAL_FACTOR, 

2996 OUT_RI, 

2997 ROWS: tl.constexpr, 

2998 K: tl.constexpr, 

2999 REAL_K: tl.constexpr, 

3000 BLOCK_SIZE: tl.constexpr, 

3001): 

3002 batch = tl.program_id(0) 

3003 offsets = tl.arange(0, BLOCK_SIZE) 

3004 mask = offsets < ROWS * K 

3005 row = offsets // K 

3006 col = offsets % K 

3007 src_col = col * 2 

3008 real = tl.load( 

3009 REAL_FACTOR + batch * (2 * ROWS) * REAL_K + row * REAL_K + src_col, 

3010 mask=mask, 

3011 other=0.0, 

3012 ) 

3013 imag = tl.load( 

3014 REAL_FACTOR + batch * (2 * ROWS) * REAL_K + (ROWS + row) * REAL_K + src_col, 

3015 mask=mask, 

3016 other=0.0, 

3017 ) 

3018 out_base = OUT_RI + batch * ROWS * K * 2 + offsets * 2 

3019 tl.store(out_base, real, mask=mask) 

3020 tl.store(out_base + 1, imag, mask=mask) 

3021 

3022 

3023@libentry() 

3024@triton.jit 

3025def _complex_svd_pick_s_kernel( 

3026 S_REAL, 

3027 S, 

3028 K: tl.constexpr, 

3029 REAL_K: tl.constexpr, 

3030 BLOCK_K: tl.constexpr, 

3031): 

3032 batch = tl.program_id(0) 

3033 cols = tl.arange(0, BLOCK_K) 

3034 mask = cols < K 

3035 src = cols * 2 

3036 vals_a = tl.load(S_REAL + batch * REAL_K + src, mask=mask, other=0.0) 

3037 vals_b = tl.load(S_REAL + batch * REAL_K + src + 1, mask=mask, other=0.0) 

3038 tl.store(S + batch * K + cols, 0.5 * (vals_a + vals_b), mask=mask) 

3039 

3040 

3041@libentry() 

3042@triton.jit 

3043def _complex_svd_pick_orthonormal_v_kernel( 

3044 V_REAL, 

3045 V_RI, 

3046 ROWS: tl.constexpr, 

3047 K: tl.constexpr, 

3048 REAL_K: tl.constexpr, 

3049 BLOCK_ROWS: tl.constexpr, 

3050 BLOCK_K: tl.constexpr, 

3051): 

3052 batch = tl.program_id(0) 

3053 rows = tl.arange(0, BLOCK_ROWS) 

3054 cols = tl.arange(0, BLOCK_K) 

3055 row_mask = rows < ROWS 

3056 col_mask = cols < K 

3057 src_cols = cols * 2 

3058 base = V_REAL + batch * (2 * ROWS) * REAL_K 

3059 vr = tl.load( 

3060 base + rows[:, None] * REAL_K + src_cols[None, :], 

3061 mask=row_mask[:, None] & col_mask[None, :], 

3062 other=0.0, 

3063 ) 

3064 vi = tl.load( 

3065 base + (ROWS + rows[:, None]) * REAL_K + src_cols[None, :], 

3066 mask=row_mask[:, None] & col_mask[None, :], 

3067 other=0.0, 

3068 ) 

3069 

3070 for c in tl.static_range(0, 16): 

3071 cur_mask = c < K 

3072 cur_r = tl.sum(tl.where(cols[None, :] == c, vr, 0.0), axis=1) 

3073 cur_i = tl.sum(tl.where(cols[None, :] == c, vi, 0.0), axis=1) 

3074 for p in tl.static_range(0, c): 

3075 prev_r = tl.sum(tl.where(cols[None, :] == p, vr, 0.0), axis=1) 

3076 prev_i = tl.sum(tl.where(cols[None, :] == p, vi, 0.0), axis=1) 

3077 coeff_r = tl.sum( 

3078 tl.where(row_mask, prev_r * cur_r + prev_i * cur_i, 0.0), axis=0 

3079 ) 

3080 coeff_i = tl.sum( 

3081 tl.where(row_mask, prev_r * cur_i - prev_i * cur_r, 0.0), axis=0 

3082 ) 

3083 cur_r -= prev_r * coeff_r - prev_i * coeff_i 

3084 cur_i -= prev_r * coeff_i + prev_i * coeff_r 

3085 norm_sq = tl.sum(tl.where(row_mask, cur_r * cur_r + cur_i * cur_i, 0.0), axis=0) 

3086 inv_norm = tl.rsqrt(tl.maximum(norm_sq, 1.0e-20)) 

3087 cur_r *= inv_norm 

3088 cur_i *= inv_norm 

3089 vr = tl.where((cols[None, :] == c) & cur_mask, cur_r[:, None], vr) 

3090 vi = tl.where((cols[None, :] == c) & cur_mask, cur_i[:, None], vi) 

3091 

3092 out_base = V_RI + batch * ROWS * K * 2 

3093 offsets = rows[:, None] * K + cols[None, :] 

3094 mask = row_mask[:, None] & col_mask[None, :] 

3095 tl.store(out_base + offsets * 2, vr, mask=mask) 

3096 tl.store(out_base + offsets * 2 + 1, vi, mask=mask) 

3097 

3098 

3099@libentry() 

3100@triton.jit 

3101def _complex_svd_project_u_kernel( 

3102 A_RI, 

3103 V_RI, 

3104 S, 

3105 U_RI, 

3106 M: tl.constexpr, 

3107 N: tl.constexpr, 

3108 K: tl.constexpr, 

3109 BLOCK_SIZE: tl.constexpr, 

3110): 

3111 batch = tl.program_id(0) 

3112 offsets = tl.arange(0, BLOCK_SIZE) 

3113 mask = offsets < M * K 

3114 row = offsets // K 

3115 col = offsets % K 

3116 

3117 acc_r = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) 

3118 acc_i = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) 

3119 for j in tl.static_range(0, N): 

3120 a_base = A_RI + batch * M * N * 2 + (row * N + j) * 2 

3121 v_base = V_RI + batch * N * K * 2 + (j * K + col) * 2 

3122 ar = tl.load(a_base, mask=mask, other=0.0) 

3123 ai = tl.load(a_base + 1, mask=mask, other=0.0) 

3124 vr = tl.load(v_base, mask=mask, other=0.0) 

3125 vi = tl.load(v_base + 1, mask=mask, other=0.0) 

3126 acc_r += ar * vr - ai * vi 

3127 acc_i += ar * vi + ai * vr 

3128 

3129 s = tl.load(S + batch * K + col, mask=mask, other=1.0) 

3130 inv_s = tl.where(s > 1.0e-20, 1.0 / s, 0.0) 

3131 out_base = U_RI + batch * M * K * 2 + offsets * 2 

3132 tl.store(out_base, acc_r * inv_s, mask=mask) 

3133 tl.store(out_base + 1, acc_i * inv_s, mask=mask) 

3134 

3135 

3136def _complex_svd_via_real_embedding(input): 

3137 batch, m, n = _svd_shape(input) 

3138 k = min(m, n) 

3139 a_ri = torch.view_as_real(input.contiguous()).reshape(batch, m, n, 2) 

3140 real_matrix = torch.empty( 

3141 (batch, 2 * m, 2 * n), dtype=torch.float32, device=input.device 

3142 ) 

3143 block_size = triton.next_power_of_2(4 * m * n) 

3144 with torch_device_fn.device(input.device): 

3145 _complex_to_real_embedding_kernel[(batch,)]( 

3146 a_ri, 

3147 real_matrix, 

3148 M=m, 

3149 N=n, 

3150 BLOCK_SIZE=block_size, 

3151 num_warps=1, 

3152 ) 

3153 _, s_real, v_real = svd(real_matrix, some=True, compute_uv=True) 

3154 s = torch.empty((batch, k), dtype=torch.float32, device=input.device) 

3155 u = torch.empty((*input.shape[:-2], m, k), dtype=input.dtype, device=input.device) 

3156 v = torch.empty((*input.shape[:-2], n, k), dtype=input.dtype, device=input.device) 

3157 u_ri = torch.view_as_real(u).reshape(batch, m, k, 2) 

3158 v_ri = torch.view_as_real(v).reshape(batch, n, k, 2) 

3159 with torch_device_fn.device(input.device): 

3160 _complex_svd_pick_s_kernel[(batch,)]( 

3161 s_real, 

3162 s, 

3163 K=k, 

3164 REAL_K=2 * k, 

3165 BLOCK_K=triton.next_power_of_2(k), 

3166 num_warps=1, 

3167 ) 

3168 _complex_svd_pick_orthonormal_v_kernel[(batch,)]( 

3169 v_real, 

3170 v_ri, 

3171 ROWS=n, 

3172 K=k, 

3173 REAL_K=2 * k, 

3174 BLOCK_ROWS=triton.next_power_of_2(n), 

3175 BLOCK_K=triton.next_power_of_2(k), 

3176 num_warps=1, 

3177 ) 

3178 _complex_svd_project_u_kernel[(batch,)]( 

3179 a_ri, 

3180 v_ri, 

3181 s, 

3182 u_ri, 

3183 M=m, 

3184 N=n, 

3185 K=k, 

3186 BLOCK_SIZE=triton.next_power_of_2(m * k), 

3187 num_warps=1, 

3188 ) 

3189 return ( 

3190 u, 

3191 s.reshape(*input.shape[:-2], k), 

3192 v, 

3193 ) 

3194 

3195 

3196def _gram_svd(input): 

3197 return _unsupported_svd(input, True, True) 

3198 

3199 

3200@libentry() 

3201@triton.jit 

3202def _gram16_finalize_kernel( 

3203 A, 

3204 EVALS, 

3205 EVECS, 

3206 U, 

3207 S, 

3208 V, 

3209 M: tl.constexpr, 

3210 N: tl.constexpr, 

3211 ROWS: tl.constexpr, 

3212 TALL: tl.constexpr, 

3213 EVECS_BATCH_STRIDE: tl.constexpr, 

3214 EVECS_ROW_STRIDE: tl.constexpr, 

3215 EVECS_COL_STRIDE: tl.constexpr, 

3216 BLOCK_R: tl.constexpr, 

3217): 

3218 batch = tl.program_id(0) 

3219 row_block = tl.program_id(1) 

3220 rows = row_block * BLOCK_R + tl.arange(0, BLOCK_R) 

3221 cols = tl.arange(0, 16) 

3222 src_cols = 15 - cols 

3223 row_mask = rows < ROWS 

3224 eps = 1.0e-20 

3225 

3226 vals = tl.load(EVALS + batch * 16 + src_cols) 

3227 s_vals = tl.sqrt(tl.maximum(vals, 0.0)) 

3228 inv_s = tl.where(s_vals > eps, 1.0 / s_vals, 0.0) 

3229 

3230 acc = tl.zeros((BLOCK_R, 16), dtype=tl.float32) 

3231 a_base = A + batch * M * N 

3232 e_base = EVECS + batch * EVECS_BATCH_STRIDE 

3233 for k in tl.static_range(0, 16): 

3234 eig = tl.load(e_base + k * EVECS_ROW_STRIDE + src_cols * EVECS_COL_STRIDE) 

3235 if TALL: 

3236 a_vals = tl.load( 

3237 a_base + rows * N + k, 

3238 mask=row_mask, 

3239 other=0.0, 

3240 ) 

3241 else: 

3242 a_vals = tl.load( 

3243 a_base + k * N + rows, 

3244 mask=row_mask, 

3245 other=0.0, 

3246 ) 

3247 acc += a_vals[:, None] * eig[None, :] 

3248 

3249 projected = acc * inv_s[None, :] 

3250 if TALL: 

3251 tl.store( 

3252 U + batch * M * 16 + rows[:, None] * 16 + cols[None, :], 

3253 projected, 

3254 mask=row_mask[:, None], 

3255 ) 

3256 else: 

3257 tl.store( 

3258 V + batch * N * 16 + rows[:, None] * 16 + cols[None, :], 

3259 projected, 

3260 mask=row_mask[:, None], 

3261 ) 

3262 

3263 head_mask = row_block == 0 

3264 tl.store(S + batch * 16 + cols, s_vals, mask=head_mask) 

3265 

3266 basis_rows = tl.arange(0, 16) 

3267 basis_cols = tl.arange(0, 16) 

3268 basis_src_cols = 15 - basis_cols 

3269 basis = tl.load( 

3270 e_base 

3271 + basis_rows[:, None] * EVECS_ROW_STRIDE 

3272 + basis_src_cols[None, :] * EVECS_COL_STRIDE 

3273 ) 

3274 if TALL: 

3275 tl.store( 

3276 V + batch * N * 16 + basis_rows[:, None] * 16 + basis_cols[None, :], 

3277 basis, 

3278 mask=head_mask, 

3279 ) 

3280 else: 

3281 tl.store( 

3282 U + batch * M * 16 + basis_rows[:, None] * 16 + basis_cols[None, :], 

3283 basis, 

3284 mask=head_mask, 

3285 ) 

3286 

3287 

3288def _gram16_svd(input): 

3289 return _unsupported_svd(input, True, True) 

3290 

3291 

3292def _large_native_svd(input): 

3293 if _can_use_hier_block_square_project_kernel(input, True, True): 

3294 return _hier_block_jacobi_square_project_svd(input) 

3295 if _can_use_blocked_square_project_kernel(input, True, True): 

3296 return _blocked_jacobi_square_project_svd(input) 

3297 if _can_use_blocked_jacobi_kernel(input, True, True): 

3298 return _blocked_jacobi_svd(input) 

3299 return _bidiagonal_qr_dqds_svd(input) 

3300 

3301 

3302def _bidiagonal_qr_dqds_svd(input): 

3303 return _unsupported_svd( 

3304 input, 

3305 True, 

3306 True, 

3307 "The k > 512 blocked-bidiagonalization plus QR/DQDS path is reserved " 

3308 "for the next native large-matrix solver stage.", 

3309 ) 

3310 

3311 

3312def _empty_svd_result(input, some=True, compute_uv=True): 

3313 _, m, n = _svd_shape(input) 

3314 k = min(m, n) 

3315 u_cols = k if compute_uv and some else m 

3316 v_cols = k if compute_uv and some else n 

3317 u = torch.empty( 

3318 (*input.shape[:-2], m, u_cols), dtype=input.dtype, device=input.device 

3319 ) 

3320 s = torch.empty((*input.shape[:-2], k), dtype=input.dtype, device=input.device) 

3321 v = torch.empty( 

3322 (*input.shape[:-2], n, v_cols), dtype=input.dtype, device=input.device 

3323 ) 

3324 return u, s, v 

3325 

3326 

3327@libentry() 

3328@triton.jit 

3329def _complete_svd_factor_kernel( 

3330 THIN, 

3331 FULL, 

3332 ROWS: tl.constexpr, 

3333 THIN_COLS: tl.constexpr, 

3334 FULL_COLS: tl.constexpr, 

3335 BLOCK_ROWS: tl.constexpr, 

3336 BLOCK_COLS: tl.constexpr, 

3337): 

3338 batch = tl.program_id(0) 

3339 rows = tl.arange(0, BLOCK_ROWS) 

3340 cols = tl.arange(0, BLOCK_COLS) 

3341 row_mask = rows < ROWS 

3342 col_mask = cols < FULL_COLS 

3343 vals = tl.load( 

3344 THIN + batch * ROWS * THIN_COLS + rows[:, None] * THIN_COLS + cols[None, :], 

3345 mask=row_mask[:, None] & (cols[None, :] < THIN_COLS), 

3346 other=0.0, 

3347 ) 

3348 identity = tl.where(rows[:, None] == cols[None, :], 1.0, 0.0) 

3349 vals = tl.where(cols[None, :] < THIN_COLS, vals, identity) 

3350 

3351 for c in tl.static_range(0, 64): 

3352 cur_mask = c < FULL_COLS 

3353 cur = tl.sum(tl.where(cols[None, :] == c, vals, 0.0), axis=1) 

3354 for p in tl.static_range(0, c): 

3355 prev = tl.sum(tl.where(cols[None, :] == p, vals, 0.0), axis=1) 

3356 coeff = tl.sum(tl.where(row_mask, prev * cur, 0.0), axis=0) 

3357 cur -= prev * coeff 

3358 norm_sq = tl.sum(tl.where(row_mask, cur * cur, 0.0), axis=0) 

3359 inv_norm = tl.rsqrt(tl.maximum(norm_sq, 1.0e-20)) 

3360 cur *= inv_norm 

3361 vals = tl.where((cols[None, :] == c) & cur_mask, cur[:, None], vals) 

3362 

3363 out_base = FULL + batch * ROWS * FULL_COLS 

3364 offsets = rows[:, None] * FULL_COLS + cols[None, :] 

3365 mask = row_mask[:, None] & col_mask[None, :] 

3366 tl.store(out_base + offsets, vals, mask=mask) 

3367 

3368 

3369def _low_precision_svd_via_float32(input, some=True, compute_uv=True): 

3370 u, s, v = svd(input.to(torch.float32), some=some, compute_uv=compute_uv) 

3371 return u.to(input.dtype), s.to(input.dtype), v.to(input.dtype) 

3372 

3373 

3374def _some_false_svd_via_thin(input): 

3375 batch, m, n = _svd_shape(input) 

3376 k = min(m, n) 

3377 thin_u, s, thin_v = svd(input, some=True, compute_uv=True) 

3378 u = torch.empty((*input.shape[:-2], m, m), dtype=input.dtype, device=input.device) 

3379 v = torch.empty((*input.shape[:-2], n, n), dtype=input.dtype, device=input.device) 

3380 with torch_device_fn.device(input.device): 

3381 _complete_svd_factor_kernel[(batch,)]( 

3382 thin_u, 

3383 u, 

3384 ROWS=m, 

3385 THIN_COLS=k, 

3386 FULL_COLS=m, 

3387 BLOCK_ROWS=triton.next_power_of_2(m), 

3388 BLOCK_COLS=triton.next_power_of_2(m), 

3389 num_warps=4, 

3390 ) 

3391 _complete_svd_factor_kernel[(batch,)]( 

3392 thin_v, 

3393 v, 

3394 ROWS=n, 

3395 THIN_COLS=k, 

3396 FULL_COLS=n, 

3397 BLOCK_ROWS=triton.next_power_of_2(n), 

3398 BLOCK_COLS=triton.next_power_of_2(n), 

3399 num_warps=4, 

3400 ) 

3401 return u, s, v 

3402 

3403 

3404def _compute_uv_false_result(input, s): 

3405 _, m, n = _svd_shape(input) 

3406 u = torch.empty((*input.shape[:-2], m, m), dtype=input.dtype, device=input.device) 

3407 v = torch.empty((*input.shape[:-2], n, n), dtype=input.dtype, device=input.device) 

3408 return u, s, v 

3409 

3410 

3411def _singular_values_only(input): 

3412 _, m, n = _svd_shape(input) 

3413 k = min(m, n) 

3414 largest = max(m, n) 

3415 if k == 2 and largest <= _RANK2_BLOCK_R_MAX: 

3416 return _rank2_singular_values(input) 

3417 if k <= 16 and largest <= 1024: 

3418 return _small_jacobi_singular_values(input) 

3419 if 16 < k <= 512 and largest <= 1024: 

3420 return _blocked_jacobi_singular_values(input) 

3421 return _unsupported_svd(input, True, False) 

3422 

3423 

3424def _should_use_gram16(batch, m, n): 

3425 return batch >= 16 and min(m, n) == 16 and max(m, n) <= 1024 

3426 

3427 

3428def _should_use_gram(batch, m, n): 

3429 k = min(m, n) 

3430 largest = max(m, n) 

3431 if k <= 32: 

3432 return True 

3433 if batch <= 4 and m == n and m <= 256: 

3434 return True 

3435 if (m, n) == (1024, 1024): 

3436 return True 

3437 if batch >= 128 and k <= 64 and largest <= 1024: 

3438 return False 

3439 return False 

3440 

3441 

3442def svd(input, some=True, compute_uv=True): 

3443 logger.debug("GEMS SVD") 

3444 if ( 

3445 input.is_cuda 

3446 and input.dtype == torch.complex64 

3447 and some 

3448 and compute_uv 

3449 and input.dim() >= 2 

3450 and 0 not in input.shape[-2:] 

3451 and max(input.shape[-2:]) <= 16 

3452 ): 

3453 return SVDResult(*_complex_svd_via_real_embedding(input)) 

3454 if _is_low_precision_cuda_matrix(input): 

3455 return SVDResult(*_low_precision_svd_via_float32(input, some, compute_uv)) 

3456 if _is_float32_cuda_matrix(input) and 0 in input.shape[-2:]: 

3457 return SVDResult(*_empty_svd_result(input, some, compute_uv)) 

3458 if _can_use_singular_values_only(input, some, compute_uv): 

3459 return SVDResult(*_compute_uv_false_result(input, _singular_values_only(input))) 

3460 if ( 

3461 _is_float32_cuda_matrix(input) 

3462 and not some 

3463 and compute_uv 

3464 and max(input.shape[-2:]) <= 64 

3465 ): 

3466 return SVDResult(*_some_false_svd_via_thin(input)) 

3467 if not _is_float32_cuda_matrix(input) or not some: 

3468 return SVDResult(*_unsupported_svd(input, some, compute_uv)) 

3469 batch, m, n = _svd_shape(input) 

3470 k = min(m, n) 

3471 try: 

3472 if k == 1: 

3473 return SVDResult(*_rank1_svd(input)) 

3474 if k == 2 and max(m, n) <= _RANK2_BLOCK_R_MAX: 

3475 return SVDResult(*_rank2_svd(input)) 

3476 if k == 4 and m == 4 and n == 4 and batch >= 16: 

3477 return SVDResult(*_small4_square_svd(input)) 

3478 if _can_use_tall_wide_gram_jacobi_kernel(input, some, compute_uv): 

3479 return SVDResult(*_gram_jacobi_svd(input)) 

3480 use_batched_cyclic16 = k == 16 and batch >= 8 and max(m, n) <= 64 

3481 if ( 

3482 _can_use_small_jacobi_kernel(input, some, compute_uv) 

3483 and not use_batched_cyclic16 

3484 ): 

3485 return SVDResult(*_small_jacobi_svd(input)) 

3486 if _can_use_tsqr_cholesky_kernel(input, some, compute_uv): 

3487 return SVDResult(*_tsqr_cholesky_svd(input)) 

3488 if _can_use_projected_jacobi_kernel(input, some, compute_uv): 

3489 return SVDResult(*_projected_jacobi_svd(input)) 

3490 if _can_use_cyclic_jacobi_kernel(input, some, compute_uv): 

3491 return SVDResult(*_cyclic_jacobi_svd(input)) 

3492 if _can_use_hier_block_square_project_kernel(input, some, compute_uv): 

3493 return SVDResult(*_hier_block_jacobi_square_project_svd(input)) 

3494 if _can_use_blocked_square_project_kernel(input, some, compute_uv): 

3495 return SVDResult(*_blocked_jacobi_square_project_svd(input)) 

3496 if _can_use_blocked_jacobi_kernel(input, some, compute_uv): 

3497 return SVDResult(*_blocked_jacobi_svd(input)) 

3498 if _should_use_gram16(batch, m, n): 

3499 return SVDResult(*_gram16_svd(input)) 

3500 if _should_use_gram(batch, m, n): 

3501 return SVDResult(*_gram_svd(input)) 

3502 return SVDResult(*_large_native_svd(input)) 

3503 except RuntimeError: 

3504 return SVDResult(*_unsupported_svd(input, some, compute_uv))