Coverage for src/flag_gems/ops/svd.py: 21%
1838 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
2from collections import namedtuple
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import device, torch_device_fn
9from flag_gems.utils import libentry
11logger = logging.getLogger(__name__)
13SVDResult = namedtuple("SVDResult", ["U", "S", "V"])
15_GRAM_CONDITION_GUARD_MAX_BATCH = 16
16_GRAM_CONDITION_GUARD_MAX_K = 32
17_GRAM_CONDITION_EIGEN_RATIO = 1.0e-8
18_GRAM_TALL_WIDE_MAX_K = 32
19_GRAM_TALL_WIDE_MAX_ROWS = 1024
20_RANK1_BLOCK_R_MAX = 1024
21_RANK2_BLOCK_R_MAX = 2048
22_TSQR_CHOLESKY_MAX_BATCH = 32
23_TSQR_CHOLESKY_MAX_K = 128
24_TSQR_CHOLESKY_MAX_ROWS = 1024
27def _unsupported_svd(input, some=True, compute_uv=True, reason=None):
28 batch, m, n = _svd_shape(input)
29 suffix = "" if reason is None else f" {reason}"
30 raise NotImplementedError(
31 "FlagGems native SVD currently supports float32 CUDA matrices with "
32 "some=True, compute_uv=True, non-empty inputs, and native Triton "
33 f"rank/Jacobi shape coverage; got batch={batch}, m={m}, n={n}, "
34 f"dtype={input.dtype}, device={input.device}, some={some}, "
35 f"compute_uv={compute_uv}.{suffix}"
36 )
39def _is_iluvatar_backend():
40 return device.vendor_name == "iluvatar"
43def _svd_shape(input):
44 if input.dim() < 2:
45 return 0, 0, 0
46 m = input.shape[-2]
47 n = input.shape[-1]
48 batch = 1
49 for dim in input.shape[:-2]:
50 batch *= dim
51 return batch, m, n
54def _should_guard_gram_spectrum(batch, k):
55 return batch <= _GRAM_CONDITION_GUARD_MAX_BATCH and k <= _GRAM_CONDITION_GUARD_MAX_K
58def _is_float32_cuda_matrix(input):
59 return input.is_cuda and input.dtype == torch.float32 and input.dim() >= 2
62def _is_low_precision_cuda_matrix(input):
63 return (
64 input.is_cuda
65 and input.dtype in (torch.float16, torch.bfloat16)
66 and input.dim() >= 2
67 )
70def _can_use_rank1_kernel(input, some=True, compute_uv=True):
71 _, m, n = _svd_shape(input)
72 return _is_float32_cuda_matrix(input) and some and compute_uv and min(m, n) == 1
75def _can_use_rank2_kernel(input, some=True, compute_uv=True):
76 _, m, n = _svd_shape(input)
77 return (
78 _is_float32_cuda_matrix(input)
79 and some
80 and compute_uv
81 and min(m, n) == 2
82 and max(m, n) <= _RANK2_BLOCK_R_MAX
83 )
86def _can_use_2x2_kernel(input):
87 _, m, n = _svd_shape(input)
88 return _can_use_rank2_kernel(input, True, True) and m == 2 and n == 2
91def _can_use_4x4_kernel(input, some=True, compute_uv=True):
92 _, m, n = _svd_shape(input)
93 return _is_float32_cuda_matrix(input) and some and compute_uv and m == 4 and n == 4
96def _can_use_small_jacobi_kernel(input, some=True, compute_uv=True):
97 _, m, n = _svd_shape(input)
98 return (
99 _is_float32_cuda_matrix(input)
100 and some
101 and compute_uv
102 and not _is_iluvatar_backend()
103 and min(m, n) <= 16
104 and max(m, n) <= 1024
105 )
108def _can_use_cyclic_jacobi_kernel(input, some=True, compute_uv=True):
109 _, m, n = _svd_shape(input)
110 k = min(m, n)
111 return (
112 _is_float32_cuda_matrix(input)
113 and some
114 and compute_uv
115 and 16 <= k <= 64
116 and max(m, n) <= 1024
117 )
120def _can_use_gram_jacobi_kernel(input, some=True, compute_uv=True):
121 _, m, n = _svd_shape(input)
122 k = min(m, n)
123 return (
124 _is_float32_cuda_matrix(input)
125 and some
126 and compute_uv
127 and 16 <= k <= 32
128 and max(m, n) <= 64
129 )
132def _can_use_tall_wide_gram_jacobi_kernel(input, some=True, compute_uv=True):
133 batch, m, n = _svd_shape(input)
134 k = min(m, n)
135 rows = max(m, n)
136 return (
137 _is_float32_cuda_matrix(input)
138 and some
139 and compute_uv
140 and batch >= 128
141 and 16 <= k <= _GRAM_TALL_WIDE_MAX_K
142 and rows <= _GRAM_TALL_WIDE_MAX_ROWS
143 and rows >= 2 * k
144 )
147def _can_use_tsqr_cholesky_kernel(input, some=True, compute_uv=True):
148 # Input-dependent TSQR safety needs a native device-side guard before dispatch.
149 return False
152def _can_use_blocked_jacobi_kernel(input, some=True, compute_uv=True):
153 _, m, n = _svd_shape(input)
154 k = min(m, n)
155 return (
156 _is_float32_cuda_matrix(input)
157 and some
158 and compute_uv
159 and 64 < k <= 512
160 and max(m, n) <= 1024
161 )
164def _can_use_blocked_square_project_kernel(input, some=True, compute_uv=True):
165 batch, m, n = _svd_shape(input)
166 k = min(m, n)
167 return (
168 _is_float32_cuda_matrix(input)
169 and some
170 and compute_uv
171 and batch == 1
172 and m == n
173 and 128 <= k <= 512
174 )
177def _can_use_hier_block_square_project_kernel(input, some=True, compute_uv=True):
178 batch, m, n = _svd_shape(input)
179 k = min(m, n)
180 return (
181 _is_float32_cuda_matrix(input)
182 and some
183 and compute_uv
184 and batch <= 2
185 and m == n
186 and k in (256, 512)
187 )
190def _can_use_projected_jacobi_kernel(input, some=True, compute_uv=True):
191 batch, m, n = _svd_shape(input)
192 k = min(m, n)
193 return (
194 _is_float32_cuda_matrix(input)
195 and some
196 and compute_uv
197 and 4 <= batch <= 32
198 and k == 64
199 and max(m, n) <= 128
200 )
203def _can_use_singular_values_only(input, some=True, compute_uv=False):
204 _, m, n = _svd_shape(input)
205 k = min(m, n)
206 return (
207 _is_float32_cuda_matrix(input)
208 and not compute_uv
209 and k <= 512
210 and max(m, n) <= 1024
211 )
214@libentry()
215@triton.jit
216def _small_jacobi_svd_kernel(
217 A,
218 A_WORK,
219 V_WORK,
220 U,
221 S,
222 V,
223 M: tl.constexpr,
224 N: tl.constexpr,
225 K: tl.constexpr,
226 ROWS: tl.constexpr,
227 TALL: tl.constexpr,
228 BLOCK_R: tl.constexpr,
229 BLOCK_K: tl.constexpr,
230 SWEEPS: tl.constexpr,
231):
232 pid = tl.program_id(0)
233 rows = tl.arange(0, BLOCK_R)
234 cols = tl.arange(0, BLOCK_K)
235 row_mask = rows < ROWS
236 col_mask = cols < K
237 eps = 1.0e-20
239 a_base = A + pid * M * N
240 aw_base = A_WORK + pid * K * ROWS
241 vw_base = V_WORK + pid * K * K
243 for j in tl.static_range(0, K):
244 if TALL:
245 vals = tl.load(a_base + rows * N + j, mask=row_mask, other=0.0).to(
246 tl.float32
247 )
248 else:
249 vals = tl.load(a_base + j * N + rows, mask=row_mask, other=0.0).to(
250 tl.float32
251 )
252 tl.store(aw_base + j * ROWS + rows, vals, mask=row_mask)
253 ident_col = tl.where(cols == j, 1.0, 0.0)
254 tl.store(vw_base + j * K + cols, ident_col, mask=col_mask)
256 for _ in tl.static_range(0, SWEEPS):
257 for p in tl.static_range(0, K):
258 for q in tl.static_range(p + 1, K):
259 ap = tl.load(aw_base + p * ROWS + rows, mask=row_mask, other=0.0)
260 aq = tl.load(aw_base + q * ROWS + rows, mask=row_mask, other=0.0)
261 alpha = tl.sum(ap * ap)
262 beta = tl.sum(aq * aq)
263 gamma = tl.sum(ap * aq)
264 abs_gamma = tl.abs(gamma)
265 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps)
266 active = abs_gamma > threshold
268 safe_gamma = tl.where(active, gamma, 1.0)
269 tau = (beta - alpha) / (2.0 * safe_gamma)
270 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0)
271 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau))
272 c = tl.rsqrt(1.0 + t * t)
273 s_rot = t * c
274 c = tl.where(active, c, 1.0)
275 s_rot = tl.where(active, s_rot, 0.0)
277 new_ap = c * ap - s_rot * aq
278 new_aq = s_rot * ap + c * aq
279 tl.store(aw_base + p * ROWS + rows, new_ap, mask=row_mask)
280 tl.store(aw_base + q * ROWS + rows, new_aq, mask=row_mask)
282 vp = tl.load(vw_base + p * K + cols, mask=col_mask, other=0.0)
283 vq = tl.load(vw_base + q * K + cols, mask=col_mask, other=0.0)
284 new_vp = c * vp - s_rot * vq
285 new_vq = s_rot * vp + c * vq
286 tl.store(vw_base + p * K + cols, new_vp, mask=col_mask)
287 tl.store(vw_base + q * K + cols, new_vq, mask=col_mask)
289 s_idx = tl.arange(0, BLOCK_K)
290 s_vals = tl.full((BLOCK_K,), 0.0, dtype=tl.float32)
291 for j in tl.static_range(0, K):
292 col = tl.load(aw_base + j * ROWS + rows, mask=row_mask, other=0.0)
293 norm = tl.sqrt(tl.sum(col * col))
294 s_vals = tl.where(s_idx == j, norm, s_vals)
296 ranks = tl.zeros((BLOCK_K,), dtype=tl.int32)
297 for i in tl.static_range(0, K):
298 si = tl.sum(tl.where(s_idx == i, s_vals, 0.0))
299 beats = ((si > s_vals) | ((si == s_vals) & (i < s_idx))) & (s_idx < K)
300 ranks = ranks + beats.to(tl.int32)
302 for j in tl.static_range(0, K):
303 col = tl.load(aw_base + j * ROWS + rows, mask=row_mask, other=0.0)
304 norm = tl.sum(tl.where(s_idx == j, s_vals, 0.0))
305 rank = tl.sum(tl.where(s_idx == j, ranks, 0))
306 inv_norm = tl.where(norm > eps, 1.0 / norm, 0.0)
307 tl.store(S + pid * K + rank, norm)
309 basis = tl.load(vw_base + j * K + cols, mask=col_mask, other=0.0)
310 if TALL:
311 tl.store(U + pid * M * K + rows * K + rank, col * inv_norm, mask=row_mask)
312 tl.store(V + pid * N * K + cols * K + rank, basis, mask=col_mask)
313 else:
314 tl.store(U + pid * M * K + cols * K + rank, basis, mask=col_mask)
315 tl.store(V + pid * N * K + rows * K + rank, col * inv_norm, mask=row_mask)
318@libentry()
319@triton.jit
320def _small_jacobi_svals_kernel(
321 A,
322 A_WORK,
323 S,
324 M: tl.constexpr,
325 N: tl.constexpr,
326 K: tl.constexpr,
327 ROWS: tl.constexpr,
328 TALL: tl.constexpr,
329 BLOCK_R: tl.constexpr,
330 BLOCK_K: tl.constexpr,
331 SWEEPS: tl.constexpr,
332):
333 pid = tl.program_id(0)
334 rows = tl.arange(0, BLOCK_R)
335 s_idx = tl.arange(0, BLOCK_K)
336 row_mask = rows < ROWS
337 eps = 1.0e-20
339 a_base = A + pid * M * N
340 aw_base = A_WORK + pid * K * ROWS
342 for j in tl.static_range(0, K):
343 if TALL:
344 vals = tl.load(a_base + rows * N + j, mask=row_mask, other=0.0).to(
345 tl.float32
346 )
347 else:
348 vals = tl.load(a_base + j * N + rows, mask=row_mask, other=0.0).to(
349 tl.float32
350 )
351 tl.store(aw_base + j * ROWS + rows, vals, mask=row_mask)
353 for _ in tl.static_range(0, SWEEPS):
354 for p in tl.static_range(0, K):
355 for q in tl.static_range(p + 1, K):
356 ap = tl.load(aw_base + p * ROWS + rows, mask=row_mask, other=0.0)
357 aq = tl.load(aw_base + q * ROWS + rows, mask=row_mask, other=0.0)
358 alpha = tl.sum(ap * ap)
359 beta = tl.sum(aq * aq)
360 gamma = tl.sum(ap * aq)
361 abs_gamma = tl.abs(gamma)
362 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps)
363 active = abs_gamma > threshold
365 safe_gamma = tl.where(active, gamma, 1.0)
366 tau = (beta - alpha) / (2.0 * safe_gamma)
367 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0)
368 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau))
369 c = tl.rsqrt(1.0 + t * t)
370 s_rot = t * c
371 c = tl.where(active, c, 1.0)
372 s_rot = tl.where(active, s_rot, 0.0)
374 new_ap = c * ap - s_rot * aq
375 new_aq = s_rot * ap + c * aq
376 tl.store(aw_base + p * ROWS + rows, new_ap, mask=row_mask)
377 tl.store(aw_base + q * ROWS + rows, new_aq, mask=row_mask)
379 s_vals = tl.full((BLOCK_K,), 0.0, dtype=tl.float32)
380 for j in tl.static_range(0, K):
381 col = tl.load(aw_base + j * ROWS + rows, mask=row_mask, other=0.0)
382 norm = tl.sqrt(tl.sum(col * col))
383 s_vals = tl.where(s_idx == j, norm, s_vals)
385 ranks = tl.zeros((BLOCK_K,), dtype=tl.int32)
386 for i in tl.static_range(0, K):
387 si = tl.sum(tl.where(s_idx == i, s_vals, 0.0))
388 beats = ((si > s_vals) | ((si == s_vals) & (i < s_idx))) & (s_idx < K)
389 ranks = ranks + beats.to(tl.int32)
391 for j in tl.static_range(0, K):
392 norm = tl.sum(tl.where(s_idx == j, s_vals, 0.0))
393 rank = tl.sum(tl.where(s_idx == j, ranks, 0))
394 tl.store(S + pid * K + rank, norm)
397def _can_use_streaming_jacobi_kernel(input, some=True, compute_uv=True):
398 _, m, n = _svd_shape(input)
399 return (
400 _is_float32_cuda_matrix(input)
401 and some
402 and compute_uv
403 and 16 < min(m, n) <= 64
404 and max(m, n) <= 1024
405 )
408def _can_use_gram_kernel(input, some=True, compute_uv=True):
409 _, m, n = _svd_shape(input)
410 return _is_float32_cuda_matrix(input) and some and compute_uv and min(m, n) <= 1024
413@libentry()
414@triton.jit
415def _triton_bmm_kernel(
416 A,
417 B,
418 C,
419 stride_ab,
420 stride_am,
421 stride_ak,
422 stride_bb,
423 stride_bk,
424 stride_bn,
425 M: tl.constexpr,
426 N: tl.constexpr,
427 K: tl.constexpr,
428 BLOCK_M: tl.constexpr,
429 BLOCK_N: tl.constexpr,
430 BLOCK_K: tl.constexpr,
431):
432 tile = tl.program_id(0)
433 batch = tl.program_id(1)
434 tiles_n = tl.cdiv(N, BLOCK_N)
435 tile_m = tile // tiles_n
436 tile_n = tile - tile_m * tiles_n
438 offs_m = tile_m * BLOCK_M + tl.arange(0, BLOCK_M)
439 offs_n = tile_n * BLOCK_N + tl.arange(0, BLOCK_N)
440 offs_k = tl.arange(0, BLOCK_K)
442 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
443 a_base = A + batch * stride_ab
444 b_base = B + batch * stride_bb
445 for k_start in range(0, K, BLOCK_K):
446 k = k_start + offs_k
447 a = tl.load(
448 a_base + offs_m[:, None] * stride_am + k[None, :] * stride_ak,
449 mask=(offs_m[:, None] < M) & (k[None, :] < K),
450 other=0.0,
451 )
452 b = tl.load(
453 b_base + k[:, None] * stride_bk + offs_n[None, :] * stride_bn,
454 mask=(k[:, None] < K) & (offs_n[None, :] < N),
455 other=0.0,
456 )
457 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
459 tl.store(
460 C + batch * M * N + offs_m[:, None] * N + offs_n[None, :],
461 acc,
462 mask=(offs_m[:, None] < M) & (offs_n[None, :] < N),
463 )
466def _triton_bmm(left, right, out_shape):
467 batch, m, k = left.shape
468 right_batch, right_k, n = right.shape
469 assert batch == right_batch, "Batch dim mismatch"
470 assert k == right_k, "K dim mismatch"
471 out = torch.empty((batch, m, n), dtype=left.dtype, device=left.device)
472 block_m = 16 if m <= 16 else 32
473 block_n = 16 if n <= 16 else 32
474 block_k = 32
475 grid = (triton.cdiv(m, block_m) * triton.cdiv(n, block_n), batch)
476 with torch_device_fn.device(left.device):
477 _triton_bmm_kernel[grid](
478 left,
479 right,
480 out,
481 left.stride(0),
482 left.stride(1),
483 left.stride(2),
484 right.stride(0),
485 right.stride(1),
486 right.stride(2),
487 M=m,
488 N=n,
489 K=k,
490 BLOCK_M=block_m,
491 BLOCK_N=block_n,
492 BLOCK_K=block_k,
493 num_warps=1 if block_m == 16 and block_n == 16 else 4,
494 )
495 return out.reshape(out_shape)
498@libentry()
499@triton.jit
500def _gram_build_tiled_kernel(
501 A,
502 GRAM,
503 M: tl.constexpr,
504 N: tl.constexpr,
505 K: tl.constexpr,
506 ROWS: tl.constexpr,
507 TALL: tl.constexpr,
508 BLOCK_I: tl.constexpr,
509 BLOCK_J: tl.constexpr,
510 BLOCK_R: tl.constexpr,
511):
512 tile_i = tl.program_id(0)
513 tile_j = tl.program_id(1)
514 batch = tl.program_id(2)
515 offs_i = tile_i * BLOCK_I + tl.arange(0, BLOCK_I)
516 offs_j = tile_j * BLOCK_J + tl.arange(0, BLOCK_J)
517 rows = tl.arange(0, BLOCK_R)
518 i_mask = offs_i < K
519 j_mask = offs_j < K
520 a_base = A + batch * M * N
521 acc = tl.zeros((BLOCK_I, BLOCK_J), dtype=tl.float32)
523 for row_start in range(0, ROWS, BLOCK_R):
524 chunk_rows = row_start + rows
525 row_mask = chunk_rows < ROWS
526 if TALL:
527 lhs = tl.load(
528 a_base + chunk_rows[None, :] * N + offs_i[:, None],
529 mask=i_mask[:, None] & row_mask[None, :],
530 other=0.0,
531 ).to(tl.float32)
532 rhs = tl.load(
533 a_base + chunk_rows[:, None] * N + offs_j[None, :],
534 mask=row_mask[:, None] & j_mask[None, :],
535 other=0.0,
536 ).to(tl.float32)
537 else:
538 lhs = tl.load(
539 a_base + offs_i[:, None] * N + chunk_rows[None, :],
540 mask=i_mask[:, None] & row_mask[None, :],
541 other=0.0,
542 ).to(tl.float32)
543 rhs = tl.load(
544 a_base + offs_j[None, :] * N + chunk_rows[:, None],
545 mask=row_mask[:, None] & j_mask[None, :],
546 other=0.0,
547 ).to(tl.float32)
548 acc += tl.dot(lhs, rhs, out_dtype=tl.float32, allow_tf32=False)
550 tl.store(
551 GRAM + batch * K * K + offs_i[:, None] * K + offs_j[None, :],
552 acc,
553 mask=i_mask[:, None] & j_mask[None, :],
554 )
557@libentry()
558@triton.jit
559def _cholesky_upper_kernel(
560 GRAM,
561 R,
562 STATUS,
563 K: tl.constexpr,
564 BLOCK_K: tl.constexpr,
565):
566 batch = tl.program_id(0)
567 cols = tl.arange(0, BLOCK_K)
568 col_mask = cols < K
569 base_g = GRAM + batch * K * K
570 base_r = R + batch * K * K
572 diag_vals = tl.load(
573 base_g + cols * K + cols,
574 mask=col_mask,
575 other=0.0,
576 ).to(tl.float32)
577 max_diag = tl.max(tl.abs(diag_vals), axis=0)
578 tol = tl.maximum(max_diag * 1.0e-8, 1.0e-20)
579 status = tl.full((), 0, dtype=tl.int32)
580 finite_limit = 3.4028234663852886e38
582 j = 0
583 while j < K:
584 row_mask = col_mask & (cols >= j)
585 gram_row = tl.load(
586 base_g + j * K + cols,
587 mask=row_mask,
588 other=0.0,
589 ).to(tl.float32)
590 diag = tl.load(base_g + j * K + j).to(tl.float32)
592 p = 0
593 while p < j:
594 r_pj = tl.load(base_r + p * K + j).to(tl.float32)
595 r_pcols = tl.load(
596 base_r + p * K + cols,
597 mask=row_mask,
598 other=0.0,
599 ).to(tl.float32)
600 gram_row -= r_pj * r_pcols
601 diag -= r_pj * r_pj
602 p += 1
604 good_diag = (diag == diag) & (tl.abs(diag) < finite_limit) & (diag > tol)
605 pivot = tl.sqrt(tl.maximum(diag, tol))
606 r_vals = gram_row / pivot
607 r_vals = tl.where(cols == j, pivot, r_vals)
608 r_vals = tl.where(row_mask, r_vals, 0.0)
609 bad_vals = tl.sum(
610 (((r_vals != r_vals) | (tl.abs(r_vals) >= finite_limit)) & row_mask).to(
611 tl.int32
612 ),
613 axis=0,
614 )
615 status = tl.where(good_diag & (bad_vals == 0), status, 1)
616 tl.store(base_r + j * K + cols, r_vals, mask=col_mask)
617 j += 1
619 tl.store(STATUS + batch, status)
622def _tsqr_guard_fallback_svd(input):
623 _, m, n = _svd_shape(input)
624 k = min(m, n)
625 if 16 <= k <= 64 and max(m, n) <= 1024:
626 return _cyclic_jacobi_svd(input)
627 if 64 < k <= 512 and max(m, n) <= 1024:
628 return _blocked_jacobi_svd(input)
629 return _unsupported_svd(
630 input,
631 True,
632 True,
633 "TSQR/Cholesky guard could not find a native Jacobi fallback.",
634 )
637def _tsqr_cholesky_svd(input):
638 batch, m, n = _svd_shape(input)
639 k = min(m, n)
640 rows = max(m, n)
641 tall = m >= n
642 a = input.contiguous().reshape(batch, m, n)
643 gram = torch.empty((batch, k, k), dtype=torch.float32, device=input.device)
644 r = torch.empty((batch, k, k), dtype=torch.float32, device=input.device)
645 status = torch.empty((batch,), dtype=torch.int32, device=input.device)
646 block_k = triton.next_power_of_2(k)
647 block_tile = 32
648 block_r = 64
650 with torch_device_fn.device(input.device):
651 _gram_build_tiled_kernel[
652 (
653 triton.cdiv(k, block_tile),
654 triton.cdiv(k, block_tile),
655 batch,
656 )
657 ](
658 a,
659 gram,
660 M=m,
661 N=n,
662 K=k,
663 ROWS=rows,
664 TALL=tall,
665 BLOCK_I=block_tile,
666 BLOCK_J=block_tile,
667 BLOCK_R=block_r,
668 num_warps=4,
669 )
670 _cholesky_upper_kernel[(batch,)](
671 gram,
672 r,
673 status,
674 K=k,
675 BLOCK_K=block_k,
676 num_warps=4,
677 )
679 _, s, basis = svd(r, some=True, compute_uv=True)
680 basis = basis.reshape(batch, k, k)
681 s = s.reshape(batch, k)
683 if tall:
684 u = _triton_bmm(a, basis, (batch, m, k))
685 v = basis
686 projected = u
687 projected_rows = m
688 else:
689 u = basis
690 v = _triton_bmm(a.transpose(1, 2).contiguous(), basis, (batch, n, k))
691 projected = v
692 projected_rows = n
694 with torch_device_fn.device(input.device):
695 _normalize_projection_kernel[(batch, k)](
696 projected,
697 s,
698 ROWS=projected_rows,
699 K=k,
700 BLOCK_R=triton.next_power_of_2(projected_rows),
701 num_warps=1 if projected_rows <= 64 else 4,
702 )
704 return (
705 u.reshape(*input.shape[:-2], m, k),
706 s.reshape(*input.shape[:-2], k),
707 v.reshape(*input.shape[:-2], n, k),
708 )
711@libentry()
712@triton.jit
713def _gram_build_kernel(
714 A,
715 GRAM,
716 M: tl.constexpr,
717 N: tl.constexpr,
718 K: tl.constexpr,
719 ROWS: tl.constexpr,
720 TALL: tl.constexpr,
721 BLOCK_K: tl.constexpr,
722 BLOCK_R: tl.constexpr,
723):
724 batch = tl.program_id(0)
725 i = tl.arange(0, BLOCK_K)
726 j = tl.arange(0, BLOCK_K)
727 rows = tl.arange(0, BLOCK_R)
728 k_mask = i < K
729 a_base = A + batch * M * N
730 acc = tl.zeros((BLOCK_K, BLOCK_K), dtype=tl.float32)
732 for row_start in range(0, ROWS, BLOCK_R):
733 chunk_rows = row_start + rows
734 row_mask = chunk_rows < ROWS
736 if TALL:
737 lhs = tl.load(
738 a_base + chunk_rows[None, :] * N + i[:, None],
739 mask=k_mask[:, None] & row_mask[None, :],
740 other=0.0,
741 ).to(tl.float32)
742 rhs = tl.load(
743 a_base + chunk_rows[:, None] * N + j[None, :],
744 mask=row_mask[:, None] & (j[None, :] < K),
745 other=0.0,
746 ).to(tl.float32)
747 else:
748 lhs = tl.load(
749 a_base + i[:, None] * N + chunk_rows[None, :],
750 mask=k_mask[:, None] & row_mask[None, :],
751 other=0.0,
752 ).to(tl.float32)
753 rhs = tl.load(
754 a_base + j[None, :] * N + chunk_rows[:, None],
755 mask=row_mask[:, None] & (j[None, :] < K),
756 other=0.0,
757 ).to(tl.float32)
759 acc += tl.dot(lhs, rhs, out_dtype=tl.float32, allow_tf32=False)
761 tl.store(
762 GRAM + batch * K * K + i[:, None] * K + j[None, :],
763 acc,
764 mask=k_mask[:, None] & (j[None, :] < K),
765 )
768@libentry()
769@triton.jit
770def _gram_jacobi_sym_kernel(
771 GRAM,
772 EVECS,
773 EVALS,
774 K,
775 SWEEPS,
776 BLOCK_K: tl.constexpr,
777):
778 batch = tl.program_id(0)
779 r = tl.arange(0, BLOCK_K)
780 cidx = tl.arange(0, BLOCK_K)
781 rr = r[:, None]
782 cc = cidx[None, :]
783 mask = (rr < K) & (cc < K)
784 base = GRAM + batch * K * K
786 g = tl.load(base + rr * K + cc, mask=mask, other=0.0).to(tl.float32)
787 v = tl.where((rr == cc) & mask, 1.0, 0.0)
788 eps = 1.0e-20
790 sweep = 0
791 while sweep < SWEEPS:
792 p = 0
793 while p < K - 1:
794 q = p + 1
795 while q < K:
796 diag_p = tl.sum(tl.where((rr == p) & (cc == p), g, 0.0))
797 diag_q = tl.sum(tl.where((rr == q) & (cc == q), g, 0.0))
798 off = tl.sum(tl.where((rr == p) & (cc == q), g, 0.0))
799 abs_off = tl.abs(off)
800 threshold = 1.0e-7 * tl.sqrt(tl.abs(diag_p * diag_q) + eps)
801 active = abs_off > threshold
802 safe_off = tl.where(active, off, 1.0)
803 tau = (diag_q - diag_p) / (2.0 * safe_off)
804 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0)
805 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau))
806 crot = tl.rsqrt(1.0 + t * t)
807 srot = t * crot
808 crot = tl.where(active, crot, 1.0)
809 srot = tl.where(active, srot, 0.0)
811 col_p = tl.sum(tl.where(cc == p, g, 0.0), axis=1)
812 col_q = tl.sum(tl.where(cc == q, g, 0.0), axis=1)
813 row_p = tl.sum(tl.where(rr == p, g, 0.0), axis=0)
814 row_q = tl.sum(tl.where(rr == q, g, 0.0), axis=0)
816 new_col_p = crot * col_p - srot * col_q
817 new_col_q = srot * col_p + crot * col_q
818 new_row_p = crot * row_p - srot * row_q
819 new_row_q = srot * row_p + crot * row_q
820 g = tl.where(cc == p, new_col_p[:, None], g)
821 g = tl.where(cc == q, new_col_q[:, None], g)
822 g = tl.where(rr == p, new_row_p[None, :], g)
823 g = tl.where(rr == q, new_row_q[None, :], g)
825 new_pp = (
826 crot * crot * diag_p
827 - 2.0 * crot * srot * off
828 + srot * srot * diag_q
829 )
830 new_qq = (
831 srot * srot * diag_p
832 + 2.0 * crot * srot * off
833 + crot * crot * diag_q
834 )
835 g = tl.where((rr == p) & (cc == p), new_pp, g)
836 g = tl.where((rr == q) & (cc == q), new_qq, g)
837 g = tl.where(((rr == p) & (cc == q)) | ((rr == q) & (cc == p)), 0.0, g)
839 vec_p = tl.sum(tl.where(cc == p, v, 0.0), axis=1)
840 vec_q = tl.sum(tl.where(cc == q, v, 0.0), axis=1)
841 new_vec_p = crot * vec_p - srot * vec_q
842 new_vec_q = srot * vec_p + crot * vec_q
843 v = tl.where(cc == p, new_vec_p[:, None], v)
844 v = tl.where(cc == q, new_vec_q[:, None], v)
845 q += 1
846 p += 1
847 sweep += 1
849 diag = tl.sum(tl.where(rr == cc, g, 0.0), axis=1)
850 tl.store(EVALS + batch * K + r, diag, mask=r < K)
851 tl.store(EVECS + batch * K * K + rr * K + cc, v, mask=mask)
854@libentry()
855@triton.jit
856def _gram_sort_basis_kernel(
857 EVALS,
858 EVECS,
859 BASIS,
860 S,
861 K: tl.constexpr,
862 BLOCK_K: tl.constexpr,
863):
864 batch = tl.program_id(0)
865 col = tl.program_id(1)
866 rows = tl.arange(0, BLOCK_K)
867 row_mask = rows < K
868 eval_col = tl.maximum(tl.load(EVALS + batch * K + col), 0.0)
869 rank = tl.full((), 0, dtype=tl.int32)
870 for other in tl.static_range(0, K):
871 eval_other = tl.maximum(tl.load(EVALS + batch * K + other), 0.0)
872 rank += (
873 (eval_other > eval_col) | ((eval_other == eval_col) & (other < col))
874 ).to(tl.int32)
876 vec = tl.load(
877 EVECS + batch * K * K + rows * K + col,
878 mask=row_mask,
879 other=0.0,
880 )
881 tl.store(S + batch * K + rank, tl.sqrt(eval_col))
882 tl.store(
883 BASIS + batch * K * K + rows * K + rank,
884 vec,
885 mask=row_mask,
886 )
889@libentry()
890@triton.jit
891def _normalize_projection_kernel(
892 Q,
893 S,
894 ROWS: tl.constexpr,
895 K: tl.constexpr,
896 BLOCK_R: tl.constexpr,
897):
898 batch = tl.program_id(0)
899 col = tl.program_id(1)
900 rows = tl.arange(0, BLOCK_R)
901 mask = rows < ROWS
902 eps = 1.0e-20
903 sval = tl.load(S + batch * K + col)
904 vals = tl.load(Q + batch * ROWS * K + rows * K + col, mask=mask, other=0.0)
905 vals = vals / tl.maximum(sval, eps)
906 tl.store(Q + batch * ROWS * K + rows * K + col, vals, mask=mask)
909@libentry()
910@triton.jit
911def _renorm_projection_update_s_kernel(
912 Q,
913 S,
914 ROWS: tl.constexpr,
915 K: tl.constexpr,
916 BLOCK_R: tl.constexpr,
917):
918 batch = tl.program_id(0)
919 col = tl.program_id(1)
920 rows = tl.arange(0, BLOCK_R)
921 mask = rows < ROWS
922 vals = tl.load(Q + batch * ROWS * K + rows * K + col, mask=mask, other=0.0)
923 vals_f32 = vals.to(tl.float32)
924 norm = tl.sqrt(tl.sum(vals_f32 * vals_f32, axis=0))
925 inv_norm = tl.rsqrt(tl.maximum(norm * norm, 1.0e-40))
926 basis = tl.where(rows == col, 1.0, 0.0)
927 vals = tl.where(norm <= 1.0e-20, basis, vals * inv_norm)
928 tl.store(S + batch * K + col, norm)
929 tl.store(Q + batch * ROWS * K + rows * K + col, vals, mask=mask)
932@libentry()
933@triton.jit
934def _complete_zero_projection_kernel(
935 Q,
936 S,
937 ROWS: tl.constexpr,
938 K: tl.constexpr,
939 BLOCK_R: tl.constexpr,
940):
941 batch = tl.program_id(0)
942 col = tl.program_id(1)
943 rows = tl.arange(0, BLOCK_R)
944 mask = rows < ROWS
945 eps = 1.0e-12
946 sval = tl.load(S + batch * K + col)
947 basis = tl.where(rows == col, 1.0, 0.0)
948 old = tl.load(Q + batch * ROWS * K + rows * K + col, mask=mask, other=0.0)
949 vals = tl.where(sval <= eps, basis, old)
950 tl.store(Q + batch * ROWS * K + rows * K + col, vals, mask=mask)
953def _gram_jacobi_svd(input):
954 batch, m, n = _svd_shape(input)
955 k = min(m, n)
956 rows = max(m, n)
957 tall = m >= n
958 a = input.contiguous().reshape(batch, m, n)
959 gram = torch.empty((batch, k, k), dtype=torch.float32, device=input.device)
960 eigvecs = torch.empty((batch, k, k), dtype=torch.float32, device=input.device)
961 evals = torch.empty((batch, k), dtype=torch.float32, device=input.device)
962 basis = torch.empty((batch, k, k), dtype=torch.float32, device=input.device)
963 s = torch.empty((batch, k), dtype=input.dtype, device=input.device)
964 block_k = triton.next_power_of_2(k)
965 block_r = min(triton.next_power_of_2(rows), 64 if k > 32 else 128)
966 sweeps = 12 if k <= 17 else 10
968 with torch_device_fn.device(input.device):
969 _gram_build_kernel[(batch,)](
970 a,
971 gram,
972 M=m,
973 N=n,
974 K=k,
975 ROWS=rows,
976 TALL=tall,
977 BLOCK_K=block_k,
978 BLOCK_R=block_r,
979 num_warps=4,
980 )
981 _gram_jacobi_sym_kernel[(batch,)](
982 gram,
983 eigvecs,
984 evals,
985 k,
986 sweeps,
987 BLOCK_K=block_k,
988 num_warps=4,
989 )
990 with torch_device_fn.device(input.device):
991 _gram_sort_basis_kernel[(batch, k)](
992 evals,
993 eigvecs,
994 basis,
995 s,
996 K=k,
997 BLOCK_K=block_k,
998 num_warps=1,
999 )
1001 if tall:
1002 u = _triton_bmm(a, basis, (batch, m, k))
1003 v = basis
1004 proj_rows = m
1005 else:
1006 a_t = a.transpose(1, 2).contiguous()
1007 v = _triton_bmm(a_t, basis, (batch, n, k))
1008 u = basis
1009 proj_rows = n
1011 with torch_device_fn.device(input.device):
1012 _renorm_projection_update_s_kernel[(batch, k)](
1013 u if tall else v,
1014 s,
1015 ROWS=proj_rows,
1016 K=k,
1017 BLOCK_R=triton.next_power_of_2(proj_rows),
1018 num_warps=1 if proj_rows <= 64 else 4,
1019 )
1020 _complete_zero_projection_kernel[(batch, k)](
1021 u if tall else v,
1022 s,
1023 ROWS=proj_rows,
1024 K=k,
1025 BLOCK_R=triton.next_power_of_2(proj_rows),
1026 num_warps=1 if proj_rows <= 64 else 4,
1027 )
1028 if k <= _GRAM_TALL_WIDE_MAX_K:
1029 _thin_reorthogonalize_kernel[(batch,)](
1030 u if tall else v,
1031 ROWS=proj_rows,
1032 K=k,
1033 BLOCK_R=triton.next_power_of_2(proj_rows),
1034 num_warps=1 if proj_rows <= 64 else 4,
1035 )
1037 return (
1038 u.reshape(*input.shape[:-2], m, k),
1039 s.reshape(*input.shape[:-2], k),
1040 v.reshape(*input.shape[:-2], n, k),
1041 )
1044@triton.jit
1045def _rotate_pair_4(ap, aq, vp, vq):
1046 eps = 1.0e-20
1047 alpha = tl.sum(ap * ap, axis=1)
1048 beta = tl.sum(aq * aq, axis=1)
1049 gamma = tl.sum(ap * aq, axis=1)
1050 abs_gamma = tl.abs(gamma)
1051 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps)
1052 active = abs_gamma > threshold
1053 safe_gamma = tl.where(active, gamma, 1.0)
1054 tau = (beta - alpha) / (2.0 * safe_gamma)
1055 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0)
1056 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau))
1057 c = tl.rsqrt(1.0 + t * t)
1058 s_rot = t * c
1059 c = tl.where(active, c, 1.0)
1060 s_rot = tl.where(active, s_rot, 0.0)
1061 new_ap = c[:, None] * ap - s_rot[:, None] * aq
1062 new_aq = s_rot[:, None] * ap + c[:, None] * aq
1063 new_vp = c[:, None] * vp - s_rot[:, None] * vq
1064 new_vq = s_rot[:, None] * vp + c[:, None] * vq
1065 return new_ap, new_aq, new_vp, new_vq
1068@libentry()
1069@triton.jit
1070def _small4_square_svd_kernel(
1071 A,
1072 U,
1073 S,
1074 V,
1075 BATCH: tl.constexpr,
1076 BLOCK_B: tl.constexpr,
1077 SWEEPS: tl.constexpr,
1078):
1079 pid = tl.program_id(0)
1080 b = pid * BLOCK_B + tl.arange(0, BLOCK_B)
1081 r = tl.arange(0, 4)
1082 bb = b[:, None]
1083 rr = r[None, :]
1084 mask = b < BATCH
1085 full_mask = (bb < BATCH) & (rr < 4)
1086 base = A + bb * 16 + rr * 4
1088 c0 = tl.load(base, mask=full_mask, other=0.0).to(tl.float32)
1089 c1 = tl.load(base + 1, mask=full_mask, other=0.0).to(tl.float32)
1090 c2 = tl.load(base + 2, mask=full_mask, other=0.0).to(tl.float32)
1091 c3 = tl.load(base + 3, mask=full_mask, other=0.0).to(tl.float32)
1093 v0 = tl.where(rr == 0, 1.0, 0.0)
1094 v1 = tl.where(rr == 1, 1.0, 0.0)
1095 v2 = tl.where(rr == 2, 1.0, 0.0)
1096 v3 = tl.where(rr == 3, 1.0, 0.0)
1098 for _ in tl.static_range(0, SWEEPS):
1099 c0, c1, v0, v1 = _rotate_pair_4(c0, c1, v0, v1)
1100 c0, c2, v0, v2 = _rotate_pair_4(c0, c2, v0, v2)
1101 c0, c3, v0, v3 = _rotate_pair_4(c0, c3, v0, v3)
1102 c1, c2, v1, v2 = _rotate_pair_4(c1, c2, v1, v2)
1103 c1, c3, v1, v3 = _rotate_pair_4(c1, c3, v1, v3)
1104 c2, c3, v2, v3 = _rotate_pair_4(c2, c3, v2, v3)
1106 s0 = tl.sqrt(tl.sum(c0 * c0, axis=1))
1107 s1 = tl.sqrt(tl.sum(c1 * c1, axis=1))
1108 s2 = tl.sqrt(tl.sum(c2 * c2, axis=1))
1109 s3 = tl.sqrt(tl.sum(c3 * c3, axis=1))
1110 r0 = (s1 > s0).to(tl.int32) + (s2 > s0).to(tl.int32) + (s3 > s0).to(tl.int32)
1111 r1 = ((s0 >= s1).to(tl.int32)) + (s2 > s1).to(tl.int32) + (s3 > s1).to(tl.int32)
1112 r2 = ((s0 >= s2).to(tl.int32)) + ((s1 >= s2).to(tl.int32)) + (s3 > s2).to(tl.int32)
1113 r3 = (
1114 ((s0 >= s3).to(tl.int32))
1115 + ((s1 >= s3).to(tl.int32))
1116 + ((s2 >= s3).to(tl.int32))
1117 )
1118 eps = 1.0e-20
1120 tl.store(S + b * 4 + r0, s0, mask=mask)
1121 tl.store(S + b * 4 + r1, s1, mask=mask)
1122 tl.store(S + b * 4 + r2, s2, mask=mask)
1123 tl.store(S + b * 4 + r3, s3, mask=mask)
1125 tl.store(
1126 U + bb * 16 + rr * 4 + r0[:, None],
1127 c0 / tl.maximum(s0[:, None], eps),
1128 mask=full_mask,
1129 )
1130 tl.store(
1131 U + bb * 16 + rr * 4 + r1[:, None],
1132 c1 / tl.maximum(s1[:, None], eps),
1133 mask=full_mask,
1134 )
1135 tl.store(
1136 U + bb * 16 + rr * 4 + r2[:, None],
1137 c2 / tl.maximum(s2[:, None], eps),
1138 mask=full_mask,
1139 )
1140 tl.store(
1141 U + bb * 16 + rr * 4 + r3[:, None],
1142 c3 / tl.maximum(s3[:, None], eps),
1143 mask=full_mask,
1144 )
1146 tl.store(V + bb * 16 + rr * 4 + r0[:, None], v0, mask=full_mask)
1147 tl.store(V + bb * 16 + rr * 4 + r1[:, None], v1, mask=full_mask)
1148 tl.store(V + bb * 16 + rr * 4 + r2[:, None], v2, mask=full_mask)
1149 tl.store(V + bb * 16 + rr * 4 + r3[:, None], v3, mask=full_mask)
1152@libentry()
1153@triton.jit
1154def _rank2_svd_tiny_kernel(
1155 A,
1156 U,
1157 S,
1158 V,
1159 BATCH: tl.constexpr,
1160 M: tl.constexpr,
1161 N: tl.constexpr,
1162 TALL: tl.constexpr,
1163 BLOCK_B: tl.constexpr,
1164 BLOCK_R: tl.constexpr,
1165):
1166 pid = tl.program_id(0)
1167 b = pid * BLOCK_B + tl.arange(0, BLOCK_B)
1168 r = tl.arange(0, BLOCK_R)
1169 bb = b[:, None]
1170 rr = r[None, :]
1171 bmask = b < BATCH
1172 eps = 1.0e-20
1174 if TALL:
1175 mask = (bb < BATCH) & (rr < M)
1176 base = A + bb * M * N + rr * N
1177 x = tl.load(base, mask=mask, other=0.0).to(tl.float32)
1178 y = tl.load(base + 1, mask=mask, other=0.0).to(tl.float32)
1179 else:
1180 mask = (bb < BATCH) & (rr < N)
1181 base = A + bb * M * N + rr
1182 x = tl.load(base, mask=mask, other=0.0).to(tl.float32)
1183 y = tl.load(base + N, mask=mask, other=0.0).to(tl.float32)
1185 aa = tl.sum(x * x, axis=1)
1186 bbv = tl.sum(y * y, axis=1)
1187 ab = tl.sum(x * y, axis=1)
1188 diff = aa - bbv
1189 root = tl.sqrt(diff * diff + 4.0 * ab * ab)
1190 l0 = tl.maximum(0.0, 0.5 * (aa + bbv + root))
1191 det = tl.maximum(0.0, aa * bbv - ab * ab)
1192 l1 = tl.where(l0 > eps, det / l0, 0.0)
1193 s0 = tl.sqrt(l0)
1194 s1 = tl.sqrt(l1)
1196 ab_abs = tl.abs(ab)
1197 aa_ge_bb = aa >= bbv
1198 vx0 = tl.where(ab_abs > eps, ab, tl.where(aa_ge_bb, 1.0, 0.0))
1199 vy0 = tl.where(ab_abs > eps, l0 - aa, tl.where(aa_ge_bb, 0.0, 1.0))
1200 inv_norm = tl.rsqrt(vx0 * vx0 + vy0 * vy0 + eps)
1201 vx0 = vx0 * inv_norm
1202 vy0 = vy0 * inv_norm
1203 vx1 = -vy0
1204 vy1 = vx0
1206 tl.store(S + b * 2, s0, mask=bmask)
1207 tl.store(S + b * 2 + 1, s1, mask=bmask)
1208 inv_s0 = tl.where(s0 > eps, 1.0 / s0, 0.0)
1209 inv_s1 = tl.where(s1 > eps, 1.0 / s1, 0.0)
1211 if TALL:
1212 u0 = (x * vx0[:, None] + y * vy0[:, None]) * inv_s0[:, None]
1213 u1 = (x * vx1[:, None] + y * vy1[:, None]) * inv_s1[:, None]
1214 ubase = U + bb * M * 2 + rr * 2
1215 tl.store(ubase, u0, mask=mask)
1216 tl.store(ubase + 1, u1, mask=mask)
1217 vbase = V + b * 4
1218 tl.store(vbase, vx0, mask=bmask)
1219 tl.store(vbase + 1, vx1, mask=bmask)
1220 tl.store(vbase + 2, vy0, mask=bmask)
1221 tl.store(vbase + 3, vy1, mask=bmask)
1222 else:
1223 ubase = U + b * 4
1224 tl.store(ubase, vx0, mask=bmask)
1225 tl.store(ubase + 1, vx1, mask=bmask)
1226 tl.store(ubase + 2, vy0, mask=bmask)
1227 tl.store(ubase + 3, vy1, mask=bmask)
1228 v0 = (x * vx0[:, None] + y * vy0[:, None]) * inv_s0[:, None]
1229 v1 = (x * vx1[:, None] + y * vy1[:, None]) * inv_s1[:, None]
1230 vbase = V + bb * N * 2 + rr * 2
1231 tl.store(vbase, v0, mask=mask)
1232 tl.store(vbase + 1, v1, mask=mask)
1235@libentry()
1236@triton.jit
1237def _rank2_svals_tiny_kernel(
1238 A,
1239 S,
1240 BATCH: tl.constexpr,
1241 M: tl.constexpr,
1242 N: tl.constexpr,
1243 TALL: tl.constexpr,
1244 BLOCK_B: tl.constexpr,
1245 BLOCK_R: tl.constexpr,
1246):
1247 pid = tl.program_id(0)
1248 b = pid * BLOCK_B + tl.arange(0, BLOCK_B)
1249 r = tl.arange(0, BLOCK_R)
1250 bb = b[:, None]
1251 rr = r[None, :]
1252 bmask = b < BATCH
1254 if TALL:
1255 mask = (bb < BATCH) & (rr < M)
1256 base = A + bb * M * N + rr * N
1257 x = tl.load(base, mask=mask, other=0.0).to(tl.float32)
1258 y = tl.load(base + 1, mask=mask, other=0.0).to(tl.float32)
1259 else:
1260 mask = (bb < BATCH) & (rr < N)
1261 base = A + bb * M * N + rr
1262 x = tl.load(base, mask=mask, other=0.0).to(tl.float32)
1263 y = tl.load(base + N, mask=mask, other=0.0).to(tl.float32)
1265 aa = tl.sum(x * x, axis=1)
1266 bbv = tl.sum(y * y, axis=1)
1267 ab = tl.sum(x * y, axis=1)
1268 diff = aa - bbv
1269 root = tl.sqrt(diff * diff + 4.0 * ab * ab)
1270 l0 = tl.maximum(0.0, 0.5 * (aa + bbv + root))
1271 det = tl.maximum(0.0, aa * bbv - ab * ab)
1272 l1 = tl.where(l0 > 1.0e-20, det / l0, 0.0)
1273 tl.store(S + b * 2, tl.sqrt(l0), mask=bmask)
1274 tl.store(S + b * 2 + 1, tl.sqrt(l1), mask=bmask)
1277@libentry()
1278@triton.jit
1279def _rank2_svals_kernel(
1280 A,
1281 S,
1282 M: tl.constexpr,
1283 N: tl.constexpr,
1284 TALL: tl.constexpr,
1285 BLOCK_R: tl.constexpr,
1286):
1287 pid = tl.program_id(0)
1288 offs = tl.arange(0, BLOCK_R)
1290 if TALL:
1291 mask = offs < M
1292 base = A + pid * M * N
1293 x = tl.load(base + offs * N, mask=mask, other=0.0).to(tl.float32)
1294 y = tl.load(base + offs * N + 1, mask=mask, other=0.0).to(tl.float32)
1295 else:
1296 mask = offs < N
1297 base = A + pid * M * N
1298 x = tl.load(base + offs, mask=mask, other=0.0).to(tl.float32)
1299 y = tl.load(base + N + offs, mask=mask, other=0.0).to(tl.float32)
1301 aa = tl.sum(x * x)
1302 bb = tl.sum(y * y)
1303 ab = tl.sum(x * y)
1304 diff = aa - bb
1305 root = tl.sqrt(diff * diff + 4.0 * ab * ab)
1306 l0 = tl.maximum(0.0, 0.5 * (aa + bb + root))
1307 det = tl.maximum(0.0, aa * bb - ab * ab)
1308 l1 = tl.where(l0 > 1.0e-20, det / l0, 0.0)
1310 sbase = S + pid * 2
1311 tl.store(sbase, tl.sqrt(l0))
1312 tl.store(sbase + 1, tl.sqrt(l1))
1315@libentry()
1316@triton.jit
1317def _rank2_svd_kernel(
1318 A,
1319 U,
1320 S,
1321 V,
1322 M: tl.constexpr,
1323 N: tl.constexpr,
1324 TALL: tl.constexpr,
1325 BLOCK_R: tl.constexpr,
1326):
1327 pid = tl.program_id(0)
1328 offs = tl.arange(0, BLOCK_R)
1329 eps = 1.0e-20
1331 if TALL:
1332 mask = offs < M
1333 base = A + pid * M * N
1334 x = tl.load(base + offs * N, mask=mask, other=0.0).to(tl.float32)
1335 y = tl.load(base + offs * N + 1, mask=mask, other=0.0).to(tl.float32)
1336 else:
1337 mask = offs < N
1338 base = A + pid * M * N
1339 x = tl.load(base + offs, mask=mask, other=0.0).to(tl.float32)
1340 y = tl.load(base + N + offs, mask=mask, other=0.0).to(tl.float32)
1342 aa = tl.sum(x * x)
1343 bb = tl.sum(y * y)
1344 ab = tl.sum(x * y)
1345 diff = aa - bb
1346 root = tl.sqrt(diff * diff + 4.0 * ab * ab)
1347 l0 = tl.maximum(0.0, 0.5 * (aa + bb + root))
1348 det = tl.maximum(0.0, aa * bb - ab * ab)
1349 l1 = tl.where(l0 > eps, det / l0, 0.0)
1350 s0 = tl.sqrt(l0)
1351 s1 = tl.sqrt(l1)
1353 ab_abs = tl.abs(ab)
1354 aa_ge_bb = aa >= bb
1355 vx0 = tl.where(ab_abs > eps, ab, tl.where(aa_ge_bb, 1.0, 0.0))
1356 vy0 = tl.where(ab_abs > eps, l0 - aa, tl.where(aa_ge_bb, 0.0, 1.0))
1357 inv_norm = tl.rsqrt(vx0 * vx0 + vy0 * vy0 + eps)
1358 vx0 = vx0 * inv_norm
1359 vy0 = vy0 * inv_norm
1360 vx1 = -vy0
1361 vy1 = vx0
1363 sbase = S + pid * 2
1364 tl.store(sbase, s0)
1365 tl.store(sbase + 1, s1)
1367 inv_s0 = tl.where(s0 > eps, 1.0 / s0, 0.0)
1368 inv_s1 = tl.where(s1 > eps, 1.0 / s1, 0.0)
1370 if TALL:
1371 ubase = U + pid * M * 2
1372 u0 = (x * vx0 + y * vy0) * inv_s0
1373 basis0 = tl.where(offs == 0, 1.0, 0.0)
1374 basis1 = tl.where(offs == 1, 1.0, 0.0)
1375 u0 = tl.where(s0 > eps, u0, basis0)
1377 u1 = (x * vx1 + y * vy1) * inv_s1
1378 u0_first = tl.sum(tl.where(offs == 0, u0, 0.0))
1379 anchor = tl.where(tl.abs(u0_first) < 0.70710678, basis0, basis1)
1380 dot = tl.sum(anchor * u0)
1381 fallback_u1 = anchor - dot * u0
1382 fallback_norm = tl.sum(fallback_u1 * fallback_u1)
1383 fallback_u1 = fallback_u1 * tl.rsqrt(fallback_norm + eps)
1384 u1 = tl.where(s1 > s0 * 5.0e-4, u1, fallback_u1)
1385 tl.store(ubase + offs * 2, u0, mask=mask)
1386 tl.store(ubase + offs * 2 + 1, u1, mask=mask)
1388 vbase = V + pid * 4
1389 tl.store(vbase, vx0)
1390 tl.store(vbase + 1, vx1)
1391 tl.store(vbase + 2, vy0)
1392 tl.store(vbase + 3, vy1)
1393 else:
1394 ubase = U + pid * 4
1395 tl.store(ubase, vx0)
1396 tl.store(ubase + 1, vx1)
1397 tl.store(ubase + 2, vy0)
1398 tl.store(ubase + 3, vy1)
1400 vbase = V + pid * N * 2
1401 v0 = (x * vx0 + y * vy0) * inv_s0
1402 basis0 = tl.where(offs == 0, 1.0, 0.0)
1403 basis1 = tl.where(offs == 1, 1.0, 0.0)
1404 v0 = tl.where(s0 > eps, v0, basis0)
1406 v1 = (x * vx1 + y * vy1) * inv_s1
1407 v0_first = tl.sum(tl.where(offs == 0, v0, 0.0))
1408 anchor = tl.where(tl.abs(v0_first) < 0.70710678, basis0, basis1)
1409 dot = tl.sum(anchor * v0)
1410 fallback_v1 = anchor - dot * v0
1411 fallback_norm = tl.sum(fallback_v1 * fallback_v1)
1412 fallback_v1 = fallback_v1 * tl.rsqrt(fallback_norm + eps)
1413 v1 = tl.where(s1 > s0 * 5.0e-4, v1, fallback_v1)
1414 tl.store(vbase + offs * 2, v0, mask=mask)
1415 tl.store(vbase + offs * 2 + 1, v1, mask=mask)
1418def _rank2_svd(input):
1419 batch, m, n = _svd_shape(input)
1420 a = input.contiguous().reshape(batch, m, n)
1421 u = torch.empty((batch, m, 2), dtype=input.dtype, device=input.device)
1422 s = torch.empty((batch, 2), dtype=input.dtype, device=input.device)
1423 v = torch.empty((batch, n, 2), dtype=input.dtype, device=input.device)
1424 largest = max(m, n)
1425 block_r = triton.next_power_of_2(largest)
1426 with torch_device_fn.device(input.device):
1427 if largest <= 16 and batch >= 16:
1428 if largest <= 2:
1429 block_b = 8
1430 elif largest == 16:
1431 block_b = 2 if m >= n else 8
1432 else:
1433 block_b = 16
1434 _rank2_svd_tiny_kernel[(triton.cdiv(batch, block_b),)](
1435 a,
1436 u,
1437 s,
1438 v,
1439 BATCH=batch,
1440 M=m,
1441 N=n,
1442 TALL=m >= n,
1443 BLOCK_B=block_b,
1444 BLOCK_R=block_r,
1445 num_warps=1,
1446 )
1447 else:
1448 _rank2_svd_kernel[(batch,)](
1449 a,
1450 u,
1451 s,
1452 v,
1453 M=m,
1454 N=n,
1455 TALL=m >= n,
1456 BLOCK_R=block_r,
1457 num_warps=1 if block_r <= 64 else 4,
1458 )
1459 return (
1460 u.reshape(*input.shape[:-2], m, 2),
1461 s.reshape(*input.shape[:-2], 2),
1462 v.reshape(*input.shape[:-2], n, 2),
1463 )
1466def _rank2_singular_values(input):
1467 batch, m, n = _svd_shape(input)
1468 a = input.contiguous().reshape(batch, m, n)
1469 s = torch.empty((batch, 2), dtype=input.dtype, device=input.device)
1470 largest = max(m, n)
1471 block_r = triton.next_power_of_2(largest)
1472 with torch_device_fn.device(input.device):
1473 if largest <= 16 and batch >= 16:
1474 if largest <= 2:
1475 block_b = 8
1476 elif largest == 16:
1477 block_b = 2 if m >= n else 8
1478 else:
1479 block_b = 16
1480 _rank2_svals_tiny_kernel[(triton.cdiv(batch, block_b),)](
1481 a,
1482 s,
1483 BATCH=batch,
1484 M=m,
1485 N=n,
1486 TALL=m >= n,
1487 BLOCK_B=block_b,
1488 BLOCK_R=block_r,
1489 num_warps=1,
1490 )
1491 else:
1492 _rank2_svals_kernel[(batch,)](
1493 a,
1494 s,
1495 M=m,
1496 N=n,
1497 TALL=m >= n,
1498 BLOCK_R=block_r,
1499 num_warps=1 if block_r <= 64 else 4,
1500 )
1501 return s.reshape(*input.shape[:-2], 2)
1504def _small_jacobi_singular_values(input):
1505 batch, m, n = _svd_shape(input)
1506 k = min(m, n)
1507 rows = max(m, n)
1508 a = input.contiguous().reshape(batch, m, n)
1509 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device)
1510 s = torch.empty((batch, k), dtype=input.dtype, device=input.device)
1511 block_r = triton.next_power_of_2(rows)
1512 block_k = triton.next_power_of_2(k)
1513 sweeps = 3 if k <= 4 else 5
1514 with torch_device_fn.device(input.device):
1515 _small_jacobi_svals_kernel[(batch,)](
1516 a,
1517 a_work,
1518 s,
1519 M=m,
1520 N=n,
1521 K=k,
1522 ROWS=rows,
1523 TALL=m >= n,
1524 BLOCK_R=block_r,
1525 BLOCK_K=block_k,
1526 SWEEPS=sweeps,
1527 num_warps=1 if block_r <= 64 else 4,
1528 )
1529 return s.reshape(*input.shape[:-2], k)
1532def _small_jacobi_svd(input):
1533 batch, m, n = _svd_shape(input)
1534 k = min(m, n)
1535 rows = max(m, n)
1536 a = input.contiguous().reshape(batch, m, n)
1537 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device)
1538 v_work = torch.empty((batch, k, k), dtype=torch.float32, device=input.device)
1539 u = torch.empty((batch, m, k), dtype=input.dtype, device=input.device)
1540 s = torch.empty((batch, k), dtype=input.dtype, device=input.device)
1541 v = torch.empty((batch, n, k), dtype=input.dtype, device=input.device)
1542 block_r = triton.next_power_of_2(rows)
1543 block_k = triton.next_power_of_2(k)
1544 sweeps = 3 if k <= 4 else 5
1545 with torch_device_fn.device(input.device):
1546 _small_jacobi_svd_kernel[(batch,)](
1547 a,
1548 a_work,
1549 v_work,
1550 u,
1551 s,
1552 v,
1553 M=m,
1554 N=n,
1555 K=k,
1556 ROWS=rows,
1557 TALL=m >= n,
1558 BLOCK_R=block_r,
1559 BLOCK_K=block_k,
1560 SWEEPS=sweeps,
1561 num_warps=1 if block_r <= 64 else 4,
1562 )
1563 return (
1564 u.reshape(*input.shape[:-2], m, k),
1565 s.reshape(*input.shape[:-2], k),
1566 v.reshape(*input.shape[:-2], n, k),
1567 )
1570@libentry()
1571@triton.jit
1572def _cyclic_jacobi_init_a_kernel(
1573 A,
1574 A_WORK,
1575 M: tl.constexpr,
1576 N: tl.constexpr,
1577 K: tl.constexpr,
1578 ROWS: tl.constexpr,
1579 TALL: tl.constexpr,
1580 BLOCK_R: tl.constexpr,
1581):
1582 batch = tl.program_id(0)
1583 col = tl.program_id(1)
1584 rows = tl.arange(0, BLOCK_R)
1585 row_mask = rows < ROWS
1586 a_base = A + batch * M * N
1587 aw_base = A_WORK + batch * K * ROWS
1589 if TALL:
1590 vals = tl.load(a_base + rows * N + col, mask=row_mask, other=0.0).to(tl.float32)
1591 else:
1592 vals = tl.load(a_base + col * N + rows, mask=row_mask, other=0.0).to(tl.float32)
1593 tl.store(aw_base + col * ROWS + rows, vals, mask=row_mask)
1596@libentry()
1597@triton.jit
1598def _cyclic_jacobi_init_kernel(
1599 A,
1600 A_WORK,
1601 V_WORK,
1602 M: tl.constexpr,
1603 N: tl.constexpr,
1604 K: tl.constexpr,
1605 ROWS: tl.constexpr,
1606 TALL: tl.constexpr,
1607 BLOCK_R: tl.constexpr,
1608 BLOCK_K: tl.constexpr,
1609):
1610 batch = tl.program_id(0)
1611 col = tl.program_id(1)
1612 rows = tl.arange(0, BLOCK_R)
1613 basis_cols = tl.arange(0, BLOCK_K)
1614 row_mask = rows < ROWS
1615 basis_mask = basis_cols < K
1616 a_base = A + batch * M * N
1617 aw_base = A_WORK + batch * K * ROWS
1618 vw_base = V_WORK + batch * K * K
1620 if TALL:
1621 vals = tl.load(a_base + rows * N + col, mask=row_mask, other=0.0).to(tl.float32)
1622 else:
1623 vals = tl.load(a_base + col * N + rows, mask=row_mask, other=0.0).to(tl.float32)
1624 tl.store(aw_base + col * ROWS + rows, vals, mask=row_mask)
1626 ident = tl.where(basis_cols == col, 1.0, 0.0)
1627 tl.store(vw_base + col * K + basis_cols, ident, mask=basis_mask)
1630@libentry()
1631@triton.jit
1632def _cyclic_jacobi_pair_kernel(
1633 A_WORK,
1634 V_WORK,
1635 STEP,
1636 K: tl.constexpr,
1637 ROUND: tl.constexpr,
1638 ROWS: tl.constexpr,
1639 BLOCK_R: tl.constexpr,
1640 BLOCK_K: tl.constexpr,
1641):
1642 batch = tl.program_id(0)
1643 pair = tl.program_id(1)
1644 rows = tl.arange(0, BLOCK_R)
1645 cols = tl.arange(0, BLOCK_K)
1646 ring = ROUND - 1
1648 pos_p = pair
1649 pos_q = ROUND - 1 - pair
1650 p = tl.where(pos_p == 0, 0, ((pos_p + ring - STEP - 1) % ring) + 1)
1651 q = tl.where(pos_q == 0, 0, ((pos_q + ring - STEP - 1) % ring) + 1)
1652 valid_pair = (p < K) & (q < K)
1653 swap = p > q
1654 p2 = tl.where(swap, q, p)
1655 q2 = tl.where(swap, p, q)
1656 row_mask = (rows < ROWS) & valid_pair
1657 col_mask = (cols < K) & valid_pair
1659 aw_base = A_WORK + batch * K * ROWS
1660 vw_base = V_WORK + batch * K * K
1661 ap = tl.load(aw_base + p2 * ROWS + rows, mask=row_mask, other=0.0)
1662 aq = tl.load(aw_base + q2 * ROWS + rows, mask=row_mask, other=0.0)
1663 alpha = tl.sum(ap * ap)
1664 beta = tl.sum(aq * aq)
1665 gamma = tl.sum(ap * aq)
1666 eps = 1.0e-20
1667 abs_gamma = tl.abs(gamma)
1668 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps)
1669 active = abs_gamma > threshold
1670 safe_gamma = tl.where(active, gamma, 1.0)
1671 tau = (beta - alpha) / (2.0 * safe_gamma)
1672 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0)
1673 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau))
1674 c = tl.rsqrt(1.0 + t * t)
1675 s_rot = t * c
1676 c = tl.where(active, c, 1.0)
1677 s_rot = tl.where(active, s_rot, 0.0)
1679 new_ap = c * ap - s_rot * aq
1680 new_aq = s_rot * ap + c * aq
1681 tl.store(aw_base + p2 * ROWS + rows, new_ap, mask=row_mask)
1682 tl.store(aw_base + q2 * ROWS + rows, new_aq, mask=row_mask)
1684 vp = tl.load(vw_base + p2 * K + cols, mask=col_mask, other=0.0)
1685 vq = tl.load(vw_base + q2 * K + cols, mask=col_mask, other=0.0)
1686 new_vp = c * vp - s_rot * vq
1687 new_vq = s_rot * vp + c * vq
1688 tl.store(vw_base + p2 * K + cols, new_vp, mask=col_mask)
1689 tl.store(vw_base + q2 * K + cols, new_vq, mask=col_mask)
1692@libentry()
1693@triton.jit
1694def _serial_cyclic_jacobi_kernel(
1695 A_WORK,
1696 V_WORK,
1697 K,
1698 ROUND,
1699 ROWS: tl.constexpr,
1700 SWEEPS,
1701 TAIL_STEPS,
1702 BLOCK_R: tl.constexpr,
1703 BLOCK_K: tl.constexpr,
1704):
1705 batch = tl.program_id(0)
1706 rows = tl.arange(0, BLOCK_R)
1707 cols = tl.arange(0, BLOCK_K)
1708 row_base_mask = rows < ROWS
1709 col_base_mask = cols < K
1710 aw_base = A_WORK + batch * K * ROWS
1711 vw_base = V_WORK + batch * K * K
1712 eps = 1.0e-20
1713 ring = ROUND - 1
1714 half_round = ROUND // 2
1716 sweep = 0
1717 while sweep < SWEEPS:
1718 step = 0
1719 while step < ROUND - 1:
1720 pair = 0
1721 while pair < half_round:
1722 pos_p = pair
1723 pos_q = ROUND - 1 - pair
1724 p = tl.where(pos_p == 0, 0, ((pos_p + ring - step - 1) % ring) + 1)
1725 q = tl.where(pos_q == 0, 0, ((pos_q + ring - step - 1) % ring) + 1)
1726 valid_pair = (p < K) & (q < K)
1727 swap = p > q
1728 p2 = tl.where(swap, q, p)
1729 q2 = tl.where(swap, p, q)
1730 row_mask = row_base_mask & valid_pair
1731 col_mask = col_base_mask & valid_pair
1733 ap = tl.load(aw_base + p2 * ROWS + rows, mask=row_mask, other=0.0)
1734 aq = tl.load(aw_base + q2 * ROWS + rows, mask=row_mask, other=0.0)
1735 alpha = tl.sum(ap * ap)
1736 beta = tl.sum(aq * aq)
1737 gamma = tl.sum(ap * aq)
1738 abs_gamma = tl.abs(gamma)
1739 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps)
1740 active = abs_gamma > threshold
1741 safe_gamma = tl.where(active, gamma, 1.0)
1742 tau = (beta - alpha) / (2.0 * safe_gamma)
1743 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0)
1744 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau))
1745 c = tl.rsqrt(1.0 + t * t)
1746 s_rot = t * c
1747 c = tl.where(active, c, 1.0)
1748 s_rot = tl.where(active, s_rot, 0.0)
1750 new_ap = c * ap - s_rot * aq
1751 new_aq = s_rot * ap + c * aq
1752 tl.store(aw_base + p2 * ROWS + rows, new_ap, mask=row_mask)
1753 tl.store(aw_base + q2 * ROWS + rows, new_aq, mask=row_mask)
1755 vp = tl.load(vw_base + p2 * K + cols, mask=col_mask, other=0.0)
1756 vq = tl.load(vw_base + q2 * K + cols, mask=col_mask, other=0.0)
1757 new_vp = c * vp - s_rot * vq
1758 new_vq = s_rot * vp + c * vq
1759 tl.store(vw_base + p2 * K + cols, new_vp, mask=col_mask)
1760 tl.store(vw_base + q2 * K + cols, new_vq, mask=col_mask)
1761 pair += 1
1762 step += 1
1763 sweep += 1
1765 step = 0
1766 while step < TAIL_STEPS:
1767 pair = 0
1768 while pair < half_round:
1769 pos_p = pair
1770 pos_q = ROUND - 1 - pair
1771 p = tl.where(pos_p == 0, 0, ((pos_p + ring - step - 1) % ring) + 1)
1772 q = tl.where(pos_q == 0, 0, ((pos_q + ring - step - 1) % ring) + 1)
1773 valid_pair = (p < K) & (q < K)
1774 swap = p > q
1775 p2 = tl.where(swap, q, p)
1776 q2 = tl.where(swap, p, q)
1777 row_mask = row_base_mask & valid_pair
1778 col_mask = col_base_mask & valid_pair
1780 ap = tl.load(aw_base + p2 * ROWS + rows, mask=row_mask, other=0.0)
1781 aq = tl.load(aw_base + q2 * ROWS + rows, mask=row_mask, other=0.0)
1782 alpha = tl.sum(ap * ap)
1783 beta = tl.sum(aq * aq)
1784 gamma = tl.sum(ap * aq)
1785 abs_gamma = tl.abs(gamma)
1786 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps)
1787 active = abs_gamma > threshold
1788 safe_gamma = tl.where(active, gamma, 1.0)
1789 tau = (beta - alpha) / (2.0 * safe_gamma)
1790 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0)
1791 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau))
1792 c = tl.rsqrt(1.0 + t * t)
1793 s_rot = t * c
1794 c = tl.where(active, c, 1.0)
1795 s_rot = tl.where(active, s_rot, 0.0)
1797 new_ap = c * ap - s_rot * aq
1798 new_aq = s_rot * ap + c * aq
1799 tl.store(aw_base + p2 * ROWS + rows, new_ap, mask=row_mask)
1800 tl.store(aw_base + q2 * ROWS + rows, new_aq, mask=row_mask)
1802 vp = tl.load(vw_base + p2 * K + cols, mask=col_mask, other=0.0)
1803 vq = tl.load(vw_base + q2 * K + cols, mask=col_mask, other=0.0)
1804 new_vp = c * vp - s_rot * vq
1805 new_vq = s_rot * vp + c * vq
1806 tl.store(vw_base + p2 * K + cols, new_vp, mask=col_mask)
1807 tl.store(vw_base + q2 * K + cols, new_vq, mask=col_mask)
1808 pair += 1
1809 step += 1
1812@libentry()
1813@triton.jit
1814def _cyclic_jacobi_norm_kernel(
1815 A_WORK,
1816 S_WORK,
1817 K: tl.constexpr,
1818 ROWS: tl.constexpr,
1819 BLOCK_R: tl.constexpr,
1820):
1821 batch = tl.program_id(0)
1822 col = tl.program_id(1)
1823 rows = tl.arange(0, BLOCK_R)
1824 mask = rows < ROWS
1825 aw_base = A_WORK + batch * K * ROWS
1826 vals = tl.load(aw_base + col * ROWS + rows, mask=mask, other=0.0)
1827 norm = tl.sqrt(tl.sum(vals * vals))
1828 tl.store(S_WORK + batch * K + col, norm)
1831@libentry()
1832@triton.jit
1833def _cyclic_jacobi_finalize_kernel(
1834 A_WORK,
1835 V_WORK,
1836 S_WORK,
1837 U,
1838 S,
1839 V,
1840 M: tl.constexpr,
1841 N: tl.constexpr,
1842 K: tl.constexpr,
1843 ROWS: tl.constexpr,
1844 TALL: tl.constexpr,
1845 BLOCK_R: tl.constexpr,
1846 BLOCK_K: tl.constexpr,
1847):
1848 batch = tl.program_id(0)
1849 col = tl.program_id(1)
1850 rows = tl.arange(0, BLOCK_R)
1851 basis_cols = tl.arange(0, BLOCK_K)
1852 row_mask = rows < ROWS
1853 basis_mask = basis_cols < K
1854 eps = 1.0e-20
1856 s_col = tl.load(S_WORK + batch * K + col)
1857 rank = tl.full((), 0, dtype=tl.int32)
1858 for other in tl.static_range(0, K):
1859 s_other = tl.load(S_WORK + batch * K + other)
1860 rank += ((s_other > s_col) | ((s_other == s_col) & (other < col))).to(tl.int32)
1862 aw_base = A_WORK + batch * K * ROWS
1863 vw_base = V_WORK + batch * K * K
1864 col_vals = tl.load(aw_base + col * ROWS + rows, mask=row_mask, other=0.0)
1865 inv_norm = tl.where(s_col > eps, 1.0 / s_col, 0.0)
1866 basis = tl.load(vw_base + col * K + basis_cols, mask=basis_mask, other=0.0)
1867 tl.store(S + batch * K + rank, s_col)
1869 if TALL:
1870 tl.store(
1871 U + batch * M * K + rows * K + rank,
1872 col_vals * inv_norm,
1873 mask=row_mask,
1874 )
1875 tl.store(
1876 V + batch * N * K + basis_cols * K + rank,
1877 basis,
1878 mask=basis_mask,
1879 )
1880 else:
1881 tl.store(
1882 U + batch * M * K + basis_cols * K + rank,
1883 basis,
1884 mask=basis_mask,
1885 )
1886 tl.store(
1887 V + batch * N * K + rows * K + rank,
1888 col_vals * inv_norm,
1889 mask=row_mask,
1890 )
1893@libentry()
1894@triton.jit
1895def _blocked_jacobi_pair_svals_kernel(
1896 A_WORK,
1897 STEP,
1898 K: tl.constexpr,
1899 ROUND: tl.constexpr,
1900 ROWS: tl.constexpr,
1901 BLOCK_R: tl.constexpr,
1902):
1903 batch = tl.program_id(0)
1904 pair = tl.program_id(1)
1905 rows = tl.arange(0, BLOCK_R)
1906 ring = ROUND - 1
1908 pos_p = pair
1909 pos_q = ROUND - 1 - pair
1910 p = tl.where(pos_p == 0, 0, ((pos_p + ring - STEP - 1) % ring) + 1)
1911 q = tl.where(pos_q == 0, 0, ((pos_q + ring - STEP - 1) % ring) + 1)
1912 valid_pair = (p < K) & (q < K)
1913 swap = p > q
1914 p2 = tl.where(swap, q, p)
1915 q2 = tl.where(swap, p, q)
1916 row_mask = (rows < ROWS) & valid_pair
1918 aw_base = A_WORK + batch * K * ROWS
1919 ap = tl.load(aw_base + p2 * ROWS + rows, mask=row_mask, other=0.0)
1920 aq = tl.load(aw_base + q2 * ROWS + rows, mask=row_mask, other=0.0)
1921 alpha = tl.sum(ap * ap)
1922 beta = tl.sum(aq * aq)
1923 gamma = tl.sum(ap * aq)
1924 eps = 1.0e-20
1925 abs_gamma = tl.abs(gamma)
1926 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps)
1927 active = abs_gamma > threshold
1928 safe_gamma = tl.where(active, gamma, 1.0)
1929 tau = (beta - alpha) / (2.0 * safe_gamma)
1930 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0)
1931 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau))
1932 c = tl.rsqrt(1.0 + t * t)
1933 s_rot = t * c
1934 c = tl.where(active & valid_pair, c, 1.0)
1935 s_rot = tl.where(active & valid_pair, s_rot, 0.0)
1937 new_ap = c * ap - s_rot * aq
1938 new_aq = s_rot * ap + c * aq
1939 tl.store(aw_base + p2 * ROWS + rows, new_ap, mask=row_mask)
1940 tl.store(aw_base + q2 * ROWS + rows, new_aq, mask=row_mask)
1943@libentry()
1944@triton.jit
1945def _blocked_jacobi_pair_a_kernel(
1946 A_WORK,
1947 ROT_C,
1948 ROT_S,
1949 STEP,
1950 K: tl.constexpr,
1951 ROUND: tl.constexpr,
1952 ROWS: tl.constexpr,
1953 BLOCK_R: tl.constexpr,
1954):
1955 batch = tl.program_id(0)
1956 pair = tl.program_id(1)
1957 rows = tl.arange(0, BLOCK_R)
1958 ring = ROUND - 1
1960 pos_p = pair
1961 pos_q = ROUND - 1 - pair
1962 p = tl.where(pos_p == 0, 0, ((pos_p + ring - STEP - 1) % ring) + 1)
1963 q = tl.where(pos_q == 0, 0, ((pos_q + ring - STEP - 1) % ring) + 1)
1964 valid_pair = (p < K) & (q < K)
1965 swap = p > q
1966 p2 = tl.where(swap, q, p)
1967 q2 = tl.where(swap, p, q)
1968 row_mask = (rows < ROWS) & valid_pair
1970 aw_base = A_WORK + batch * K * ROWS
1971 ap = tl.load(aw_base + p2 * ROWS + rows, mask=row_mask, other=0.0)
1972 aq = tl.load(aw_base + q2 * ROWS + rows, mask=row_mask, other=0.0)
1973 alpha = tl.sum(ap * ap)
1974 beta = tl.sum(aq * aq)
1975 gamma = tl.sum(ap * aq)
1976 eps = 1.0e-20
1977 abs_gamma = tl.abs(gamma)
1978 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps)
1979 active = abs_gamma > threshold
1980 safe_gamma = tl.where(active, gamma, 1.0)
1981 tau = (beta - alpha) / (2.0 * safe_gamma)
1982 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0)
1983 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau))
1984 c = tl.rsqrt(1.0 + t * t)
1985 s_rot = t * c
1986 c = tl.where(active & valid_pair, c, 1.0)
1987 s_rot = tl.where(active & valid_pair, s_rot, 0.0)
1989 new_ap = c * ap - s_rot * aq
1990 new_aq = s_rot * ap + c * aq
1991 tl.store(aw_base + p2 * ROWS + rows, new_ap, mask=row_mask)
1992 tl.store(aw_base + q2 * ROWS + rows, new_aq, mask=row_mask)
1994 rot_base = batch * (ROUND // 2) + pair
1995 tl.store(ROT_C + rot_base, c)
1996 tl.store(ROT_S + rot_base, s_rot)
1999@libentry()
2000@triton.jit
2001def _hier_block_jacobi_pair_a_kernel(
2002 A_WORK,
2003 STEP,
2004 K: tl.constexpr,
2005 K_BLOCKS: tl.constexpr,
2006 ROUND_BLOCKS: tl.constexpr,
2007 ROWS: tl.constexpr,
2008 TILE_B: tl.constexpr,
2009 TILE_COLS: tl.constexpr,
2010 BLOCK_R: tl.constexpr,
2011 LOCAL_SWEEPS: tl.constexpr,
2012):
2013 batch = tl.program_id(0)
2014 pair = tl.program_id(1)
2015 rows = tl.arange(0, BLOCK_R)
2016 local_cols = tl.arange(0, TILE_COLS)
2017 ring = ROUND_BLOCKS - 1
2019 pos_p = pair
2020 pos_q = ROUND_BLOCKS - 1 - pair
2021 p_block = tl.where(pos_p == 0, 0, ((pos_p + ring - STEP - 1) % ring) + 1)
2022 q_block = tl.where(pos_q == 0, 0, ((pos_q + ring - STEP - 1) % ring) + 1)
2023 valid_pair = (p_block < K_BLOCKS) & (q_block < K_BLOCKS)
2024 p2 = tl.minimum(p_block, q_block)
2025 q2 = tl.maximum(p_block, q_block)
2027 col_ids = tl.where(
2028 local_cols < TILE_B,
2029 p2 * TILE_B + local_cols,
2030 q2 * TILE_B + local_cols - TILE_B,
2031 )
2032 row_mask = rows < ROWS
2033 col_mask = (col_ids < K) & valid_pair
2034 aw_base = A_WORK + batch * K * ROWS
2035 vals = tl.load(
2036 aw_base + col_ids[:, None] * ROWS + rows[None, :],
2037 mask=col_mask[:, None] & row_mask[None, :],
2038 other=0.0,
2039 ).to(tl.float32)
2040 col_axis = local_cols[:, None]
2041 eps = 1.0e-20
2043 for _ in tl.static_range(0, LOCAL_SWEEPS):
2044 for p in tl.static_range(0, TILE_COLS):
2045 for q in tl.static_range(p + 1, TILE_COLS):
2046 ap = tl.sum(tl.where(col_axis == p, vals, 0.0), axis=0)
2047 aq = tl.sum(tl.where(col_axis == q, vals, 0.0), axis=0)
2048 alpha = tl.sum(ap * ap)
2049 beta = tl.sum(aq * aq)
2050 gamma = tl.sum(ap * aq)
2051 abs_gamma = tl.abs(gamma)
2052 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps)
2053 active = abs_gamma > threshold
2054 safe_gamma = tl.where(active, gamma, 1.0)
2055 tau = (beta - alpha) / (2.0 * safe_gamma)
2056 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0)
2057 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau))
2058 c = tl.rsqrt(1.0 + t * t)
2059 s_rot = t * c
2060 c = tl.where(active & valid_pair, c, 1.0)
2061 s_rot = tl.where(active & valid_pair, s_rot, 0.0)
2063 new_ap = c * ap - s_rot * aq
2064 new_aq = s_rot * ap + c * aq
2065 vals = tl.where(col_axis == p, new_ap[None, :], vals)
2066 vals = tl.where(col_axis == q, new_aq[None, :], vals)
2068 tl.store(
2069 aw_base + col_ids[:, None] * ROWS + rows[None, :],
2070 vals,
2071 mask=col_mask[:, None] & row_mask[None, :],
2072 )
2075@libentry()
2076@triton.jit
2077def _blocked_jacobi_apply_v_kernel(
2078 V_WORK,
2079 ROT_C,
2080 ROT_S,
2081 STEP,
2082 K: tl.constexpr,
2083 ROUND: tl.constexpr,
2084 BLOCK_V: tl.constexpr,
2085):
2086 batch = tl.program_id(0)
2087 pair = tl.program_id(1)
2088 block = tl.program_id(2)
2089 cols = block * BLOCK_V + tl.arange(0, BLOCK_V)
2090 ring = ROUND - 1
2092 pos_p = pair
2093 pos_q = ROUND - 1 - pair
2094 p = tl.where(pos_p == 0, 0, ((pos_p + ring - STEP - 1) % ring) + 1)
2095 q = tl.where(pos_q == 0, 0, ((pos_q + ring - STEP - 1) % ring) + 1)
2096 valid_pair = (p < K) & (q < K)
2097 swap = p > q
2098 p2 = tl.where(swap, q, p)
2099 q2 = tl.where(swap, p, q)
2100 mask = (cols < K) & valid_pair
2102 rot_base = batch * (ROUND // 2) + pair
2103 c = tl.load(ROT_C + rot_base)
2104 s_rot = tl.load(ROT_S + rot_base)
2105 vw_base = V_WORK + batch * K * K
2106 vp = tl.load(vw_base + p2 * K + cols, mask=mask, other=0.0)
2107 vq = tl.load(vw_base + q2 * K + cols, mask=mask, other=0.0)
2108 new_vp = c * vp - s_rot * vq
2109 new_vq = s_rot * vp + c * vq
2110 tl.store(vw_base + p2 * K + cols, new_vp, mask=mask)
2111 tl.store(vw_base + q2 * K + cols, new_vq, mask=mask)
2114@libentry()
2115@triton.jit
2116def _blocked_jacobi_rank_kernel(
2117 S_WORK,
2118 RANKS,
2119 S,
2120 K,
2121):
2122 batch = tl.program_id(0)
2123 col = tl.program_id(1)
2124 s_col = tl.load(S_WORK + batch * K + col)
2125 rank = tl.full((), 0, dtype=tl.int32)
2126 other = 0
2127 while other < K:
2128 s_other = tl.load(S_WORK + batch * K + other)
2129 rank += ((s_other > s_col) | ((s_other == s_col) & (other < col))).to(tl.int32)
2130 other += 1
2131 tl.store(RANKS + batch * K + col, rank)
2132 tl.store(S + batch * K + rank, s_col)
2135@libentry()
2136@triton.jit
2137def _blocked_jacobi_store_projected_kernel(
2138 A_WORK,
2139 S_WORK,
2140 RANKS,
2141 PROJECTED,
2142 K: tl.constexpr,
2143 ROWS: tl.constexpr,
2144 OUT_ROWS: tl.constexpr,
2145 BLOCK_R: tl.constexpr,
2146):
2147 batch = tl.program_id(0)
2148 col = tl.program_id(1)
2149 block = tl.program_id(2)
2150 rows = block * BLOCK_R + tl.arange(0, BLOCK_R)
2151 mask = rows < OUT_ROWS
2152 rank = tl.load(RANKS + batch * K + col)
2153 s_col = tl.load(S_WORK + batch * K + col)
2154 eps = 1.0e-20
2155 vals = tl.load(
2156 A_WORK + batch * K * ROWS + col * ROWS + rows,
2157 mask=mask,
2158 other=0.0,
2159 )
2160 vals = vals / tl.maximum(s_col, eps)
2161 basis = tl.where(rows == rank, 1.0, 0.0)
2162 vals = tl.where(s_col <= eps, basis, vals)
2163 tl.store(
2164 PROJECTED + batch * OUT_ROWS * K + rows * K + rank,
2165 vals,
2166 mask=mask,
2167 )
2170@libentry()
2171@triton.jit
2172def _blocked_jacobi_store_basis_kernel(
2173 V_WORK,
2174 RANKS,
2175 BASIS,
2176 K: tl.constexpr,
2177 BLOCK_V: tl.constexpr,
2178):
2179 batch = tl.program_id(0)
2180 col = tl.program_id(1)
2181 block = tl.program_id(2)
2182 rows = block * BLOCK_V + tl.arange(0, BLOCK_V)
2183 mask = rows < K
2184 rank = tl.load(RANKS + batch * K + col)
2185 vals = tl.load(
2186 V_WORK + batch * K * K + col * K + rows,
2187 mask=mask,
2188 other=0.0,
2189 )
2190 tl.store(
2191 BASIS + batch * K * K + rows * K + rank,
2192 vals,
2193 mask=mask,
2194 )
2197@libentry()
2198@triton.jit
2199def _thin_reorthogonalize_kernel(
2200 Q,
2201 ROWS: tl.constexpr,
2202 K: tl.constexpr,
2203 BLOCK_R: tl.constexpr,
2204):
2205 batch = tl.program_id(0)
2206 rows = tl.arange(0, BLOCK_R)
2207 row_mask = rows < ROWS
2208 base = Q + batch * ROWS * K
2209 eps = 1.0e-20
2211 for j in tl.static_range(0, K):
2212 vec = tl.load(base + rows * K + j, mask=row_mask, other=0.0).to(tl.float32)
2214 for prev in tl.static_range(0, K):
2215 if prev < j:
2216 q_prev = tl.load(base + rows * K + prev, mask=row_mask, other=0.0).to(
2217 tl.float32
2218 )
2219 coeff = tl.sum(vec * q_prev)
2220 vec = vec - coeff * q_prev
2222 for prev in tl.static_range(0, K):
2223 if prev < j:
2224 q_prev = tl.load(base + rows * K + prev, mask=row_mask, other=0.0).to(
2225 tl.float32
2226 )
2227 coeff = tl.sum(vec * q_prev)
2228 vec = vec - coeff * q_prev
2230 norm = tl.sqrt(tl.sum(vec * vec))
2231 basis = tl.where(rows == j, 1.0, 0.0)
2232 vec = tl.where(norm > eps, vec, basis)
2234 for prev in tl.static_range(0, K):
2235 if prev < j:
2236 q_prev = tl.load(base + rows * K + prev, mask=row_mask, other=0.0).to(
2237 tl.float32
2238 )
2239 coeff = tl.sum(vec * q_prev)
2240 vec = vec - coeff * q_prev
2242 norm = tl.sqrt(tl.sum(vec * vec))
2243 vec = vec / tl.maximum(norm, eps)
2244 tl.store(base + rows * K + j, vec, mask=row_mask)
2247def _cyclic_jacobi_svd(input):
2248 batch, m, n = _svd_shape(input)
2249 k = min(m, n)
2250 rows = max(m, n)
2251 a = input.contiguous().reshape(batch, m, n)
2252 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device)
2253 v_work = torch.empty((batch, k, k), dtype=torch.float32, device=input.device)
2254 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device)
2255 u = torch.empty((batch, m, k), dtype=input.dtype, device=input.device)
2256 s = torch.empty((batch, k), dtype=input.dtype, device=input.device)
2257 v = torch.empty((batch, n, k), dtype=input.dtype, device=input.device)
2258 block_r = triton.next_power_of_2(rows)
2259 block_k = triton.next_power_of_2(k)
2260 sweeps = 6 if k == 32 else 8 if k < 32 else 12
2261 tail_steps = 20 if k == 32 else 0
2262 round_size = k if k % 2 == 0 else k + 1
2263 serial_medium = 16 <= k <= 32 and rows <= 64 and batch <= 32
2264 with torch_device_fn.device(input.device):
2265 _cyclic_jacobi_init_kernel[(batch, k)](
2266 a,
2267 a_work,
2268 v_work,
2269 M=m,
2270 N=n,
2271 K=k,
2272 ROWS=rows,
2273 TALL=m >= n,
2274 BLOCK_R=block_r,
2275 BLOCK_K=block_k,
2276 num_warps=1 if block_r <= 64 else 4,
2277 )
2278 if serial_medium:
2279 _serial_cyclic_jacobi_kernel[(batch,)](
2280 a_work,
2281 v_work,
2282 K=k,
2283 ROUND=round_size,
2284 ROWS=rows,
2285 SWEEPS=sweeps,
2286 TAIL_STEPS=tail_steps,
2287 BLOCK_R=block_r,
2288 BLOCK_K=block_k,
2289 num_warps=1,
2290 )
2291 else:
2292 for _ in range(sweeps):
2293 for step in range(round_size - 1):
2294 _cyclic_jacobi_pair_kernel[(batch, round_size // 2)](
2295 a_work,
2296 v_work,
2297 step,
2298 K=k,
2299 ROUND=round_size,
2300 ROWS=rows,
2301 BLOCK_R=block_r,
2302 BLOCK_K=block_k,
2303 num_warps=1 if block_r <= 64 else 4,
2304 )
2305 _cyclic_jacobi_norm_kernel[(batch, k)](
2306 a_work,
2307 s_work,
2308 K=k,
2309 ROWS=rows,
2310 BLOCK_R=block_r,
2311 num_warps=1 if block_r <= 64 else 4,
2312 )
2313 _cyclic_jacobi_finalize_kernel[(batch, k)](
2314 a_work,
2315 v_work,
2316 s_work,
2317 u,
2318 s,
2319 v,
2320 M=m,
2321 N=n,
2322 K=k,
2323 ROWS=rows,
2324 TALL=m >= n,
2325 BLOCK_R=block_r,
2326 BLOCK_K=block_k,
2327 num_warps=1 if block_r <= 64 else 4,
2328 )
2329 if k <= 17 and rows <= 64:
2330 with torch_device_fn.device(input.device):
2331 if m >= n:
2332 _thin_reorthogonalize_kernel[(batch,)](
2333 v,
2334 ROWS=n,
2335 K=k,
2336 BLOCK_R=triton.next_power_of_2(n),
2337 num_warps=1,
2338 )
2339 else:
2340 _thin_reorthogonalize_kernel[(batch,)](
2341 u,
2342 ROWS=m,
2343 K=k,
2344 BLOCK_R=triton.next_power_of_2(m),
2345 num_warps=1,
2346 )
2348 if m >= n:
2349 u = _triton_bmm(a, v, (batch, m, k))
2350 projected = u
2351 projected_rows = m
2352 else:
2353 a_t = a.transpose(1, 2).contiguous()
2354 v = _triton_bmm(a_t, u, (batch, n, k))
2355 projected = v
2356 projected_rows = n
2357 with torch_device_fn.device(input.device):
2358 _normalize_projection_kernel[(batch, k)](
2359 projected,
2360 s,
2361 ROWS=projected_rows,
2362 K=k,
2363 BLOCK_R=triton.next_power_of_2(projected_rows),
2364 num_warps=1,
2365 )
2366 _complete_zero_projection_kernel[(batch, k)](
2367 projected,
2368 s,
2369 ROWS=projected_rows,
2370 K=k,
2371 BLOCK_R=triton.next_power_of_2(projected_rows),
2372 num_warps=1,
2373 )
2374 if batch > 1 and k <= 16:
2375 _thin_reorthogonalize_kernel[(batch,)](
2376 projected,
2377 ROWS=projected_rows,
2378 K=k,
2379 BLOCK_R=triton.next_power_of_2(projected_rows),
2380 num_warps=1,
2381 )
2382 return (
2383 u.reshape(*input.shape[:-2], m, k),
2384 s.reshape(*input.shape[:-2], k),
2385 v.reshape(*input.shape[:-2], n, k),
2386 )
2389def _projected_jacobi_svd(input):
2390 batch, m, n = _svd_shape(input)
2391 k = min(m, n)
2392 rows = max(m, n)
2393 tall = m >= n
2394 a = input.contiguous().reshape(batch, m, n)
2395 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device)
2396 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device)
2397 ranks = torch.empty((batch, k), dtype=torch.int32, device=input.device)
2398 s = torch.empty((batch, k), dtype=input.dtype, device=input.device)
2400 projected_rows = m if tall else n
2401 projected = torch.empty(
2402 (batch, projected_rows, k), dtype=input.dtype, device=input.device
2403 )
2404 block_r = triton.next_power_of_2(rows)
2405 sweeps = 10 if k >= 128 else 8
2406 round_size = k if k % 2 == 0 else k + 1
2407 half_round = round_size // 2
2408 rot_c = torch.empty((batch, half_round), dtype=torch.float32, device=input.device)
2409 rot_s = torch.empty((batch, half_round), dtype=torch.float32, device=input.device)
2411 with torch_device_fn.device(input.device):
2412 _cyclic_jacobi_init_a_kernel[(batch, k)](
2413 a,
2414 a_work,
2415 M=m,
2416 N=n,
2417 K=k,
2418 ROWS=rows,
2419 TALL=tall,
2420 BLOCK_R=block_r,
2421 num_warps=1 if block_r <= 64 else 4,
2422 )
2423 for _ in range(sweeps):
2424 for step in range(round_size - 1):
2425 _blocked_jacobi_pair_a_kernel[(batch, half_round)](
2426 a_work,
2427 rot_c,
2428 rot_s,
2429 step,
2430 K=k,
2431 ROUND=round_size,
2432 ROWS=rows,
2433 BLOCK_R=block_r,
2434 num_warps=1 if block_r <= 64 else 4,
2435 )
2436 _cyclic_jacobi_norm_kernel[(batch, k)](
2437 a_work,
2438 s_work,
2439 K=k,
2440 ROWS=rows,
2441 BLOCK_R=block_r,
2442 num_warps=1 if block_r <= 64 else 4,
2443 )
2444 _blocked_jacobi_rank_kernel[(batch, k)](
2445 s_work,
2446 ranks,
2447 s,
2448 k,
2449 num_warps=1,
2450 )
2451 _blocked_jacobi_store_projected_kernel[
2452 (batch, k, triton.cdiv(projected_rows, block_r))
2453 ](
2454 a_work,
2455 s_work,
2456 ranks,
2457 projected,
2458 K=k,
2459 ROWS=rows,
2460 OUT_ROWS=projected_rows,
2461 BLOCK_R=block_r,
2462 num_warps=1 if block_r <= 64 else 4,
2463 )
2465 if tall:
2466 u = projected
2467 v = _triton_bmm(a.transpose(1, 2).contiguous(), u, (batch, n, k))
2468 normalized = v
2469 normalized_rows = n
2470 else:
2471 v = projected
2472 u = _triton_bmm(a, v, (batch, m, k))
2473 normalized = u
2474 normalized_rows = m
2476 with torch_device_fn.device(input.device):
2477 _normalize_projection_kernel[(batch, k)](
2478 normalized,
2479 s,
2480 ROWS=normalized_rows,
2481 K=k,
2482 BLOCK_R=triton.next_power_of_2(normalized_rows),
2483 num_warps=1 if normalized_rows <= 64 else 4,
2484 )
2486 return (
2487 u.reshape(*input.shape[:-2], m, k),
2488 s.reshape(*input.shape[:-2], k),
2489 v.reshape(*input.shape[:-2], n, k),
2490 )
2493def _blocked_jacobi_svd(input):
2494 batch, m, n = _svd_shape(input)
2495 k = min(m, n)
2496 rows = max(m, n)
2497 a = input.contiguous().reshape(batch, m, n)
2498 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device)
2499 v_work = torch.empty((batch, k, k), dtype=torch.float32, device=input.device)
2500 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device)
2501 ranks = torch.empty((batch, k), dtype=torch.int32, device=input.device)
2502 u = torch.empty((batch, m, k), dtype=input.dtype, device=input.device)
2503 s = torch.empty((batch, k), dtype=input.dtype, device=input.device)
2504 v = torch.empty((batch, n, k), dtype=input.dtype, device=input.device)
2506 block_r = triton.next_power_of_2(rows)
2507 block_k = triton.next_power_of_2(k)
2508 block_v = 64
2509 sweeps = 14 if k > 256 else 10
2510 round_size = k if k % 2 == 0 else k + 1
2511 half_round = round_size // 2
2512 rot_c = torch.empty((batch, half_round), dtype=torch.float32, device=input.device)
2513 rot_s = torch.empty((batch, half_round), dtype=torch.float32, device=input.device)
2514 with torch_device_fn.device(input.device):
2515 _cyclic_jacobi_init_kernel[(batch, k)](
2516 a,
2517 a_work,
2518 v_work,
2519 M=m,
2520 N=n,
2521 K=k,
2522 ROWS=rows,
2523 TALL=m >= n,
2524 BLOCK_R=block_r,
2525 BLOCK_K=block_k,
2526 num_warps=1 if block_r <= 64 else 4,
2527 )
2528 for _ in range(sweeps):
2529 for step in range(round_size - 1):
2530 _blocked_jacobi_pair_a_kernel[(batch, half_round)](
2531 a_work,
2532 rot_c,
2533 rot_s,
2534 step,
2535 K=k,
2536 ROUND=round_size,
2537 ROWS=rows,
2538 BLOCK_R=block_r,
2539 num_warps=1 if block_r <= 64 else 4,
2540 )
2541 _blocked_jacobi_apply_v_kernel[
2542 (batch, half_round, triton.cdiv(k, block_v))
2543 ](
2544 v_work,
2545 rot_c,
2546 rot_s,
2547 step,
2548 K=k,
2549 ROUND=round_size,
2550 BLOCK_V=block_v,
2551 num_warps=1,
2552 )
2553 _cyclic_jacobi_norm_kernel[(batch, k)](
2554 a_work,
2555 s_work,
2556 K=k,
2557 ROWS=rows,
2558 BLOCK_R=block_r,
2559 num_warps=1 if block_r <= 64 else 4,
2560 )
2561 _blocked_jacobi_rank_kernel[(batch, k)](
2562 s_work,
2563 ranks,
2564 s,
2565 k,
2566 num_warps=1,
2567 )
2568 if m >= n:
2569 _blocked_jacobi_store_projected_kernel[(batch, k, triton.cdiv(m, block_r))](
2570 a_work,
2571 s_work,
2572 ranks,
2573 u,
2574 K=k,
2575 ROWS=rows,
2576 OUT_ROWS=m,
2577 BLOCK_R=block_r,
2578 num_warps=1 if block_r <= 64 else 4,
2579 )
2580 _blocked_jacobi_store_basis_kernel[(batch, k, triton.cdiv(n, block_v))](
2581 v_work,
2582 ranks,
2583 v,
2584 K=k,
2585 BLOCK_V=block_v,
2586 num_warps=1,
2587 )
2588 else:
2589 _blocked_jacobi_store_basis_kernel[(batch, k, triton.cdiv(m, block_v))](
2590 v_work,
2591 ranks,
2592 u,
2593 K=k,
2594 BLOCK_V=block_v,
2595 num_warps=1,
2596 )
2597 _blocked_jacobi_store_projected_kernel[(batch, k, triton.cdiv(n, block_r))](
2598 a_work,
2599 s_work,
2600 ranks,
2601 v,
2602 K=k,
2603 ROWS=rows,
2604 OUT_ROWS=n,
2605 BLOCK_R=block_r,
2606 num_warps=1 if block_r <= 64 else 4,
2607 )
2609 return (
2610 u.reshape(*input.shape[:-2], m, k),
2611 s.reshape(*input.shape[:-2], k),
2612 v.reshape(*input.shape[:-2], n, k),
2613 )
2616def _blocked_jacobi_square_project_svd(input):
2617 batch, m, n = _svd_shape(input)
2618 k = min(m, n)
2619 rows = max(m, n)
2620 a = input.contiguous().reshape(batch, m, n)
2621 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device)
2622 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device)
2623 ranks = torch.empty((batch, k), dtype=torch.int32, device=input.device)
2624 u = torch.empty((batch, m, k), dtype=input.dtype, device=input.device)
2625 s = torch.empty((batch, k), dtype=input.dtype, device=input.device)
2627 block_r = triton.next_power_of_2(rows)
2628 sweeps = 12 if k <= 256 else 16
2629 round_size = k if k % 2 == 0 else k + 1
2630 half_round = round_size // 2
2631 rot_c = torch.empty((batch, half_round), dtype=torch.float32, device=input.device)
2632 rot_s = torch.empty((batch, half_round), dtype=torch.float32, device=input.device)
2633 with torch_device_fn.device(input.device):
2634 _cyclic_jacobi_init_a_kernel[(batch, k)](
2635 a,
2636 a_work,
2637 M=m,
2638 N=n,
2639 K=k,
2640 ROWS=rows,
2641 TALL=True,
2642 BLOCK_R=block_r,
2643 num_warps=1 if block_r <= 64 else 4,
2644 )
2645 for _ in range(sweeps):
2646 for step in range(round_size - 1):
2647 _blocked_jacobi_pair_a_kernel[(batch, half_round)](
2648 a_work,
2649 rot_c,
2650 rot_s,
2651 step,
2652 K=k,
2653 ROUND=round_size,
2654 ROWS=rows,
2655 BLOCK_R=block_r,
2656 num_warps=1 if block_r <= 64 else 4,
2657 )
2658 _cyclic_jacobi_norm_kernel[(batch, k)](
2659 a_work,
2660 s_work,
2661 K=k,
2662 ROWS=rows,
2663 BLOCK_R=block_r,
2664 num_warps=1 if block_r <= 64 else 4,
2665 )
2666 _blocked_jacobi_rank_kernel[(batch, k)](
2667 s_work,
2668 ranks,
2669 s,
2670 k,
2671 num_warps=1,
2672 )
2673 _blocked_jacobi_store_projected_kernel[(batch, k, triton.cdiv(m, block_r))](
2674 a_work,
2675 s_work,
2676 ranks,
2677 u,
2678 K=k,
2679 ROWS=rows,
2680 OUT_ROWS=m,
2681 BLOCK_R=block_r,
2682 num_warps=1 if block_r <= 64 else 4,
2683 )
2685 a_t = a.transpose(1, 2).contiguous()
2686 v = _triton_bmm(a_t, u, (batch, n, k))
2687 with torch_device_fn.device(input.device):
2688 _renorm_projection_update_s_kernel[(batch, k)](
2689 v,
2690 s,
2691 ROWS=n,
2692 K=k,
2693 BLOCK_R=triton.next_power_of_2(n),
2694 num_warps=1 if n <= 64 else 4,
2695 )
2697 return (
2698 u.reshape(*input.shape[:-2], m, k),
2699 s.reshape(*input.shape[:-2], k),
2700 v.reshape(*input.shape[:-2], n, k),
2701 )
2704def _hier_block_jacobi_square_project_svd(input):
2705 batch, m, n = _svd_shape(input)
2706 k = min(m, n)
2707 rows = max(m, n)
2708 a = input.contiguous().reshape(batch, m, n)
2709 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device)
2710 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device)
2711 ranks = torch.empty((batch, k), dtype=torch.int32, device=input.device)
2712 u = torch.empty((batch, m, k), dtype=input.dtype, device=input.device)
2713 s = torch.empty((batch, k), dtype=input.dtype, device=input.device)
2715 tile_b = 4 if k == 512 else 2
2716 if m != n or k % tile_b != 0:
2717 return _unsupported_svd(
2718 input,
2719 True,
2720 True,
2721 "Hierarchical block Jacobi supports square matrices with "
2722 "k divisible by two.",
2723 )
2725 block_r = triton.next_power_of_2(rows)
2726 block_count = k // tile_b
2727 round_blocks = block_count if block_count % 2 == 0 else block_count + 1
2728 half_round_blocks = round_blocks // 2
2729 sweep_count = 10 if k <= 256 else 12
2730 tile_cols = tile_b * 2
2731 with torch_device_fn.device(input.device):
2732 _cyclic_jacobi_init_a_kernel[(batch, k)](
2733 a,
2734 a_work,
2735 M=m,
2736 N=n,
2737 K=k,
2738 ROWS=rows,
2739 TALL=True,
2740 BLOCK_R=block_r,
2741 num_warps=1 if block_r <= 64 else 4,
2742 )
2743 for _ in range(sweep_count):
2744 for step in range(round_blocks - 1):
2745 _hier_block_jacobi_pair_a_kernel[(batch, half_round_blocks)](
2746 a_work,
2747 step,
2748 K=k,
2749 K_BLOCKS=block_count,
2750 ROUND_BLOCKS=round_blocks,
2751 ROWS=rows,
2752 TILE_B=tile_b,
2753 TILE_COLS=tile_cols,
2754 BLOCK_R=block_r,
2755 LOCAL_SWEEPS=1,
2756 num_warps=4,
2757 )
2758 _cyclic_jacobi_norm_kernel[(batch, k)](
2759 a_work,
2760 s_work,
2761 K=k,
2762 ROWS=rows,
2763 BLOCK_R=block_r,
2764 num_warps=1 if block_r <= 64 else 4,
2765 )
2766 _blocked_jacobi_rank_kernel[(batch, k)](
2767 s_work,
2768 ranks,
2769 s,
2770 k,
2771 num_warps=1,
2772 )
2773 _blocked_jacobi_store_projected_kernel[(batch, k, triton.cdiv(m, block_r))](
2774 a_work,
2775 s_work,
2776 ranks,
2777 u,
2778 K=k,
2779 ROWS=rows,
2780 OUT_ROWS=m,
2781 BLOCK_R=block_r,
2782 num_warps=1 if block_r <= 64 else 4,
2783 )
2785 a_t = a.transpose(1, 2).contiguous()
2786 v = _triton_bmm(a_t, u, (batch, n, k))
2787 with torch_device_fn.device(input.device):
2788 _renorm_projection_update_s_kernel[(batch, k)](
2789 v,
2790 s,
2791 ROWS=n,
2792 K=k,
2793 BLOCK_R=triton.next_power_of_2(n),
2794 num_warps=1 if n <= 64 else 4,
2795 )
2797 return (
2798 u.reshape(*input.shape[:-2], m, k),
2799 s.reshape(*input.shape[:-2], k),
2800 v.reshape(*input.shape[:-2], n, k),
2801 )
2804def _blocked_jacobi_singular_values(input):
2805 batch, m, n = _svd_shape(input)
2806 k = min(m, n)
2807 rows = max(m, n)
2808 a = input.contiguous().reshape(batch, m, n)
2809 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device)
2810 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device)
2811 ranks = torch.empty((batch, k), dtype=torch.int32, device=input.device)
2812 s = torch.empty((batch, k), dtype=input.dtype, device=input.device)
2814 block_r = triton.next_power_of_2(rows)
2815 sweeps = 14 if k > 256 else 10
2816 round_size = k if k % 2 == 0 else k + 1
2817 half_round = round_size // 2
2818 with torch_device_fn.device(input.device):
2819 _cyclic_jacobi_init_a_kernel[(batch, k)](
2820 a,
2821 a_work,
2822 M=m,
2823 N=n,
2824 K=k,
2825 ROWS=rows,
2826 TALL=m >= n,
2827 BLOCK_R=block_r,
2828 num_warps=1 if block_r <= 64 else 4,
2829 )
2830 for _ in range(sweeps):
2831 for step in range(round_size - 1):
2832 _blocked_jacobi_pair_svals_kernel[(batch, half_round)](
2833 a_work,
2834 step,
2835 K=k,
2836 ROUND=round_size,
2837 ROWS=rows,
2838 BLOCK_R=block_r,
2839 num_warps=1 if block_r <= 64 else 4,
2840 )
2841 _cyclic_jacobi_norm_kernel[(batch, k)](
2842 a_work,
2843 s_work,
2844 K=k,
2845 ROWS=rows,
2846 BLOCK_R=block_r,
2847 num_warps=1 if block_r <= 64 else 4,
2848 )
2849 _blocked_jacobi_rank_kernel[(batch, k)](
2850 s_work,
2851 ranks,
2852 s,
2853 k,
2854 num_warps=1,
2855 )
2857 return s.reshape(*input.shape[:-2], k)
2860def _small4_square_svd(input):
2861 batch, m, n = _svd_shape(input)
2862 a = input.contiguous().reshape(batch, m, n)
2863 u = torch.empty((batch, 4, 4), dtype=input.dtype, device=input.device)
2864 s = torch.empty((batch, 4), dtype=input.dtype, device=input.device)
2865 v = torch.empty((batch, 4, 4), dtype=input.dtype, device=input.device)
2866 block_b = 16
2867 with torch_device_fn.device(input.device):
2868 _small4_square_svd_kernel[(triton.cdiv(batch, block_b),)](
2869 a, u, s, v, BATCH=batch, BLOCK_B=block_b, SWEEPS=4, num_warps=1
2870 )
2871 return (
2872 u.reshape(*input.shape[:-2], 4, 4),
2873 s.reshape(*input.shape[:-2], 4),
2874 v.reshape(*input.shape[:-2], 4, 4),
2875 )
2878@libentry()
2879@triton.jit
2880def _rank1_svd_kernel(
2881 A,
2882 U,
2883 S,
2884 V,
2885 M: tl.constexpr,
2886 N: tl.constexpr,
2887 TALL: tl.constexpr,
2888 BLOCK_R: tl.constexpr,
2889):
2890 pid = tl.program_id(0)
2891 offsets = tl.arange(0, BLOCK_R)
2892 eps = 1.1920928955078125e-7
2893 a_base = A + pid * M * N
2894 norm_sq = tl.full((), 0.0, dtype=tl.float32)
2896 if TALL:
2897 for base in range(0, M, BLOCK_R):
2898 rows = base + offsets
2899 mask = rows < M
2900 vals = tl.load(a_base + rows * N, mask=mask, other=0.0).to(tl.float32)
2901 norm_sq += tl.sum(vals * vals)
2903 norm = tl.sqrt(norm_sq)
2904 denom = tl.maximum(norm, eps)
2905 tl.store(S + pid, norm)
2906 tl.store(V + pid, 1.0)
2908 u_base = U + pid * M
2909 for base in range(0, M, BLOCK_R):
2910 rows = base + offsets
2911 mask = rows < M
2912 vals = tl.load(a_base + rows * N, mask=mask, other=0.0).to(tl.float32)
2913 tl.store(u_base + rows, vals / denom, mask=mask)
2914 else:
2915 for base in range(0, N, BLOCK_R):
2916 cols = base + offsets
2917 mask = cols < N
2918 vals = tl.load(a_base + cols, mask=mask, other=0.0).to(tl.float32)
2919 norm_sq += tl.sum(vals * vals)
2921 norm = tl.sqrt(norm_sq)
2922 denom = tl.maximum(norm, eps)
2923 tl.store(S + pid, norm)
2924 tl.store(U + pid, 1.0)
2926 v_base = V + pid * N
2927 for base in range(0, N, BLOCK_R):
2928 cols = base + offsets
2929 mask = cols < N
2930 vals = tl.load(a_base + cols, mask=mask, other=0.0).to(tl.float32)
2931 tl.store(v_base + cols, vals / denom, mask=mask)
2934def _rank1_svd(input):
2935 batch, m, n = _svd_shape(input)
2936 a = input.contiguous().reshape(batch, m, n)
2937 u = torch.empty((batch, m, 1), dtype=input.dtype, device=input.device)
2938 s = torch.empty((batch, 1), dtype=input.dtype, device=input.device)
2939 v = torch.empty((batch, n, 1), dtype=input.dtype, device=input.device)
2940 if batch != 0:
2941 rows = max(m, n)
2942 block_r = _RANK1_BLOCK_R_MAX
2943 if rows <= _RANK1_BLOCK_R_MAX:
2944 block_r = triton.next_power_of_2(rows)
2945 with torch_device_fn.device(input.device):
2946 _rank1_svd_kernel[(batch,)](
2947 a,
2948 u,
2949 s,
2950 v,
2951 m,
2952 n,
2953 TALL=n == 1,
2954 BLOCK_R=block_r,
2955 num_warps=1 if block_r <= 64 else 4,
2956 )
2957 return (
2958 u.reshape(*input.shape[:-2], m, 1),
2959 s.reshape(*input.shape[:-2], 1),
2960 v.reshape(*input.shape[:-2], n, 1),
2961 )
2964@libentry()
2965@triton.jit
2966def _complex_to_real_embedding_kernel(
2967 A_RI,
2968 R,
2969 M: tl.constexpr,
2970 N: tl.constexpr,
2971 BLOCK_SIZE: tl.constexpr,
2972):
2973 batch = tl.program_id(0)
2974 offsets = tl.arange(0, BLOCK_SIZE)
2975 total = 4 * M * N
2976 mask = offsets < total
2977 row = offsets // (2 * N)
2978 col = offsets - row * (2 * N)
2979 src_row = tl.where(row < M, row, row - M)
2980 src_col = tl.where(col < N, col, col - N)
2981 comp = tl.where((row < M) & (col >= N), 1, 0)
2982 comp = tl.where((row >= M) & (col < N), 1, comp)
2983 vals = tl.load(
2984 A_RI + batch * M * N * 2 + (src_row * N + src_col) * 2 + comp,
2985 mask=mask,
2986 other=0.0,
2987 )
2988 sign = tl.where((row < M) & (col >= N), -1.0, 1.0)
2989 tl.store(R + batch * 4 * M * N + offsets, vals * sign, mask=mask)
2992@libentry()
2993@triton.jit
2994def _complex_svd_pick_factor_kernel(
2995 REAL_FACTOR,
2996 OUT_RI,
2997 ROWS: tl.constexpr,
2998 K: tl.constexpr,
2999 REAL_K: tl.constexpr,
3000 BLOCK_SIZE: tl.constexpr,
3001):
3002 batch = tl.program_id(0)
3003 offsets = tl.arange(0, BLOCK_SIZE)
3004 mask = offsets < ROWS * K
3005 row = offsets // K
3006 col = offsets % K
3007 src_col = col * 2
3008 real = tl.load(
3009 REAL_FACTOR + batch * (2 * ROWS) * REAL_K + row * REAL_K + src_col,
3010 mask=mask,
3011 other=0.0,
3012 )
3013 imag = tl.load(
3014 REAL_FACTOR + batch * (2 * ROWS) * REAL_K + (ROWS + row) * REAL_K + src_col,
3015 mask=mask,
3016 other=0.0,
3017 )
3018 out_base = OUT_RI + batch * ROWS * K * 2 + offsets * 2
3019 tl.store(out_base, real, mask=mask)
3020 tl.store(out_base + 1, imag, mask=mask)
3023@libentry()
3024@triton.jit
3025def _complex_svd_pick_s_kernel(
3026 S_REAL,
3027 S,
3028 K: tl.constexpr,
3029 REAL_K: tl.constexpr,
3030 BLOCK_K: tl.constexpr,
3031):
3032 batch = tl.program_id(0)
3033 cols = tl.arange(0, BLOCK_K)
3034 mask = cols < K
3035 src = cols * 2
3036 vals_a = tl.load(S_REAL + batch * REAL_K + src, mask=mask, other=0.0)
3037 vals_b = tl.load(S_REAL + batch * REAL_K + src + 1, mask=mask, other=0.0)
3038 tl.store(S + batch * K + cols, 0.5 * (vals_a + vals_b), mask=mask)
3041@libentry()
3042@triton.jit
3043def _complex_svd_pick_orthonormal_v_kernel(
3044 V_REAL,
3045 V_RI,
3046 ROWS: tl.constexpr,
3047 K: tl.constexpr,
3048 REAL_K: tl.constexpr,
3049 BLOCK_ROWS: tl.constexpr,
3050 BLOCK_K: tl.constexpr,
3051):
3052 batch = tl.program_id(0)
3053 rows = tl.arange(0, BLOCK_ROWS)
3054 cols = tl.arange(0, BLOCK_K)
3055 row_mask = rows < ROWS
3056 col_mask = cols < K
3057 src_cols = cols * 2
3058 base = V_REAL + batch * (2 * ROWS) * REAL_K
3059 vr = tl.load(
3060 base + rows[:, None] * REAL_K + src_cols[None, :],
3061 mask=row_mask[:, None] & col_mask[None, :],
3062 other=0.0,
3063 )
3064 vi = tl.load(
3065 base + (ROWS + rows[:, None]) * REAL_K + src_cols[None, :],
3066 mask=row_mask[:, None] & col_mask[None, :],
3067 other=0.0,
3068 )
3070 for c in tl.static_range(0, 16):
3071 cur_mask = c < K
3072 cur_r = tl.sum(tl.where(cols[None, :] == c, vr, 0.0), axis=1)
3073 cur_i = tl.sum(tl.where(cols[None, :] == c, vi, 0.0), axis=1)
3074 for p in tl.static_range(0, c):
3075 prev_r = tl.sum(tl.where(cols[None, :] == p, vr, 0.0), axis=1)
3076 prev_i = tl.sum(tl.where(cols[None, :] == p, vi, 0.0), axis=1)
3077 coeff_r = tl.sum(
3078 tl.where(row_mask, prev_r * cur_r + prev_i * cur_i, 0.0), axis=0
3079 )
3080 coeff_i = tl.sum(
3081 tl.where(row_mask, prev_r * cur_i - prev_i * cur_r, 0.0), axis=0
3082 )
3083 cur_r -= prev_r * coeff_r - prev_i * coeff_i
3084 cur_i -= prev_r * coeff_i + prev_i * coeff_r
3085 norm_sq = tl.sum(tl.where(row_mask, cur_r * cur_r + cur_i * cur_i, 0.0), axis=0)
3086 inv_norm = tl.rsqrt(tl.maximum(norm_sq, 1.0e-20))
3087 cur_r *= inv_norm
3088 cur_i *= inv_norm
3089 vr = tl.where((cols[None, :] == c) & cur_mask, cur_r[:, None], vr)
3090 vi = tl.where((cols[None, :] == c) & cur_mask, cur_i[:, None], vi)
3092 out_base = V_RI + batch * ROWS * K * 2
3093 offsets = rows[:, None] * K + cols[None, :]
3094 mask = row_mask[:, None] & col_mask[None, :]
3095 tl.store(out_base + offsets * 2, vr, mask=mask)
3096 tl.store(out_base + offsets * 2 + 1, vi, mask=mask)
3099@libentry()
3100@triton.jit
3101def _complex_svd_project_u_kernel(
3102 A_RI,
3103 V_RI,
3104 S,
3105 U_RI,
3106 M: tl.constexpr,
3107 N: tl.constexpr,
3108 K: tl.constexpr,
3109 BLOCK_SIZE: tl.constexpr,
3110):
3111 batch = tl.program_id(0)
3112 offsets = tl.arange(0, BLOCK_SIZE)
3113 mask = offsets < M * K
3114 row = offsets // K
3115 col = offsets % K
3117 acc_r = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
3118 acc_i = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
3119 for j in tl.static_range(0, N):
3120 a_base = A_RI + batch * M * N * 2 + (row * N + j) * 2
3121 v_base = V_RI + batch * N * K * 2 + (j * K + col) * 2
3122 ar = tl.load(a_base, mask=mask, other=0.0)
3123 ai = tl.load(a_base + 1, mask=mask, other=0.0)
3124 vr = tl.load(v_base, mask=mask, other=0.0)
3125 vi = tl.load(v_base + 1, mask=mask, other=0.0)
3126 acc_r += ar * vr - ai * vi
3127 acc_i += ar * vi + ai * vr
3129 s = tl.load(S + batch * K + col, mask=mask, other=1.0)
3130 inv_s = tl.where(s > 1.0e-20, 1.0 / s, 0.0)
3131 out_base = U_RI + batch * M * K * 2 + offsets * 2
3132 tl.store(out_base, acc_r * inv_s, mask=mask)
3133 tl.store(out_base + 1, acc_i * inv_s, mask=mask)
3136def _complex_svd_via_real_embedding(input):
3137 batch, m, n = _svd_shape(input)
3138 k = min(m, n)
3139 a_ri = torch.view_as_real(input.contiguous()).reshape(batch, m, n, 2)
3140 real_matrix = torch.empty(
3141 (batch, 2 * m, 2 * n), dtype=torch.float32, device=input.device
3142 )
3143 block_size = triton.next_power_of_2(4 * m * n)
3144 with torch_device_fn.device(input.device):
3145 _complex_to_real_embedding_kernel[(batch,)](
3146 a_ri,
3147 real_matrix,
3148 M=m,
3149 N=n,
3150 BLOCK_SIZE=block_size,
3151 num_warps=1,
3152 )
3153 _, s_real, v_real = svd(real_matrix, some=True, compute_uv=True)
3154 s = torch.empty((batch, k), dtype=torch.float32, device=input.device)
3155 u = torch.empty((*input.shape[:-2], m, k), dtype=input.dtype, device=input.device)
3156 v = torch.empty((*input.shape[:-2], n, k), dtype=input.dtype, device=input.device)
3157 u_ri = torch.view_as_real(u).reshape(batch, m, k, 2)
3158 v_ri = torch.view_as_real(v).reshape(batch, n, k, 2)
3159 with torch_device_fn.device(input.device):
3160 _complex_svd_pick_s_kernel[(batch,)](
3161 s_real,
3162 s,
3163 K=k,
3164 REAL_K=2 * k,
3165 BLOCK_K=triton.next_power_of_2(k),
3166 num_warps=1,
3167 )
3168 _complex_svd_pick_orthonormal_v_kernel[(batch,)](
3169 v_real,
3170 v_ri,
3171 ROWS=n,
3172 K=k,
3173 REAL_K=2 * k,
3174 BLOCK_ROWS=triton.next_power_of_2(n),
3175 BLOCK_K=triton.next_power_of_2(k),
3176 num_warps=1,
3177 )
3178 _complex_svd_project_u_kernel[(batch,)](
3179 a_ri,
3180 v_ri,
3181 s,
3182 u_ri,
3183 M=m,
3184 N=n,
3185 K=k,
3186 BLOCK_SIZE=triton.next_power_of_2(m * k),
3187 num_warps=1,
3188 )
3189 return (
3190 u,
3191 s.reshape(*input.shape[:-2], k),
3192 v,
3193 )
3196def _gram_svd(input):
3197 return _unsupported_svd(input, True, True)
3200@libentry()
3201@triton.jit
3202def _gram16_finalize_kernel(
3203 A,
3204 EVALS,
3205 EVECS,
3206 U,
3207 S,
3208 V,
3209 M: tl.constexpr,
3210 N: tl.constexpr,
3211 ROWS: tl.constexpr,
3212 TALL: tl.constexpr,
3213 EVECS_BATCH_STRIDE: tl.constexpr,
3214 EVECS_ROW_STRIDE: tl.constexpr,
3215 EVECS_COL_STRIDE: tl.constexpr,
3216 BLOCK_R: tl.constexpr,
3217):
3218 batch = tl.program_id(0)
3219 row_block = tl.program_id(1)
3220 rows = row_block * BLOCK_R + tl.arange(0, BLOCK_R)
3221 cols = tl.arange(0, 16)
3222 src_cols = 15 - cols
3223 row_mask = rows < ROWS
3224 eps = 1.0e-20
3226 vals = tl.load(EVALS + batch * 16 + src_cols)
3227 s_vals = tl.sqrt(tl.maximum(vals, 0.0))
3228 inv_s = tl.where(s_vals > eps, 1.0 / s_vals, 0.0)
3230 acc = tl.zeros((BLOCK_R, 16), dtype=tl.float32)
3231 a_base = A + batch * M * N
3232 e_base = EVECS + batch * EVECS_BATCH_STRIDE
3233 for k in tl.static_range(0, 16):
3234 eig = tl.load(e_base + k * EVECS_ROW_STRIDE + src_cols * EVECS_COL_STRIDE)
3235 if TALL:
3236 a_vals = tl.load(
3237 a_base + rows * N + k,
3238 mask=row_mask,
3239 other=0.0,
3240 )
3241 else:
3242 a_vals = tl.load(
3243 a_base + k * N + rows,
3244 mask=row_mask,
3245 other=0.0,
3246 )
3247 acc += a_vals[:, None] * eig[None, :]
3249 projected = acc * inv_s[None, :]
3250 if TALL:
3251 tl.store(
3252 U + batch * M * 16 + rows[:, None] * 16 + cols[None, :],
3253 projected,
3254 mask=row_mask[:, None],
3255 )
3256 else:
3257 tl.store(
3258 V + batch * N * 16 + rows[:, None] * 16 + cols[None, :],
3259 projected,
3260 mask=row_mask[:, None],
3261 )
3263 head_mask = row_block == 0
3264 tl.store(S + batch * 16 + cols, s_vals, mask=head_mask)
3266 basis_rows = tl.arange(0, 16)
3267 basis_cols = tl.arange(0, 16)
3268 basis_src_cols = 15 - basis_cols
3269 basis = tl.load(
3270 e_base
3271 + basis_rows[:, None] * EVECS_ROW_STRIDE
3272 + basis_src_cols[None, :] * EVECS_COL_STRIDE
3273 )
3274 if TALL:
3275 tl.store(
3276 V + batch * N * 16 + basis_rows[:, None] * 16 + basis_cols[None, :],
3277 basis,
3278 mask=head_mask,
3279 )
3280 else:
3281 tl.store(
3282 U + batch * M * 16 + basis_rows[:, None] * 16 + basis_cols[None, :],
3283 basis,
3284 mask=head_mask,
3285 )
3288def _gram16_svd(input):
3289 return _unsupported_svd(input, True, True)
3292def _large_native_svd(input):
3293 if _can_use_hier_block_square_project_kernel(input, True, True):
3294 return _hier_block_jacobi_square_project_svd(input)
3295 if _can_use_blocked_square_project_kernel(input, True, True):
3296 return _blocked_jacobi_square_project_svd(input)
3297 if _can_use_blocked_jacobi_kernel(input, True, True):
3298 return _blocked_jacobi_svd(input)
3299 return _bidiagonal_qr_dqds_svd(input)
3302def _bidiagonal_qr_dqds_svd(input):
3303 return _unsupported_svd(
3304 input,
3305 True,
3306 True,
3307 "The k > 512 blocked-bidiagonalization plus QR/DQDS path is reserved "
3308 "for the next native large-matrix solver stage.",
3309 )
3312def _empty_svd_result(input, some=True, compute_uv=True):
3313 _, m, n = _svd_shape(input)
3314 k = min(m, n)
3315 u_cols = k if compute_uv and some else m
3316 v_cols = k if compute_uv and some else n
3317 u = torch.empty(
3318 (*input.shape[:-2], m, u_cols), dtype=input.dtype, device=input.device
3319 )
3320 s = torch.empty((*input.shape[:-2], k), dtype=input.dtype, device=input.device)
3321 v = torch.empty(
3322 (*input.shape[:-2], n, v_cols), dtype=input.dtype, device=input.device
3323 )
3324 return u, s, v
3327@libentry()
3328@triton.jit
3329def _complete_svd_factor_kernel(
3330 THIN,
3331 FULL,
3332 ROWS: tl.constexpr,
3333 THIN_COLS: tl.constexpr,
3334 FULL_COLS: tl.constexpr,
3335 BLOCK_ROWS: tl.constexpr,
3336 BLOCK_COLS: tl.constexpr,
3337):
3338 batch = tl.program_id(0)
3339 rows = tl.arange(0, BLOCK_ROWS)
3340 cols = tl.arange(0, BLOCK_COLS)
3341 row_mask = rows < ROWS
3342 col_mask = cols < FULL_COLS
3343 vals = tl.load(
3344 THIN + batch * ROWS * THIN_COLS + rows[:, None] * THIN_COLS + cols[None, :],
3345 mask=row_mask[:, None] & (cols[None, :] < THIN_COLS),
3346 other=0.0,
3347 )
3348 identity = tl.where(rows[:, None] == cols[None, :], 1.0, 0.0)
3349 vals = tl.where(cols[None, :] < THIN_COLS, vals, identity)
3351 for c in tl.static_range(0, 64):
3352 cur_mask = c < FULL_COLS
3353 cur = tl.sum(tl.where(cols[None, :] == c, vals, 0.0), axis=1)
3354 for p in tl.static_range(0, c):
3355 prev = tl.sum(tl.where(cols[None, :] == p, vals, 0.0), axis=1)
3356 coeff = tl.sum(tl.where(row_mask, prev * cur, 0.0), axis=0)
3357 cur -= prev * coeff
3358 norm_sq = tl.sum(tl.where(row_mask, cur * cur, 0.0), axis=0)
3359 inv_norm = tl.rsqrt(tl.maximum(norm_sq, 1.0e-20))
3360 cur *= inv_norm
3361 vals = tl.where((cols[None, :] == c) & cur_mask, cur[:, None], vals)
3363 out_base = FULL + batch * ROWS * FULL_COLS
3364 offsets = rows[:, None] * FULL_COLS + cols[None, :]
3365 mask = row_mask[:, None] & col_mask[None, :]
3366 tl.store(out_base + offsets, vals, mask=mask)
3369def _low_precision_svd_via_float32(input, some=True, compute_uv=True):
3370 u, s, v = svd(input.to(torch.float32), some=some, compute_uv=compute_uv)
3371 return u.to(input.dtype), s.to(input.dtype), v.to(input.dtype)
3374def _some_false_svd_via_thin(input):
3375 batch, m, n = _svd_shape(input)
3376 k = min(m, n)
3377 thin_u, s, thin_v = svd(input, some=True, compute_uv=True)
3378 u = torch.empty((*input.shape[:-2], m, m), dtype=input.dtype, device=input.device)
3379 v = torch.empty((*input.shape[:-2], n, n), dtype=input.dtype, device=input.device)
3380 with torch_device_fn.device(input.device):
3381 _complete_svd_factor_kernel[(batch,)](
3382 thin_u,
3383 u,
3384 ROWS=m,
3385 THIN_COLS=k,
3386 FULL_COLS=m,
3387 BLOCK_ROWS=triton.next_power_of_2(m),
3388 BLOCK_COLS=triton.next_power_of_2(m),
3389 num_warps=4,
3390 )
3391 _complete_svd_factor_kernel[(batch,)](
3392 thin_v,
3393 v,
3394 ROWS=n,
3395 THIN_COLS=k,
3396 FULL_COLS=n,
3397 BLOCK_ROWS=triton.next_power_of_2(n),
3398 BLOCK_COLS=triton.next_power_of_2(n),
3399 num_warps=4,
3400 )
3401 return u, s, v
3404def _compute_uv_false_result(input, s):
3405 _, m, n = _svd_shape(input)
3406 u = torch.empty((*input.shape[:-2], m, m), dtype=input.dtype, device=input.device)
3407 v = torch.empty((*input.shape[:-2], n, n), dtype=input.dtype, device=input.device)
3408 return u, s, v
3411def _singular_values_only(input):
3412 _, m, n = _svd_shape(input)
3413 k = min(m, n)
3414 largest = max(m, n)
3415 if k == 2 and largest <= _RANK2_BLOCK_R_MAX:
3416 return _rank2_singular_values(input)
3417 if k <= 16 and largest <= 1024:
3418 return _small_jacobi_singular_values(input)
3419 if 16 < k <= 512 and largest <= 1024:
3420 return _blocked_jacobi_singular_values(input)
3421 return _unsupported_svd(input, True, False)
3424def _should_use_gram16(batch, m, n):
3425 return batch >= 16 and min(m, n) == 16 and max(m, n) <= 1024
3428def _should_use_gram(batch, m, n):
3429 k = min(m, n)
3430 largest = max(m, n)
3431 if k <= 32:
3432 return True
3433 if batch <= 4 and m == n and m <= 256:
3434 return True
3435 if (m, n) == (1024, 1024):
3436 return True
3437 if batch >= 128 and k <= 64 and largest <= 1024:
3438 return False
3439 return False
3442def svd(input, some=True, compute_uv=True):
3443 logger.debug("GEMS SVD")
3444 if (
3445 input.is_cuda
3446 and input.dtype == torch.complex64
3447 and some
3448 and compute_uv
3449 and input.dim() >= 2
3450 and 0 not in input.shape[-2:]
3451 and max(input.shape[-2:]) <= 16
3452 ):
3453 return SVDResult(*_complex_svd_via_real_embedding(input))
3454 if _is_low_precision_cuda_matrix(input):
3455 return SVDResult(*_low_precision_svd_via_float32(input, some, compute_uv))
3456 if _is_float32_cuda_matrix(input) and 0 in input.shape[-2:]:
3457 return SVDResult(*_empty_svd_result(input, some, compute_uv))
3458 if _can_use_singular_values_only(input, some, compute_uv):
3459 return SVDResult(*_compute_uv_false_result(input, _singular_values_only(input)))
3460 if (
3461 _is_float32_cuda_matrix(input)
3462 and not some
3463 and compute_uv
3464 and max(input.shape[-2:]) <= 64
3465 ):
3466 return SVDResult(*_some_false_svd_via_thin(input))
3467 if not _is_float32_cuda_matrix(input) or not some:
3468 return SVDResult(*_unsupported_svd(input, some, compute_uv))
3469 batch, m, n = _svd_shape(input)
3470 k = min(m, n)
3471 try:
3472 if k == 1:
3473 return SVDResult(*_rank1_svd(input))
3474 if k == 2 and max(m, n) <= _RANK2_BLOCK_R_MAX:
3475 return SVDResult(*_rank2_svd(input))
3476 if k == 4 and m == 4 and n == 4 and batch >= 16:
3477 return SVDResult(*_small4_square_svd(input))
3478 if _can_use_tall_wide_gram_jacobi_kernel(input, some, compute_uv):
3479 return SVDResult(*_gram_jacobi_svd(input))
3480 use_batched_cyclic16 = k == 16 and batch >= 8 and max(m, n) <= 64
3481 if (
3482 _can_use_small_jacobi_kernel(input, some, compute_uv)
3483 and not use_batched_cyclic16
3484 ):
3485 return SVDResult(*_small_jacobi_svd(input))
3486 if _can_use_tsqr_cholesky_kernel(input, some, compute_uv):
3487 return SVDResult(*_tsqr_cholesky_svd(input))
3488 if _can_use_projected_jacobi_kernel(input, some, compute_uv):
3489 return SVDResult(*_projected_jacobi_svd(input))
3490 if _can_use_cyclic_jacobi_kernel(input, some, compute_uv):
3491 return SVDResult(*_cyclic_jacobi_svd(input))
3492 if _can_use_hier_block_square_project_kernel(input, some, compute_uv):
3493 return SVDResult(*_hier_block_jacobi_square_project_svd(input))
3494 if _can_use_blocked_square_project_kernel(input, some, compute_uv):
3495 return SVDResult(*_blocked_jacobi_square_project_svd(input))
3496 if _can_use_blocked_jacobi_kernel(input, some, compute_uv):
3497 return SVDResult(*_blocked_jacobi_svd(input))
3498 if _should_use_gram16(batch, m, n):
3499 return SVDResult(*_gram16_svd(input))
3500 if _should_use_gram(batch, m, n):
3501 return SVDResult(*_gram_svd(input))
3502 return SVDResult(*_large_native_svd(input))
3503 except RuntimeError:
3504 return SVDResult(*_unsupported_svd(input, some, compute_uv))