Coverage for src/flag_gems/runtime/backend/_sunrise/ops/svd.py: 0%
1842 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2from collections import namedtuple
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import device, torch_device_fn
9from flag_gems.utils import libentry, tensor_wrapper
11logger = logging.getLogger(__name__)
13SVDResult = namedtuple("SVDResult", ["U", "S", "V"])
15_GRAM_CONDITION_GUARD_MAX_BATCH = 16
16_GRAM_CONDITION_GUARD_MAX_K = 32
17_GRAM_CONDITION_EIGEN_RATIO = 1.0e-8
18_GRAM_TALL_WIDE_MAX_K = 32
19_GRAM_TALL_WIDE_MAX_ROWS = 1024
20_RANK1_BLOCK_R_MAX = 1024
21_RANK2_BLOCK_R_MAX = 2048
22_TSQR_CHOLESKY_MAX_BATCH = 32
23_TSQR_CHOLESKY_MAX_K = 128
24_TSQR_CHOLESKY_MAX_ROWS = 1024
27def _unsupported_svd(input, some=True, compute_uv=True, reason=None):
28 batch, m, n = _svd_shape(input)
29 suffix = "" if reason is None else f" {reason}"
30 raise NotImplementedError(
31 "FlagGems Sunrise native SVD currently supports float32 PTPU matrices with "
32 "some=True, compute_uv=True, non-empty inputs, and native Triton "
33 f"rank/Jacobi shape coverage; got batch={batch}, m={m}, n={n}, "
34 f"dtype={input.dtype}, device={input.device}, some={some}, "
35 f"compute_uv={compute_uv}.{suffix}"
36 )
39def _is_iluvatar_backend():
40 return device.vendor_name == "iluvatar"
43def _svd_shape(input):
44 if input.dim() < 2:
45 return 0, 0, 0
46 m = input.shape[-2]
47 n = input.shape[-1]
48 batch = 1
49 for dim in input.shape[:-2]:
50 batch *= dim
51 return batch, m, n
54def _should_guard_gram_spectrum(batch, k):
55 return batch <= _GRAM_CONDITION_GUARD_MAX_BATCH and k <= _GRAM_CONDITION_GUARD_MAX_K
58def _is_float32_cuda_matrix(input):
59 return (
60 input.device.type == "ptpu"
61 and input.dtype == torch.float32
62 and input.dim() >= 2
63 )
66def _is_low_precision_cuda_matrix(input):
67 return (
68 input.device.type == "ptpu"
69 and input.dtype in (torch.float16, torch.bfloat16)
70 and input.dim() >= 2
71 )
74def _can_use_rank1_kernel(input, some=True, compute_uv=True):
75 _, m, n = _svd_shape(input)
76 return _is_float32_cuda_matrix(input) and some and compute_uv and min(m, n) == 1
79def _can_use_rank2_kernel(input, some=True, compute_uv=True):
80 _, m, n = _svd_shape(input)
81 return (
82 _is_float32_cuda_matrix(input)
83 and some
84 and compute_uv
85 and min(m, n) == 2
86 and max(m, n) <= _RANK2_BLOCK_R_MAX
87 )
90def _can_use_2x2_kernel(input):
91 _, m, n = _svd_shape(input)
92 return _can_use_rank2_kernel(input, True, True) and m == 2 and n == 2
95def _can_use_4x4_kernel(input, some=True, compute_uv=True):
96 _, m, n = _svd_shape(input)
97 return _is_float32_cuda_matrix(input) and some and compute_uv and m == 4 and n == 4
100def _can_use_small_jacobi_kernel(input, some=True, compute_uv=True):
101 _, m, n = _svd_shape(input)
102 return (
103 _is_float32_cuda_matrix(input)
104 and some
105 and compute_uv
106 and not _is_iluvatar_backend()
107 and min(m, n) <= 16
108 and max(m, n) <= 1024
109 )
112def _can_use_cyclic_jacobi_kernel(input, some=True, compute_uv=True):
113 _, m, n = _svd_shape(input)
114 k = min(m, n)
115 return (
116 _is_float32_cuda_matrix(input)
117 and some
118 and compute_uv
119 and 16 <= k <= 64
120 and max(m, n) <= 1024
121 )
124def _can_use_gram_jacobi_kernel(input, some=True, compute_uv=True):
125 _, m, n = _svd_shape(input)
126 k = min(m, n)
127 return (
128 _is_float32_cuda_matrix(input)
129 and some
130 and compute_uv
131 and 16 <= k <= 32
132 and max(m, n) <= 64
133 )
136def _can_use_tall_wide_gram_jacobi_kernel(input, some=True, compute_uv=True):
137 batch, m, n = _svd_shape(input)
138 k = min(m, n)
139 rows = max(m, n)
140 return (
141 _is_float32_cuda_matrix(input)
142 and some
143 and compute_uv
144 and batch >= 128
145 and 16 <= k <= _GRAM_TALL_WIDE_MAX_K
146 and rows <= _GRAM_TALL_WIDE_MAX_ROWS
147 and rows >= 2 * k
148 )
151def _can_use_tsqr_cholesky_kernel(input, some=True, compute_uv=True):
152 # Input-dependent TSQR safety needs a native device-side guard before dispatch.
153 return False
156def _can_use_blocked_jacobi_kernel(input, some=True, compute_uv=True):
157 _, m, n = _svd_shape(input)
158 k = min(m, n)
159 return (
160 _is_float32_cuda_matrix(input)
161 and some
162 and compute_uv
163 and 64 < k <= 512
164 and max(m, n) <= 1024
165 )
168def _can_use_blocked_square_project_kernel(input, some=True, compute_uv=True):
169 batch, m, n = _svd_shape(input)
170 k = min(m, n)
171 return (
172 _is_float32_cuda_matrix(input)
173 and some
174 and compute_uv
175 and batch == 1
176 and m == n
177 and 128 <= k <= 512
178 )
181def _can_use_hier_block_square_project_kernel(input, some=True, compute_uv=True):
182 batch, m, n = _svd_shape(input)
183 k = min(m, n)
184 return (
185 _is_float32_cuda_matrix(input)
186 and some
187 and compute_uv
188 and batch <= 2
189 and m == n
190 and k in (256, 512)
191 )
194def _can_use_projected_jacobi_kernel(input, some=True, compute_uv=True):
195 batch, m, n = _svd_shape(input)
196 k = min(m, n)
197 return (
198 _is_float32_cuda_matrix(input)
199 and some
200 and compute_uv
201 and 4 <= batch <= 32
202 and k == 64
203 and max(m, n) <= 128
204 )
207def _can_use_singular_values_only(input, some=True, compute_uv=False):
208 _, m, n = _svd_shape(input)
209 k = min(m, n)
210 return (
211 _is_float32_cuda_matrix(input)
212 and not compute_uv
213 and k <= 512
214 and max(m, n) <= 1024
215 )
218@libentry()
219@triton.jit
220def _small_jacobi_svd_kernel(
221 A,
222 A_WORK,
223 V_WORK,
224 U,
225 S,
226 V,
227 M: tl.constexpr,
228 N: tl.constexpr,
229 K: tl.constexpr,
230 ROWS: tl.constexpr,
231 TALL: tl.constexpr,
232 BLOCK_R: tl.constexpr,
233 BLOCK_K: tl.constexpr,
234 SWEEPS: tl.constexpr,
235):
236 pid = tl.program_id(0)
237 rows = tl.arange(0, BLOCK_R)
238 cols = tl.arange(0, BLOCK_K)
239 row_mask = rows < ROWS
240 col_mask = cols < K
241 eps = 1.0e-20
243 a_base = A + pid * M * N
244 aw_base = A_WORK + pid * K * ROWS
245 vw_base = V_WORK + pid * K * K
247 for j in tl.static_range(0, K):
248 if TALL:
249 vals = tl.load(a_base + rows * N + j, mask=row_mask, other=0.0).to(
250 tl.float32
251 )
252 else:
253 vals = tl.load(a_base + j * N + rows, mask=row_mask, other=0.0).to(
254 tl.float32
255 )
256 tl.store(aw_base + j * ROWS + rows, vals, mask=row_mask)
257 ident_col = tl.where(cols == j, 1.0, 0.0)
258 tl.store(vw_base + j * K + cols, ident_col, mask=col_mask)
260 for _ in tl.static_range(0, SWEEPS):
261 for p in tl.static_range(0, K):
262 for q in tl.static_range(p + 1, K):
263 ap = tl.load(aw_base + p * ROWS + rows, mask=row_mask, other=0.0)
264 aq = tl.load(aw_base + q * ROWS + rows, mask=row_mask, other=0.0)
265 alpha = tl.sum(ap * ap)
266 beta = tl.sum(aq * aq)
267 gamma = tl.sum(ap * aq)
268 abs_gamma = tl.abs(gamma)
269 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps)
270 active = abs_gamma > threshold
272 safe_gamma = tl.where(active, gamma, 1.0)
273 tau = (beta - alpha) / (2.0 * safe_gamma)
274 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0)
275 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau))
276 c = tl.rsqrt(1.0 + t * t)
277 s_rot = t * c
278 c = tl.where(active, c, 1.0)
279 s_rot = tl.where(active, s_rot, 0.0)
281 new_ap = c * ap - s_rot * aq
282 new_aq = s_rot * ap + c * aq
283 tl.store(aw_base + p * ROWS + rows, new_ap, mask=row_mask)
284 tl.store(aw_base + q * ROWS + rows, new_aq, mask=row_mask)
286 vp = tl.load(vw_base + p * K + cols, mask=col_mask, other=0.0)
287 vq = tl.load(vw_base + q * K + cols, mask=col_mask, other=0.0)
288 new_vp = c * vp - s_rot * vq
289 new_vq = s_rot * vp + c * vq
290 tl.store(vw_base + p * K + cols, new_vp, mask=col_mask)
291 tl.store(vw_base + q * K + cols, new_vq, mask=col_mask)
293 s_idx = tl.arange(0, BLOCK_K)
294 s_vals = tl.full((BLOCK_K,), 0.0, dtype=tl.float32)
295 for j in tl.static_range(0, K):
296 col = tl.load(aw_base + j * ROWS + rows, mask=row_mask, other=0.0)
297 norm = tl.sqrt(tl.sum(col * col))
298 s_vals = tl.where(s_idx == j, norm, s_vals)
300 ranks = tl.zeros((BLOCK_K,), dtype=tl.int32)
301 for i in tl.static_range(0, K):
302 si = tl.sum(tl.where(s_idx == i, s_vals, 0.0))
303 beats = ((si > s_vals) | ((si == s_vals) & (i < s_idx))) & (s_idx < K)
304 ranks = ranks + beats.to(tl.int32)
306 for j in tl.static_range(0, K):
307 col = tl.load(aw_base + j * ROWS + rows, mask=row_mask, other=0.0)
308 norm = tl.sum(tl.where(s_idx == j, s_vals, 0.0))
309 rank = tl.sum(tl.where(s_idx == j, ranks, 0))
310 inv_norm = tl.where(norm > eps, 1.0 / norm, 0.0)
311 tl.store(S + pid * K + rank, norm)
313 basis = tl.load(vw_base + j * K + cols, mask=col_mask, other=0.0)
314 if TALL:
315 tl.store(U + pid * M * K + rows * K + rank, col * inv_norm, mask=row_mask)
316 tl.store(V + pid * N * K + cols * K + rank, basis, mask=col_mask)
317 else:
318 tl.store(U + pid * M * K + cols * K + rank, basis, mask=col_mask)
319 tl.store(V + pid * N * K + rows * K + rank, col * inv_norm, mask=row_mask)
322@libentry()
323@triton.jit
324def _small_jacobi_svals_kernel(
325 A,
326 A_WORK,
327 S,
328 M: tl.constexpr,
329 N: tl.constexpr,
330 K: tl.constexpr,
331 ROWS: tl.constexpr,
332 TALL: tl.constexpr,
333 BLOCK_R: tl.constexpr,
334 BLOCK_K: tl.constexpr,
335 SWEEPS: tl.constexpr,
336):
337 pid = tl.program_id(0)
338 rows = tl.arange(0, BLOCK_R)
339 s_idx = tl.arange(0, BLOCK_K)
340 row_mask = rows < ROWS
341 eps = 1.0e-20
343 a_base = A + pid * M * N
344 aw_base = A_WORK + pid * K * ROWS
346 for j in tl.static_range(0, K):
347 if TALL:
348 vals = tl.load(a_base + rows * N + j, mask=row_mask, other=0.0).to(
349 tl.float32
350 )
351 else:
352 vals = tl.load(a_base + j * N + rows, mask=row_mask, other=0.0).to(
353 tl.float32
354 )
355 tl.store(aw_base + j * ROWS + rows, vals, mask=row_mask)
357 for _ in tl.static_range(0, SWEEPS):
358 for p in tl.static_range(0, K):
359 for q in tl.static_range(p + 1, K):
360 ap = tl.load(aw_base + p * ROWS + rows, mask=row_mask, other=0.0)
361 aq = tl.load(aw_base + q * ROWS + rows, mask=row_mask, other=0.0)
362 alpha = tl.sum(ap * ap)
363 beta = tl.sum(aq * aq)
364 gamma = tl.sum(ap * aq)
365 abs_gamma = tl.abs(gamma)
366 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps)
367 active = abs_gamma > threshold
369 safe_gamma = tl.where(active, gamma, 1.0)
370 tau = (beta - alpha) / (2.0 * safe_gamma)
371 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0)
372 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau))
373 c = tl.rsqrt(1.0 + t * t)
374 s_rot = t * c
375 c = tl.where(active, c, 1.0)
376 s_rot = tl.where(active, s_rot, 0.0)
378 new_ap = c * ap - s_rot * aq
379 new_aq = s_rot * ap + c * aq
380 tl.store(aw_base + p * ROWS + rows, new_ap, mask=row_mask)
381 tl.store(aw_base + q * ROWS + rows, new_aq, mask=row_mask)
383 s_vals = tl.full((BLOCK_K,), 0.0, dtype=tl.float32)
384 for j in tl.static_range(0, K):
385 col = tl.load(aw_base + j * ROWS + rows, mask=row_mask, other=0.0)
386 norm = tl.sqrt(tl.sum(col * col))
387 s_vals = tl.where(s_idx == j, norm, s_vals)
389 ranks = tl.zeros((BLOCK_K,), dtype=tl.int32)
390 for i in tl.static_range(0, K):
391 si = tl.sum(tl.where(s_idx == i, s_vals, 0.0))
392 beats = ((si > s_vals) | ((si == s_vals) & (i < s_idx))) & (s_idx < K)
393 ranks = ranks + beats.to(tl.int32)
395 for j in tl.static_range(0, K):
396 norm = tl.sum(tl.where(s_idx == j, s_vals, 0.0))
397 rank = tl.sum(tl.where(s_idx == j, ranks, 0))
398 tl.store(S + pid * K + rank, norm)
401def _can_use_streaming_jacobi_kernel(input, some=True, compute_uv=True):
402 _, m, n = _svd_shape(input)
403 return (
404 _is_float32_cuda_matrix(input)
405 and some
406 and compute_uv
407 and 16 < min(m, n) <= 64
408 and max(m, n) <= 1024
409 )
412def _can_use_gram_kernel(input, some=True, compute_uv=True):
413 _, m, n = _svd_shape(input)
414 return _is_float32_cuda_matrix(input) and some and compute_uv and min(m, n) <= 1024
417@libentry()
418@triton.jit
419def _triton_bmm_kernel(
420 A,
421 B,
422 C,
423 stride_ab,
424 stride_am,
425 stride_ak,
426 stride_bb,
427 stride_bk,
428 stride_bn,
429 M: tl.constexpr,
430 N: tl.constexpr,
431 K: tl.constexpr,
432 BLOCK_M: tl.constexpr,
433 BLOCK_N: tl.constexpr,
434 BLOCK_K: tl.constexpr,
435):
436 tile = tl.program_id(0)
437 batch = tl.program_id(1)
438 tiles_n = tl.cdiv(N, BLOCK_N)
439 tile_m = tile // tiles_n
440 tile_n = tile - tile_m * tiles_n
442 offs_m = tile_m * BLOCK_M + tl.arange(0, BLOCK_M)
443 offs_n = tile_n * BLOCK_N + tl.arange(0, BLOCK_N)
444 offs_k = tl.arange(0, BLOCK_K)
446 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
447 a_base = A + batch * stride_ab
448 b_base = B + batch * stride_bb
449 for k_start in range(0, K, BLOCK_K):
450 k = k_start + offs_k
451 a = tl.load(
452 a_base + offs_m[:, None] * stride_am + k[None, :] * stride_ak,
453 mask=(offs_m[:, None] < M) & (k[None, :] < K),
454 other=0.0,
455 )
456 b = tl.load(
457 b_base + k[:, None] * stride_bk + offs_n[None, :] * stride_bn,
458 mask=(k[:, None] < K) & (offs_n[None, :] < N),
459 other=0.0,
460 )
461 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
463 tl.store(
464 C + batch * M * N + offs_m[:, None] * N + offs_n[None, :],
465 acc,
466 mask=(offs_m[:, None] < M) & (offs_n[None, :] < N),
467 )
470def _triton_bmm(left, right, out_shape):
471 batch, m, k = left.shape
472 right_batch, right_k, n = right.shape
473 assert batch == right_batch, "Batch dim mismatch"
474 assert k == right_k, "K dim mismatch"
475 out = torch.empty((batch, m, n), dtype=left.dtype, device=left.device)
476 block_m = 16 if m <= 16 else 32
477 block_n = 16 if n <= 16 else 32
478 block_k = 32
479 grid = (triton.cdiv(m, block_m) * triton.cdiv(n, block_n), batch)
480 with torch_device_fn.device(left.device):
481 _triton_bmm_kernel[grid](
482 left,
483 right,
484 out,
485 left.stride(0),
486 left.stride(1),
487 left.stride(2),
488 right.stride(0),
489 right.stride(1),
490 right.stride(2),
491 M=m,
492 N=n,
493 K=k,
494 BLOCK_M=block_m,
495 BLOCK_N=block_n,
496 BLOCK_K=block_k,
497 num_warps=1 if block_m == 16 and block_n == 16 else 4,
498 )
499 return out.reshape(out_shape)
502@libentry()
503@triton.jit
504def _gram_build_tiled_kernel(
505 A,
506 GRAM,
507 M: tl.constexpr,
508 N: tl.constexpr,
509 K: tl.constexpr,
510 ROWS: tl.constexpr,
511 TALL: tl.constexpr,
512 BLOCK_I: tl.constexpr,
513 BLOCK_J: tl.constexpr,
514 BLOCK_R: tl.constexpr,
515):
516 tile_i = tl.program_id(0)
517 tile_j = tl.program_id(1)
518 batch = tl.program_id(2)
519 offs_i = tile_i * BLOCK_I + tl.arange(0, BLOCK_I)
520 offs_j = tile_j * BLOCK_J + tl.arange(0, BLOCK_J)
521 rows = tl.arange(0, BLOCK_R)
522 i_mask = offs_i < K
523 j_mask = offs_j < K
524 a_base = A + batch * M * N
525 acc = tl.zeros((BLOCK_I, BLOCK_J), dtype=tl.float32)
527 for row_start in range(0, ROWS, BLOCK_R):
528 chunk_rows = row_start + rows
529 row_mask = chunk_rows < ROWS
530 if TALL:
531 lhs = tl.load(
532 a_base + chunk_rows[None, :] * N + offs_i[:, None],
533 mask=i_mask[:, None] & row_mask[None, :],
534 other=0.0,
535 ).to(tl.float32)
536 rhs = tl.load(
537 a_base + chunk_rows[:, None] * N + offs_j[None, :],
538 mask=row_mask[:, None] & j_mask[None, :],
539 other=0.0,
540 ).to(tl.float32)
541 else:
542 lhs = tl.load(
543 a_base + offs_i[:, None] * N + chunk_rows[None, :],
544 mask=i_mask[:, None] & row_mask[None, :],
545 other=0.0,
546 ).to(tl.float32)
547 rhs = tl.load(
548 a_base + offs_j[None, :] * N + chunk_rows[:, None],
549 mask=row_mask[:, None] & j_mask[None, :],
550 other=0.0,
551 ).to(tl.float32)
552 acc += tl.dot(lhs, rhs, out_dtype=tl.float32, allow_tf32=False)
554 tl.store(
555 GRAM + batch * K * K + offs_i[:, None] * K + offs_j[None, :],
556 acc,
557 mask=i_mask[:, None] & j_mask[None, :],
558 )
561@libentry()
562@triton.jit
563def _cholesky_upper_kernel(
564 GRAM,
565 R,
566 STATUS,
567 K: tl.constexpr,
568 BLOCK_K: tl.constexpr,
569):
570 batch = tl.program_id(0)
571 cols = tl.arange(0, BLOCK_K)
572 col_mask = cols < K
573 base_g = GRAM + batch * K * K
574 base_r = R + batch * K * K
576 diag_vals = tl.load(
577 base_g + cols * K + cols,
578 mask=col_mask,
579 other=0.0,
580 ).to(tl.float32)
581 max_diag = tl.max(tl.abs(diag_vals), axis=0)
582 tol = tl.maximum(max_diag * 1.0e-8, 1.0e-20)
583 status = tl.full((), 0, dtype=tl.int32)
584 finite_limit = 3.4028234663852886e38
586 j = 0
587 while j < K:
588 row_mask = col_mask & (cols >= j)
589 gram_row = tl.load(
590 base_g + j * K + cols,
591 mask=row_mask,
592 other=0.0,
593 ).to(tl.float32)
594 diag = tl.load(base_g + j * K + j).to(tl.float32)
596 p = 0
597 while p < j:
598 r_pj = tl.load(base_r + p * K + j).to(tl.float32)
599 r_pcols = tl.load(
600 base_r + p * K + cols,
601 mask=row_mask,
602 other=0.0,
603 ).to(tl.float32)
604 gram_row -= r_pj * r_pcols
605 diag -= r_pj * r_pj
606 p += 1
608 good_diag = (diag == diag) & (tl.abs(diag) < finite_limit) & (diag > tol)
609 pivot = tl.sqrt(tl.maximum(diag, tol))
610 r_vals = gram_row / pivot
611 r_vals = tl.where(cols == j, pivot, r_vals)
612 r_vals = tl.where(row_mask, r_vals, 0.0)
613 bad_vals = tl.sum(
614 (((r_vals != r_vals) | (tl.abs(r_vals) >= finite_limit)) & row_mask).to(
615 tl.int32
616 ),
617 axis=0,
618 )
619 status = tl.where(good_diag & (bad_vals == 0), status, 1)
620 tl.store(base_r + j * K + cols, r_vals, mask=col_mask)
621 j += 1
623 tl.store(STATUS + batch, status)
626def _tsqr_guard_fallback_svd(input):
627 _, m, n = _svd_shape(input)
628 k = min(m, n)
629 if 16 <= k <= 64 and max(m, n) <= 1024:
630 return _cyclic_jacobi_svd(input)
631 if 64 < k <= 512 and max(m, n) <= 1024:
632 return _blocked_jacobi_svd(input)
633 return _unsupported_svd(
634 input,
635 True,
636 True,
637 "TSQR/Cholesky guard could not find a native Jacobi fallback.",
638 )
641def _tsqr_cholesky_svd(input):
642 batch, m, n = _svd_shape(input)
643 k = min(m, n)
644 rows = max(m, n)
645 tall = m >= n
646 a = input.contiguous().reshape(batch, m, n)
647 gram = torch.empty((batch, k, k), dtype=torch.float32, device=input.device)
648 r = torch.empty((batch, k, k), dtype=torch.float32, device=input.device)
649 status = torch.empty((batch,), dtype=torch.int32, device=input.device)
650 block_k = triton.next_power_of_2(k)
651 block_tile = 32
652 block_r = 64
654 with torch_device_fn.device(input.device):
655 _gram_build_tiled_kernel[
656 (
657 triton.cdiv(k, block_tile),
658 triton.cdiv(k, block_tile),
659 batch,
660 )
661 ](
662 a,
663 gram,
664 M=m,
665 N=n,
666 K=k,
667 ROWS=rows,
668 TALL=tall,
669 BLOCK_I=block_tile,
670 BLOCK_J=block_tile,
671 BLOCK_R=block_r,
672 num_warps=4,
673 )
674 _cholesky_upper_kernel[(batch,)](
675 gram,
676 r,
677 status,
678 K=k,
679 BLOCK_K=block_k,
680 num_warps=4,
681 )
683 _, s, basis = svd(r, some=True, compute_uv=True)
684 basis = basis.reshape(batch, k, k)
685 s = s.reshape(batch, k)
687 if tall:
688 u = _triton_bmm(a, basis, (batch, m, k))
689 v = basis
690 projected = u
691 projected_rows = m
692 else:
693 u = basis
694 v = _triton_bmm(a.transpose(1, 2).contiguous(), basis, (batch, n, k))
695 projected = v
696 projected_rows = n
698 with torch_device_fn.device(input.device):
699 _normalize_projection_kernel[(batch, k)](
700 projected,
701 s,
702 ROWS=projected_rows,
703 K=k,
704 BLOCK_R=triton.next_power_of_2(projected_rows),
705 num_warps=1 if projected_rows <= 64 else 4,
706 )
708 return (
709 u.reshape(*input.shape[:-2], m, k),
710 s.reshape(*input.shape[:-2], k),
711 v.reshape(*input.shape[:-2], n, k),
712 )
715@libentry()
716@triton.jit
717def _gram_build_kernel(
718 A,
719 GRAM,
720 M: tl.constexpr,
721 N: tl.constexpr,
722 K: tl.constexpr,
723 ROWS: tl.constexpr,
724 TALL: tl.constexpr,
725 BLOCK_K: tl.constexpr,
726 BLOCK_R: tl.constexpr,
727):
728 batch = tl.program_id(0)
729 i = tl.arange(0, BLOCK_K)
730 j = tl.arange(0, BLOCK_K)
731 rows = tl.arange(0, BLOCK_R)
732 k_mask = i < K
733 a_base = A + batch * M * N
734 acc = tl.zeros((BLOCK_K, BLOCK_K), dtype=tl.float32)
736 for row_start in range(0, ROWS, BLOCK_R):
737 chunk_rows = row_start + rows
738 row_mask = chunk_rows < ROWS
740 if TALL:
741 lhs = tl.load(
742 a_base + chunk_rows[None, :] * N + i[:, None],
743 mask=k_mask[:, None] & row_mask[None, :],
744 other=0.0,
745 ).to(tl.float32)
746 rhs = tl.load(
747 a_base + chunk_rows[:, None] * N + j[None, :],
748 mask=row_mask[:, None] & (j[None, :] < K),
749 other=0.0,
750 ).to(tl.float32)
751 else:
752 lhs = tl.load(
753 a_base + i[:, None] * N + chunk_rows[None, :],
754 mask=k_mask[:, None] & row_mask[None, :],
755 other=0.0,
756 ).to(tl.float32)
757 rhs = tl.load(
758 a_base + j[None, :] * N + chunk_rows[:, None],
759 mask=row_mask[:, None] & (j[None, :] < K),
760 other=0.0,
761 ).to(tl.float32)
763 acc += tl.dot(lhs, rhs, out_dtype=tl.float32, allow_tf32=False)
765 tl.store(
766 GRAM + batch * K * K + i[:, None] * K + j[None, :],
767 acc,
768 mask=k_mask[:, None] & (j[None, :] < K),
769 )
772@libentry()
773@triton.jit
774def _gram_jacobi_sym_kernel(
775 GRAM,
776 EVECS,
777 EVALS,
778 K,
779 SWEEPS,
780 BLOCK_K: tl.constexpr,
781):
782 batch = tl.program_id(0)
783 r = tl.arange(0, BLOCK_K)
784 cidx = tl.arange(0, BLOCK_K)
785 rr = r[:, None]
786 cc = cidx[None, :]
787 mask = (rr < K) & (cc < K)
788 base = GRAM + batch * K * K
790 g = tl.load(base + rr * K + cc, mask=mask, other=0.0).to(tl.float32)
791 v = tl.where((rr == cc) & mask, 1.0, 0.0)
792 eps = 1.0e-20
794 sweep = 0
795 while sweep < SWEEPS:
796 p = 0
797 while p < K - 1:
798 q = p + 1
799 while q < K:
800 diag_p = tl.sum(tl.where((rr == p) & (cc == p), g, 0.0))
801 diag_q = tl.sum(tl.where((rr == q) & (cc == q), g, 0.0))
802 off = tl.sum(tl.where((rr == p) & (cc == q), g, 0.0))
803 abs_off = tl.abs(off)
804 threshold = 1.0e-7 * tl.sqrt(tl.abs(diag_p * diag_q) + eps)
805 active = abs_off > threshold
806 safe_off = tl.where(active, off, 1.0)
807 tau = (diag_q - diag_p) / (2.0 * safe_off)
808 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0)
809 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau))
810 crot = tl.rsqrt(1.0 + t * t)
811 srot = t * crot
812 crot = tl.where(active, crot, 1.0)
813 srot = tl.where(active, srot, 0.0)
815 col_p = tl.sum(tl.where(cc == p, g, 0.0), axis=1)
816 col_q = tl.sum(tl.where(cc == q, g, 0.0), axis=1)
817 row_p = tl.sum(tl.where(rr == p, g, 0.0), axis=0)
818 row_q = tl.sum(tl.where(rr == q, g, 0.0), axis=0)
820 new_col_p = crot * col_p - srot * col_q
821 new_col_q = srot * col_p + crot * col_q
822 new_row_p = crot * row_p - srot * row_q
823 new_row_q = srot * row_p + crot * row_q
824 g = tl.where(cc == p, new_col_p[:, None], g)
825 g = tl.where(cc == q, new_col_q[:, None], g)
826 g = tl.where(rr == p, new_row_p[None, :], g)
827 g = tl.where(rr == q, new_row_q[None, :], g)
829 new_pp = (
830 crot * crot * diag_p
831 - 2.0 * crot * srot * off
832 + srot * srot * diag_q
833 )
834 new_qq = (
835 srot * srot * diag_p
836 + 2.0 * crot * srot * off
837 + crot * crot * diag_q
838 )
839 g = tl.where((rr == p) & (cc == p), new_pp, g)
840 g = tl.where((rr == q) & (cc == q), new_qq, g)
841 g = tl.where(((rr == p) & (cc == q)) | ((rr == q) & (cc == p)), 0.0, g)
843 vec_p = tl.sum(tl.where(cc == p, v, 0.0), axis=1)
844 vec_q = tl.sum(tl.where(cc == q, v, 0.0), axis=1)
845 new_vec_p = crot * vec_p - srot * vec_q
846 new_vec_q = srot * vec_p + crot * vec_q
847 v = tl.where(cc == p, new_vec_p[:, None], v)
848 v = tl.where(cc == q, new_vec_q[:, None], v)
849 q += 1
850 p += 1
851 sweep += 1
853 diag = tl.sum(tl.where(rr == cc, g, 0.0), axis=1)
854 tl.store(EVALS + batch * K + r, diag, mask=r < K)
855 tl.store(EVECS + batch * K * K + rr * K + cc, v, mask=mask)
858@libentry()
859@triton.jit
860def _gram_sort_basis_kernel(
861 EVALS,
862 EVECS,
863 BASIS,
864 S,
865 K: tl.constexpr,
866 BLOCK_K: tl.constexpr,
867):
868 batch = tl.program_id(0)
869 col = tl.program_id(1)
870 rows = tl.arange(0, BLOCK_K)
871 row_mask = rows < K
872 eval_col = tl.maximum(tl.load(EVALS + batch * K + col), 0.0)
873 rank = tl.full((), 0, dtype=tl.int32)
874 for other in tl.static_range(0, K):
875 eval_other = tl.maximum(tl.load(EVALS + batch * K + other), 0.0)
876 rank += (
877 (eval_other > eval_col) | ((eval_other == eval_col) & (other < col))
878 ).to(tl.int32)
880 vec = tl.load(
881 EVECS + batch * K * K + rows * K + col,
882 mask=row_mask,
883 other=0.0,
884 )
885 tl.store(S + batch * K + rank, tl.sqrt(eval_col))
886 tl.store(
887 BASIS + batch * K * K + rows * K + rank,
888 vec,
889 mask=row_mask,
890 )
893@libentry()
894@triton.jit
895def _normalize_projection_kernel(
896 Q,
897 S,
898 ROWS: tl.constexpr,
899 K: tl.constexpr,
900 BLOCK_R: tl.constexpr,
901):
902 batch = tl.program_id(0)
903 col = tl.program_id(1)
904 rows = tl.arange(0, BLOCK_R)
905 mask = rows < ROWS
906 eps = 1.0e-20
907 sval = tl.load(S + batch * K + col)
908 vals = tl.load(Q + batch * ROWS * K + rows * K + col, mask=mask, other=0.0)
909 vals = vals / tl.maximum(sval, eps)
910 tl.store(Q + batch * ROWS * K + rows * K + col, vals, mask=mask)
913@libentry()
914@triton.jit
915def _renorm_projection_update_s_kernel(
916 Q,
917 S,
918 ROWS: tl.constexpr,
919 K: tl.constexpr,
920 BLOCK_R: tl.constexpr,
921):
922 batch = tl.program_id(0)
923 col = tl.program_id(1)
924 rows = tl.arange(0, BLOCK_R)
925 mask = rows < ROWS
926 vals = tl.load(Q + batch * ROWS * K + rows * K + col, mask=mask, other=0.0)
927 vals_f32 = vals.to(tl.float32)
928 norm = tl.sqrt(tl.sum(vals_f32 * vals_f32, axis=0))
929 inv_norm = tl.rsqrt(tl.maximum(norm * norm, 1.0e-40))
930 basis = tl.where(rows == col, 1.0, 0.0)
931 vals = tl.where(norm <= 1.0e-20, basis, vals * inv_norm)
932 tl.store(S + batch * K + col, norm)
933 tl.store(Q + batch * ROWS * K + rows * K + col, vals, mask=mask)
936@libentry()
937@triton.jit
938def _complete_zero_projection_kernel(
939 Q,
940 S,
941 ROWS: tl.constexpr,
942 K: tl.constexpr,
943 BLOCK_R: tl.constexpr,
944):
945 batch = tl.program_id(0)
946 col = tl.program_id(1)
947 rows = tl.arange(0, BLOCK_R)
948 mask = rows < ROWS
949 eps = 1.0e-12
950 sval = tl.load(S + batch * K + col)
951 basis = tl.where(rows == col, 1.0, 0.0)
952 old = tl.load(Q + batch * ROWS * K + rows * K + col, mask=mask, other=0.0)
953 vals = tl.where(sval <= eps, basis, old)
954 tl.store(Q + batch * ROWS * K + rows * K + col, vals, mask=mask)
957def _gram_jacobi_svd(input):
958 batch, m, n = _svd_shape(input)
959 k = min(m, n)
960 rows = max(m, n)
961 tall = m >= n
962 a = input.contiguous().reshape(batch, m, n)
963 gram = torch.empty((batch, k, k), dtype=torch.float32, device=input.device)
964 eigvecs = torch.empty((batch, k, k), dtype=torch.float32, device=input.device)
965 evals = torch.empty((batch, k), dtype=torch.float32, device=input.device)
966 basis = torch.empty((batch, k, k), dtype=torch.float32, device=input.device)
967 s = torch.empty((batch, k), dtype=input.dtype, device=input.device)
968 block_k = triton.next_power_of_2(k)
969 block_r = min(triton.next_power_of_2(rows), 64 if k > 32 else 128)
970 sweeps = 12 if k <= 17 else 10
972 with torch_device_fn.device(input.device):
973 _gram_build_kernel[(batch,)](
974 a,
975 gram,
976 M=m,
977 N=n,
978 K=k,
979 ROWS=rows,
980 TALL=tall,
981 BLOCK_K=block_k,
982 BLOCK_R=block_r,
983 num_warps=4,
984 )
985 _gram_jacobi_sym_kernel[(batch,)](
986 gram,
987 eigvecs,
988 evals,
989 k,
990 sweeps,
991 BLOCK_K=block_k,
992 num_warps=4,
993 )
994 with torch_device_fn.device(input.device):
995 _gram_sort_basis_kernel[(batch, k)](
996 evals,
997 eigvecs,
998 basis,
999 s,
1000 K=k,
1001 BLOCK_K=block_k,
1002 num_warps=1,
1003 )
1005 if tall:
1006 u = _triton_bmm(a, basis, (batch, m, k))
1007 v = basis
1008 proj_rows = m
1009 else:
1010 a_t = a.transpose(1, 2).contiguous()
1011 v = _triton_bmm(a_t, basis, (batch, n, k))
1012 u = basis
1013 proj_rows = n
1015 with torch_device_fn.device(input.device):
1016 _renorm_projection_update_s_kernel[(batch, k)](
1017 u if tall else v,
1018 s,
1019 ROWS=proj_rows,
1020 K=k,
1021 BLOCK_R=triton.next_power_of_2(proj_rows),
1022 num_warps=1 if proj_rows <= 64 else 4,
1023 )
1024 _complete_zero_projection_kernel[(batch, k)](
1025 u if tall else v,
1026 s,
1027 ROWS=proj_rows,
1028 K=k,
1029 BLOCK_R=triton.next_power_of_2(proj_rows),
1030 num_warps=1 if proj_rows <= 64 else 4,
1031 )
1032 if k <= _GRAM_TALL_WIDE_MAX_K:
1033 _thin_reorthogonalize_kernel[(batch,)](
1034 u if tall else v,
1035 ROWS=proj_rows,
1036 K=k,
1037 BLOCK_R=triton.next_power_of_2(proj_rows),
1038 num_warps=1 if proj_rows <= 64 else 4,
1039 )
1041 return (
1042 u.reshape(*input.shape[:-2], m, k),
1043 s.reshape(*input.shape[:-2], k),
1044 v.reshape(*input.shape[:-2], n, k),
1045 )
1048@triton.jit
1049def _rotate_pair_4(ap, aq, vp, vq):
1050 eps = 1.0e-20
1051 alpha = tl.sum(ap * ap, axis=1)
1052 beta = tl.sum(aq * aq, axis=1)
1053 gamma = tl.sum(ap * aq, axis=1)
1054 abs_gamma = tl.abs(gamma)
1055 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps)
1056 active = abs_gamma > threshold
1057 safe_gamma = tl.where(active, gamma, 1.0)
1058 tau = (beta - alpha) / (2.0 * safe_gamma)
1059 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0)
1060 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau))
1061 c = tl.rsqrt(1.0 + t * t)
1062 s_rot = t * c
1063 c = tl.where(active, c, 1.0)
1064 s_rot = tl.where(active, s_rot, 0.0)
1065 new_ap = c[:, None] * ap - s_rot[:, None] * aq
1066 new_aq = s_rot[:, None] * ap + c[:, None] * aq
1067 new_vp = c[:, None] * vp - s_rot[:, None] * vq
1068 new_vq = s_rot[:, None] * vp + c[:, None] * vq
1069 return new_ap, new_aq, new_vp, new_vq
1072@libentry()
1073@triton.jit
1074def _small4_square_svd_kernel(
1075 A,
1076 U,
1077 S,
1078 V,
1079 BATCH: tl.constexpr,
1080 BLOCK_B: tl.constexpr,
1081 SWEEPS: tl.constexpr,
1082):
1083 pid = tl.program_id(0)
1084 b = pid * BLOCK_B + tl.arange(0, BLOCK_B)
1085 r = tl.arange(0, 4)
1086 bb = b[:, None]
1087 rr = r[None, :]
1088 mask = b < BATCH
1089 full_mask = (bb < BATCH) & (rr < 4)
1090 base = A + bb * 16 + rr * 4
1092 c0 = tl.load(base, mask=full_mask, other=0.0).to(tl.float32)
1093 c1 = tl.load(base + 1, mask=full_mask, other=0.0).to(tl.float32)
1094 c2 = tl.load(base + 2, mask=full_mask, other=0.0).to(tl.float32)
1095 c3 = tl.load(base + 3, mask=full_mask, other=0.0).to(tl.float32)
1097 v0 = tl.where(rr == 0, 1.0, 0.0)
1098 v1 = tl.where(rr == 1, 1.0, 0.0)
1099 v2 = tl.where(rr == 2, 1.0, 0.0)
1100 v3 = tl.where(rr == 3, 1.0, 0.0)
1102 for _ in tl.static_range(0, SWEEPS):
1103 c0, c1, v0, v1 = _rotate_pair_4(c0, c1, v0, v1)
1104 c0, c2, v0, v2 = _rotate_pair_4(c0, c2, v0, v2)
1105 c0, c3, v0, v3 = _rotate_pair_4(c0, c3, v0, v3)
1106 c1, c2, v1, v2 = _rotate_pair_4(c1, c2, v1, v2)
1107 c1, c3, v1, v3 = _rotate_pair_4(c1, c3, v1, v3)
1108 c2, c3, v2, v3 = _rotate_pair_4(c2, c3, v2, v3)
1110 s0 = tl.sqrt(tl.sum(c0 * c0, axis=1))
1111 s1 = tl.sqrt(tl.sum(c1 * c1, axis=1))
1112 s2 = tl.sqrt(tl.sum(c2 * c2, axis=1))
1113 s3 = tl.sqrt(tl.sum(c3 * c3, axis=1))
1114 r0 = (s1 > s0).to(tl.int32) + (s2 > s0).to(tl.int32) + (s3 > s0).to(tl.int32)
1115 r1 = ((s0 >= s1).to(tl.int32)) + (s2 > s1).to(tl.int32) + (s3 > s1).to(tl.int32)
1116 r2 = ((s0 >= s2).to(tl.int32)) + ((s1 >= s2).to(tl.int32)) + (s3 > s2).to(tl.int32)
1117 r3 = (
1118 ((s0 >= s3).to(tl.int32))
1119 + ((s1 >= s3).to(tl.int32))
1120 + ((s2 >= s3).to(tl.int32))
1121 )
1122 eps = 1.0e-20
1124 tl.store(S + b * 4 + r0, s0, mask=mask)
1125 tl.store(S + b * 4 + r1, s1, mask=mask)
1126 tl.store(S + b * 4 + r2, s2, mask=mask)
1127 tl.store(S + b * 4 + r3, s3, mask=mask)
1129 tl.store(
1130 U + bb * 16 + rr * 4 + r0[:, None],
1131 c0 / tl.maximum(s0[:, None], eps),
1132 mask=full_mask,
1133 )
1134 tl.store(
1135 U + bb * 16 + rr * 4 + r1[:, None],
1136 c1 / tl.maximum(s1[:, None], eps),
1137 mask=full_mask,
1138 )
1139 tl.store(
1140 U + bb * 16 + rr * 4 + r2[:, None],
1141 c2 / tl.maximum(s2[:, None], eps),
1142 mask=full_mask,
1143 )
1144 tl.store(
1145 U + bb * 16 + rr * 4 + r3[:, None],
1146 c3 / tl.maximum(s3[:, None], eps),
1147 mask=full_mask,
1148 )
1150 tl.store(V + bb * 16 + rr * 4 + r0[:, None], v0, mask=full_mask)
1151 tl.store(V + bb * 16 + rr * 4 + r1[:, None], v1, mask=full_mask)
1152 tl.store(V + bb * 16 + rr * 4 + r2[:, None], v2, mask=full_mask)
1153 tl.store(V + bb * 16 + rr * 4 + r3[:, None], v3, mask=full_mask)
1156@libentry()
1157@triton.jit
1158def _rank2_svd_tiny_kernel(
1159 A,
1160 U,
1161 S,
1162 V,
1163 BATCH: tl.constexpr,
1164 M: tl.constexpr,
1165 N: tl.constexpr,
1166 TALL: tl.constexpr,
1167 BLOCK_B: tl.constexpr,
1168 BLOCK_R: tl.constexpr,
1169):
1170 pid = tl.program_id(0)
1171 b = pid * BLOCK_B + tl.arange(0, BLOCK_B)
1172 r = tl.arange(0, BLOCK_R)
1173 bb = b[:, None]
1174 rr = r[None, :]
1175 bmask = b < BATCH
1176 eps = 1.0e-20
1178 if TALL:
1179 mask = (bb < BATCH) & (rr < M)
1180 base = A + bb * M * N + rr * N
1181 x = tl.load(base, mask=mask, other=0.0).to(tl.float32)
1182 y = tl.load(base + 1, mask=mask, other=0.0).to(tl.float32)
1183 else:
1184 mask = (bb < BATCH) & (rr < N)
1185 base = A + bb * M * N + rr
1186 x = tl.load(base, mask=mask, other=0.0).to(tl.float32)
1187 y = tl.load(base + N, mask=mask, other=0.0).to(tl.float32)
1189 aa = tl.sum(x * x, axis=1)
1190 bbv = tl.sum(y * y, axis=1)
1191 ab = tl.sum(x * y, axis=1)
1192 diff = aa - bbv
1193 root = tl.sqrt(diff * diff + 4.0 * ab * ab)
1194 l0 = tl.maximum(0.0, 0.5 * (aa + bbv + root))
1195 det = tl.maximum(0.0, aa * bbv - ab * ab)
1196 l1 = tl.where(l0 > eps, det / l0, 0.0)
1197 s0 = tl.sqrt(l0)
1198 s1 = tl.sqrt(l1)
1200 ab_abs = tl.abs(ab)
1201 aa_ge_bb = aa >= bbv
1202 vx0 = tl.where(ab_abs > eps, ab, tl.where(aa_ge_bb, 1.0, 0.0))
1203 vy0 = tl.where(ab_abs > eps, l0 - aa, tl.where(aa_ge_bb, 0.0, 1.0))
1204 inv_norm = tl.rsqrt(vx0 * vx0 + vy0 * vy0 + eps)
1205 vx0 = vx0 * inv_norm
1206 vy0 = vy0 * inv_norm
1207 vx1 = -vy0
1208 vy1 = vx0
1210 tl.store(S + b * 2, s0, mask=bmask)
1211 tl.store(S + b * 2 + 1, s1, mask=bmask)
1212 inv_s0 = tl.where(s0 > eps, 1.0 / s0, 0.0)
1213 inv_s1 = tl.where(s1 > eps, 1.0 / s1, 0.0)
1215 if TALL:
1216 u0 = (x * vx0[:, None] + y * vy0[:, None]) * inv_s0[:, None]
1217 u1 = (x * vx1[:, None] + y * vy1[:, None]) * inv_s1[:, None]
1218 ubase = U + bb * M * 2 + rr * 2
1219 tl.store(ubase, u0, mask=mask)
1220 tl.store(ubase + 1, u1, mask=mask)
1221 vbase = V + b * 4
1222 tl.store(vbase, vx0, mask=bmask)
1223 tl.store(vbase + 1, vx1, mask=bmask)
1224 tl.store(vbase + 2, vy0, mask=bmask)
1225 tl.store(vbase + 3, vy1, mask=bmask)
1226 else:
1227 ubase = U + b * 4
1228 tl.store(ubase, vx0, mask=bmask)
1229 tl.store(ubase + 1, vx1, mask=bmask)
1230 tl.store(ubase + 2, vy0, mask=bmask)
1231 tl.store(ubase + 3, vy1, mask=bmask)
1232 v0 = (x * vx0[:, None] + y * vy0[:, None]) * inv_s0[:, None]
1233 v1 = (x * vx1[:, None] + y * vy1[:, None]) * inv_s1[:, None]
1234 vbase = V + bb * N * 2 + rr * 2
1235 tl.store(vbase, v0, mask=mask)
1236 tl.store(vbase + 1, v1, mask=mask)
1239@libentry()
1240@triton.jit
1241def _rank2_svals_tiny_kernel(
1242 A,
1243 S,
1244 BATCH: tl.constexpr,
1245 M: tl.constexpr,
1246 N: tl.constexpr,
1247 TALL: tl.constexpr,
1248 BLOCK_B: tl.constexpr,
1249 BLOCK_R: tl.constexpr,
1250):
1251 pid = tl.program_id(0)
1252 b = pid * BLOCK_B + tl.arange(0, BLOCK_B)
1253 r = tl.arange(0, BLOCK_R)
1254 bb = b[:, None]
1255 rr = r[None, :]
1256 bmask = b < BATCH
1258 if TALL:
1259 mask = (bb < BATCH) & (rr < M)
1260 base = A + bb * M * N + rr * N
1261 x = tl.load(base, mask=mask, other=0.0).to(tl.float32)
1262 y = tl.load(base + 1, mask=mask, other=0.0).to(tl.float32)
1263 else:
1264 mask = (bb < BATCH) & (rr < N)
1265 base = A + bb * M * N + rr
1266 x = tl.load(base, mask=mask, other=0.0).to(tl.float32)
1267 y = tl.load(base + N, mask=mask, other=0.0).to(tl.float32)
1269 aa = tl.sum(x * x, axis=1)
1270 bbv = tl.sum(y * y, axis=1)
1271 ab = tl.sum(x * y, axis=1)
1272 diff = aa - bbv
1273 root = tl.sqrt(diff * diff + 4.0 * ab * ab)
1274 l0 = tl.maximum(0.0, 0.5 * (aa + bbv + root))
1275 det = tl.maximum(0.0, aa * bbv - ab * ab)
1276 l1 = tl.where(l0 > 1.0e-20, det / l0, 0.0)
1277 tl.store(S + b * 2, tl.sqrt(l0), mask=bmask)
1278 tl.store(S + b * 2 + 1, tl.sqrt(l1), mask=bmask)
1281@libentry()
1282@triton.jit
1283def _rank2_svals_kernel(
1284 A,
1285 S,
1286 M: tl.constexpr,
1287 N: tl.constexpr,
1288 TALL: tl.constexpr,
1289 BLOCK_R: tl.constexpr,
1290):
1291 pid = tl.program_id(0)
1292 offs = tl.arange(0, BLOCK_R)
1294 if TALL:
1295 mask = offs < M
1296 base = A + pid * M * N
1297 x = tl.load(base + offs * N, mask=mask, other=0.0).to(tl.float32)
1298 y = tl.load(base + offs * N + 1, mask=mask, other=0.0).to(tl.float32)
1299 else:
1300 mask = offs < N
1301 base = A + pid * M * N
1302 x = tl.load(base + offs, mask=mask, other=0.0).to(tl.float32)
1303 y = tl.load(base + N + offs, mask=mask, other=0.0).to(tl.float32)
1305 aa = tl.sum(x * x)
1306 bb = tl.sum(y * y)
1307 ab = tl.sum(x * y)
1308 diff = aa - bb
1309 root = tl.sqrt(diff * diff + 4.0 * ab * ab)
1310 l0 = tl.maximum(0.0, 0.5 * (aa + bb + root))
1311 det = tl.maximum(0.0, aa * bb - ab * ab)
1312 l1 = tl.where(l0 > 1.0e-20, det / l0, 0.0)
1314 sbase = S + pid * 2
1315 tl.store(sbase, tl.sqrt(l0))
1316 tl.store(sbase + 1, tl.sqrt(l1))
1319@libentry()
1320@triton.jit
1321def _rank2_svd_kernel(
1322 A,
1323 U,
1324 S,
1325 V,
1326 M: tl.constexpr,
1327 N: tl.constexpr,
1328 TALL: tl.constexpr,
1329 BLOCK_R: tl.constexpr,
1330):
1331 pid = tl.program_id(0)
1332 offs = tl.arange(0, BLOCK_R)
1333 eps = 1.0e-20
1335 if TALL:
1336 mask = offs < M
1337 base = A + pid * M * N
1338 x = tl.load(base + offs * N, mask=mask, other=0.0).to(tl.float32)
1339 y = tl.load(base + offs * N + 1, mask=mask, other=0.0).to(tl.float32)
1340 else:
1341 mask = offs < N
1342 base = A + pid * M * N
1343 x = tl.load(base + offs, mask=mask, other=0.0).to(tl.float32)
1344 y = tl.load(base + N + offs, mask=mask, other=0.0).to(tl.float32)
1346 aa = tl.sum(x * x)
1347 bb = tl.sum(y * y)
1348 ab = tl.sum(x * y)
1349 diff = aa - bb
1350 root = tl.sqrt(diff * diff + 4.0 * ab * ab)
1351 l0 = tl.maximum(0.0, 0.5 * (aa + bb + root))
1352 det = tl.maximum(0.0, aa * bb - ab * ab)
1353 l1 = tl.where(l0 > eps, det / l0, 0.0)
1354 s0 = tl.sqrt(l0)
1355 s1 = tl.sqrt(l1)
1357 ab_abs = tl.abs(ab)
1358 aa_ge_bb = aa >= bb
1359 vx0 = tl.where(ab_abs > eps, ab, tl.where(aa_ge_bb, 1.0, 0.0))
1360 vy0 = tl.where(ab_abs > eps, l0 - aa, tl.where(aa_ge_bb, 0.0, 1.0))
1361 inv_norm = tl.rsqrt(vx0 * vx0 + vy0 * vy0 + eps)
1362 vx0 = vx0 * inv_norm
1363 vy0 = vy0 * inv_norm
1364 vx1 = -vy0
1365 vy1 = vx0
1367 sbase = S + pid * 2
1368 tl.store(sbase, s0)
1369 tl.store(sbase + 1, s1)
1371 inv_s0 = tl.where(s0 > eps, 1.0 / s0, 0.0)
1372 inv_s1 = tl.where(s1 > eps, 1.0 / s1, 0.0)
1374 if TALL:
1375 ubase = U + pid * M * 2
1376 u0 = (x * vx0 + y * vy0) * inv_s0
1377 basis0 = tl.where(offs == 0, 1.0, 0.0)
1378 basis1 = tl.where(offs == 1, 1.0, 0.0)
1379 u0 = tl.where(s0 > eps, u0, basis0)
1381 u1 = (x * vx1 + y * vy1) * inv_s1
1382 u0_first = tl.sum(tl.where(offs == 0, u0, 0.0))
1383 anchor = tl.where(tl.abs(u0_first) < 0.70710678, basis0, basis1)
1384 dot = tl.sum(anchor * u0)
1385 fallback_u1 = anchor - dot * u0
1386 fallback_norm = tl.sum(fallback_u1 * fallback_u1)
1387 fallback_u1 = fallback_u1 * tl.rsqrt(fallback_norm + eps)
1388 u1 = tl.where(s1 > s0 * 5.0e-4, u1, fallback_u1)
1389 tl.store(ubase + offs * 2, u0, mask=mask)
1390 tl.store(ubase + offs * 2 + 1, u1, mask=mask)
1392 vbase = V + pid * 4
1393 tl.store(vbase, vx0)
1394 tl.store(vbase + 1, vx1)
1395 tl.store(vbase + 2, vy0)
1396 tl.store(vbase + 3, vy1)
1397 else:
1398 ubase = U + pid * 4
1399 tl.store(ubase, vx0)
1400 tl.store(ubase + 1, vx1)
1401 tl.store(ubase + 2, vy0)
1402 tl.store(ubase + 3, vy1)
1404 vbase = V + pid * N * 2
1405 v0 = (x * vx0 + y * vy0) * inv_s0
1406 basis0 = tl.where(offs == 0, 1.0, 0.0)
1407 basis1 = tl.where(offs == 1, 1.0, 0.0)
1408 v0 = tl.where(s0 > eps, v0, basis0)
1410 v1 = (x * vx1 + y * vy1) * inv_s1
1411 v0_first = tl.sum(tl.where(offs == 0, v0, 0.0))
1412 anchor = tl.where(tl.abs(v0_first) < 0.70710678, basis0, basis1)
1413 dot = tl.sum(anchor * v0)
1414 fallback_v1 = anchor - dot * v0
1415 fallback_norm = tl.sum(fallback_v1 * fallback_v1)
1416 fallback_v1 = fallback_v1 * tl.rsqrt(fallback_norm + eps)
1417 v1 = tl.where(s1 > s0 * 5.0e-4, v1, fallback_v1)
1418 tl.store(vbase + offs * 2, v0, mask=mask)
1419 tl.store(vbase + offs * 2 + 1, v1, mask=mask)
1422def _rank2_svd(input):
1423 batch, m, n = _svd_shape(input)
1424 a = input.contiguous().reshape(batch, m, n)
1425 u = torch.empty((batch, m, 2), dtype=input.dtype, device=input.device)
1426 s = torch.empty((batch, 2), dtype=input.dtype, device=input.device)
1427 v = torch.empty((batch, n, 2), dtype=input.dtype, device=input.device)
1428 largest = max(m, n)
1429 block_r = triton.next_power_of_2(largest)
1430 with torch_device_fn.device(input.device):
1431 if largest <= 16 and batch >= 16:
1432 if largest <= 2:
1433 block_b = 8
1434 elif largest == 16:
1435 block_b = 2 if m >= n else 8
1436 else:
1437 block_b = 16
1438 _rank2_svd_tiny_kernel[(triton.cdiv(batch, block_b),)](
1439 a,
1440 u,
1441 s,
1442 v,
1443 BATCH=batch,
1444 M=m,
1445 N=n,
1446 TALL=m >= n,
1447 BLOCK_B=block_b,
1448 BLOCK_R=block_r,
1449 num_warps=1,
1450 )
1451 else:
1452 _rank2_svd_kernel[(batch,)](
1453 a,
1454 u,
1455 s,
1456 v,
1457 M=m,
1458 N=n,
1459 TALL=m >= n,
1460 BLOCK_R=block_r,
1461 num_warps=1 if block_r <= 64 else 4,
1462 )
1463 return (
1464 u.reshape(*input.shape[:-2], m, 2),
1465 s.reshape(*input.shape[:-2], 2),
1466 v.reshape(*input.shape[:-2], n, 2),
1467 )
1470def _rank2_singular_values(input):
1471 batch, m, n = _svd_shape(input)
1472 a = input.contiguous().reshape(batch, m, n)
1473 s = torch.empty((batch, 2), dtype=input.dtype, device=input.device)
1474 largest = max(m, n)
1475 block_r = triton.next_power_of_2(largest)
1476 with torch_device_fn.device(input.device):
1477 if largest <= 16 and batch >= 16:
1478 if largest <= 2:
1479 block_b = 8
1480 elif largest == 16:
1481 block_b = 2 if m >= n else 8
1482 else:
1483 block_b = 16
1484 _rank2_svals_tiny_kernel[(triton.cdiv(batch, block_b),)](
1485 a,
1486 s,
1487 BATCH=batch,
1488 M=m,
1489 N=n,
1490 TALL=m >= n,
1491 BLOCK_B=block_b,
1492 BLOCK_R=block_r,
1493 num_warps=1,
1494 )
1495 else:
1496 _rank2_svals_kernel[(batch,)](
1497 a,
1498 s,
1499 M=m,
1500 N=n,
1501 TALL=m >= n,
1502 BLOCK_R=block_r,
1503 num_warps=1 if block_r <= 64 else 4,
1504 )
1505 return s.reshape(*input.shape[:-2], 2)
1508def _small_jacobi_singular_values(input):
1509 batch, m, n = _svd_shape(input)
1510 k = min(m, n)
1511 rows = max(m, n)
1512 a = input.contiguous().reshape(batch, m, n)
1513 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device)
1514 s = torch.empty((batch, k), dtype=input.dtype, device=input.device)
1515 block_r = triton.next_power_of_2(rows)
1516 block_k = triton.next_power_of_2(k)
1517 sweeps = 3 if k <= 4 else 5
1518 with torch_device_fn.device(input.device):
1519 _small_jacobi_svals_kernel[(batch,)](
1520 a,
1521 a_work,
1522 s,
1523 M=m,
1524 N=n,
1525 K=k,
1526 ROWS=rows,
1527 TALL=m >= n,
1528 BLOCK_R=block_r,
1529 BLOCK_K=block_k,
1530 SWEEPS=sweeps,
1531 num_warps=1 if block_r <= 64 else 4,
1532 )
1533 return s.reshape(*input.shape[:-2], k)
1536def _small_jacobi_svd(input):
1537 batch, m, n = _svd_shape(input)
1538 k = min(m, n)
1539 rows = max(m, n)
1540 a = input.contiguous().reshape(batch, m, n)
1541 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device)
1542 v_work = torch.empty((batch, k, k), dtype=torch.float32, device=input.device)
1543 u = torch.empty((batch, m, k), dtype=input.dtype, device=input.device)
1544 s = torch.empty((batch, k), dtype=input.dtype, device=input.device)
1545 v = torch.empty((batch, n, k), dtype=input.dtype, device=input.device)
1546 block_r = triton.next_power_of_2(rows)
1547 block_k = triton.next_power_of_2(k)
1548 sweeps = 3 if k <= 4 else 5
1549 with torch_device_fn.device(input.device):
1550 _small_jacobi_svd_kernel[(batch,)](
1551 a,
1552 a_work,
1553 v_work,
1554 u,
1555 s,
1556 v,
1557 M=m,
1558 N=n,
1559 K=k,
1560 ROWS=rows,
1561 TALL=m >= n,
1562 BLOCK_R=block_r,
1563 BLOCK_K=block_k,
1564 SWEEPS=sweeps,
1565 num_warps=1 if block_r <= 64 else 4,
1566 )
1567 return (
1568 u.reshape(*input.shape[:-2], m, k),
1569 s.reshape(*input.shape[:-2], k),
1570 v.reshape(*input.shape[:-2], n, k),
1571 )
1574@libentry()
1575@triton.jit
1576def _cyclic_jacobi_init_a_kernel(
1577 A,
1578 A_WORK,
1579 M: tl.constexpr,
1580 N: tl.constexpr,
1581 K: tl.constexpr,
1582 ROWS: tl.constexpr,
1583 TALL: tl.constexpr,
1584 BLOCK_R: tl.constexpr,
1585):
1586 batch = tl.program_id(0)
1587 col = tl.program_id(1)
1588 rows = tl.arange(0, BLOCK_R)
1589 row_mask = rows < ROWS
1590 a_base = A + batch * M * N
1591 aw_base = A_WORK + batch * K * ROWS
1593 if TALL:
1594 vals = tl.load(a_base + rows * N + col, mask=row_mask, other=0.0).to(tl.float32)
1595 else:
1596 vals = tl.load(a_base + col * N + rows, mask=row_mask, other=0.0).to(tl.float32)
1597 tl.store(aw_base + col * ROWS + rows, vals, mask=row_mask)
1600@libentry()
1601@triton.jit
1602def _cyclic_jacobi_init_kernel(
1603 A,
1604 A_WORK,
1605 V_WORK,
1606 M: tl.constexpr,
1607 N: tl.constexpr,
1608 K: tl.constexpr,
1609 ROWS: tl.constexpr,
1610 TALL: tl.constexpr,
1611 BLOCK_R: tl.constexpr,
1612 BLOCK_K: tl.constexpr,
1613):
1614 batch = tl.program_id(0)
1615 col = tl.program_id(1)
1616 rows = tl.arange(0, BLOCK_R)
1617 basis_cols = tl.arange(0, BLOCK_K)
1618 row_mask = rows < ROWS
1619 basis_mask = basis_cols < K
1620 a_base = A + batch * M * N
1621 aw_base = A_WORK + batch * K * ROWS
1622 vw_base = V_WORK + batch * K * K
1624 if TALL:
1625 vals = tl.load(a_base + rows * N + col, mask=row_mask, other=0.0).to(tl.float32)
1626 else:
1627 vals = tl.load(a_base + col * N + rows, mask=row_mask, other=0.0).to(tl.float32)
1628 tl.store(aw_base + col * ROWS + rows, vals, mask=row_mask)
1630 ident = tl.where(basis_cols == col, 1.0, 0.0)
1631 tl.store(vw_base + col * K + basis_cols, ident, mask=basis_mask)
1634@libentry()
1635@triton.jit
1636def _cyclic_jacobi_pair_kernel(
1637 A_WORK,
1638 V_WORK,
1639 STEP,
1640 K: tl.constexpr,
1641 ROUND: tl.constexpr,
1642 ROWS: tl.constexpr,
1643 BLOCK_R: tl.constexpr,
1644 BLOCK_K: tl.constexpr,
1645):
1646 batch = tl.program_id(0)
1647 pair = tl.program_id(1)
1648 rows = tl.arange(0, BLOCK_R)
1649 cols = tl.arange(0, BLOCK_K)
1650 ring = ROUND - 1
1652 pos_p = pair
1653 pos_q = ROUND - 1 - pair
1654 p = tl.where(pos_p == 0, 0, ((pos_p + ring - STEP - 1) % ring) + 1)
1655 q = tl.where(pos_q == 0, 0, ((pos_q + ring - STEP - 1) % ring) + 1)
1656 valid_pair = (p < K) & (q < K)
1657 swap = p > q
1658 p2 = tl.where(swap, q, p)
1659 q2 = tl.where(swap, p, q)
1660 row_mask = (rows < ROWS) & valid_pair
1661 col_mask = (cols < K) & valid_pair
1663 aw_base = A_WORK + batch * K * ROWS
1664 vw_base = V_WORK + batch * K * K
1665 ap = tl.load(aw_base + p2 * ROWS + rows, mask=row_mask, other=0.0)
1666 aq = tl.load(aw_base + q2 * ROWS + rows, mask=row_mask, other=0.0)
1667 alpha = tl.sum(ap * ap)
1668 beta = tl.sum(aq * aq)
1669 gamma = tl.sum(ap * aq)
1670 eps = 1.0e-20
1671 abs_gamma = tl.abs(gamma)
1672 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps)
1673 active = abs_gamma > threshold
1674 safe_gamma = tl.where(active, gamma, 1.0)
1675 tau = (beta - alpha) / (2.0 * safe_gamma)
1676 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0)
1677 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau))
1678 c = tl.rsqrt(1.0 + t * t)
1679 s_rot = t * c
1680 c = tl.where(active, c, 1.0)
1681 s_rot = tl.where(active, s_rot, 0.0)
1683 new_ap = c * ap - s_rot * aq
1684 new_aq = s_rot * ap + c * aq
1685 tl.store(aw_base + p2 * ROWS + rows, new_ap, mask=row_mask)
1686 tl.store(aw_base + q2 * ROWS + rows, new_aq, mask=row_mask)
1688 vp = tl.load(vw_base + p2 * K + cols, mask=col_mask, other=0.0)
1689 vq = tl.load(vw_base + q2 * K + cols, mask=col_mask, other=0.0)
1690 new_vp = c * vp - s_rot * vq
1691 new_vq = s_rot * vp + c * vq
1692 tl.store(vw_base + p2 * K + cols, new_vp, mask=col_mask)
1693 tl.store(vw_base + q2 * K + cols, new_vq, mask=col_mask)
1696@libentry()
1697@triton.jit
1698def _serial_cyclic_jacobi_kernel(
1699 A_WORK,
1700 V_WORK,
1701 K,
1702 ROUND,
1703 ROWS: tl.constexpr,
1704 SWEEPS,
1705 TAIL_STEPS,
1706 BLOCK_R: tl.constexpr,
1707 BLOCK_K: tl.constexpr,
1708):
1709 batch = tl.program_id(0)
1710 rows = tl.arange(0, BLOCK_R)
1711 cols = tl.arange(0, BLOCK_K)
1712 row_base_mask = rows < ROWS
1713 col_base_mask = cols < K
1714 aw_base = A_WORK + batch * K * ROWS
1715 vw_base = V_WORK + batch * K * K
1716 eps = 1.0e-20
1717 ring = ROUND - 1
1718 half_round = ROUND // 2
1720 sweep = 0
1721 while sweep < SWEEPS:
1722 step = 0
1723 while step < ROUND - 1:
1724 pair = 0
1725 while pair < half_round:
1726 pos_p = pair
1727 pos_q = ROUND - 1 - pair
1728 p = tl.where(pos_p == 0, 0, ((pos_p + ring - step - 1) % ring) + 1)
1729 q = tl.where(pos_q == 0, 0, ((pos_q + ring - step - 1) % ring) + 1)
1730 valid_pair = (p < K) & (q < K)
1731 swap = p > q
1732 p2 = tl.where(swap, q, p)
1733 q2 = tl.where(swap, p, q)
1734 row_mask = row_base_mask & valid_pair
1735 col_mask = col_base_mask & valid_pair
1737 ap = tl.load(aw_base + p2 * ROWS + rows, mask=row_mask, other=0.0)
1738 aq = tl.load(aw_base + q2 * ROWS + rows, mask=row_mask, other=0.0)
1739 alpha = tl.sum(ap * ap)
1740 beta = tl.sum(aq * aq)
1741 gamma = tl.sum(ap * aq)
1742 abs_gamma = tl.abs(gamma)
1743 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps)
1744 active = abs_gamma > threshold
1745 safe_gamma = tl.where(active, gamma, 1.0)
1746 tau = (beta - alpha) / (2.0 * safe_gamma)
1747 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0)
1748 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau))
1749 c = tl.rsqrt(1.0 + t * t)
1750 s_rot = t * c
1751 c = tl.where(active, c, 1.0)
1752 s_rot = tl.where(active, s_rot, 0.0)
1754 new_ap = c * ap - s_rot * aq
1755 new_aq = s_rot * ap + c * aq
1756 tl.store(aw_base + p2 * ROWS + rows, new_ap, mask=row_mask)
1757 tl.store(aw_base + q2 * ROWS + rows, new_aq, mask=row_mask)
1759 vp = tl.load(vw_base + p2 * K + cols, mask=col_mask, other=0.0)
1760 vq = tl.load(vw_base + q2 * K + cols, mask=col_mask, other=0.0)
1761 new_vp = c * vp - s_rot * vq
1762 new_vq = s_rot * vp + c * vq
1763 tl.store(vw_base + p2 * K + cols, new_vp, mask=col_mask)
1764 tl.store(vw_base + q2 * K + cols, new_vq, mask=col_mask)
1765 pair += 1
1766 step += 1
1767 sweep += 1
1769 step = 0
1770 while step < TAIL_STEPS:
1771 pair = 0
1772 while pair < half_round:
1773 pos_p = pair
1774 pos_q = ROUND - 1 - pair
1775 p = tl.where(pos_p == 0, 0, ((pos_p + ring - step - 1) % ring) + 1)
1776 q = tl.where(pos_q == 0, 0, ((pos_q + ring - step - 1) % ring) + 1)
1777 valid_pair = (p < K) & (q < K)
1778 swap = p > q
1779 p2 = tl.where(swap, q, p)
1780 q2 = tl.where(swap, p, q)
1781 row_mask = row_base_mask & valid_pair
1782 col_mask = col_base_mask & valid_pair
1784 ap = tl.load(aw_base + p2 * ROWS + rows, mask=row_mask, other=0.0)
1785 aq = tl.load(aw_base + q2 * ROWS + rows, mask=row_mask, other=0.0)
1786 alpha = tl.sum(ap * ap)
1787 beta = tl.sum(aq * aq)
1788 gamma = tl.sum(ap * aq)
1789 abs_gamma = tl.abs(gamma)
1790 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps)
1791 active = abs_gamma > threshold
1792 safe_gamma = tl.where(active, gamma, 1.0)
1793 tau = (beta - alpha) / (2.0 * safe_gamma)
1794 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0)
1795 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau))
1796 c = tl.rsqrt(1.0 + t * t)
1797 s_rot = t * c
1798 c = tl.where(active, c, 1.0)
1799 s_rot = tl.where(active, s_rot, 0.0)
1801 new_ap = c * ap - s_rot * aq
1802 new_aq = s_rot * ap + c * aq
1803 tl.store(aw_base + p2 * ROWS + rows, new_ap, mask=row_mask)
1804 tl.store(aw_base + q2 * ROWS + rows, new_aq, mask=row_mask)
1806 vp = tl.load(vw_base + p2 * K + cols, mask=col_mask, other=0.0)
1807 vq = tl.load(vw_base + q2 * K + cols, mask=col_mask, other=0.0)
1808 new_vp = c * vp - s_rot * vq
1809 new_vq = s_rot * vp + c * vq
1810 tl.store(vw_base + p2 * K + cols, new_vp, mask=col_mask)
1811 tl.store(vw_base + q2 * K + cols, new_vq, mask=col_mask)
1812 pair += 1
1813 step += 1
1816@libentry()
1817@triton.jit
1818def _cyclic_jacobi_norm_kernel(
1819 A_WORK,
1820 S_WORK,
1821 K: tl.constexpr,
1822 ROWS: tl.constexpr,
1823 BLOCK_R: tl.constexpr,
1824):
1825 batch = tl.program_id(0)
1826 col = tl.program_id(1)
1827 rows = tl.arange(0, BLOCK_R)
1828 mask = rows < ROWS
1829 aw_base = A_WORK + batch * K * ROWS
1830 vals = tl.load(aw_base + col * ROWS + rows, mask=mask, other=0.0)
1831 norm = tl.sqrt(tl.sum(vals * vals))
1832 tl.store(S_WORK + batch * K + col, norm)
1835@libentry()
1836@triton.jit
1837def _cyclic_jacobi_finalize_kernel(
1838 A_WORK,
1839 V_WORK,
1840 S_WORK,
1841 U,
1842 S,
1843 V,
1844 M: tl.constexpr,
1845 N: tl.constexpr,
1846 K: tl.constexpr,
1847 ROWS: tl.constexpr,
1848 TALL: tl.constexpr,
1849 BLOCK_R: tl.constexpr,
1850 BLOCK_K: tl.constexpr,
1851):
1852 batch = tl.program_id(0)
1853 col = tl.program_id(1)
1854 rows = tl.arange(0, BLOCK_R)
1855 basis_cols = tl.arange(0, BLOCK_K)
1856 row_mask = rows < ROWS
1857 basis_mask = basis_cols < K
1858 eps = 1.0e-20
1860 s_col = tl.load(S_WORK + batch * K + col)
1861 rank = tl.full((), 0, dtype=tl.int32)
1862 for other in tl.static_range(0, K):
1863 s_other = tl.load(S_WORK + batch * K + other)
1864 rank += ((s_other > s_col) | ((s_other == s_col) & (other < col))).to(tl.int32)
1866 aw_base = A_WORK + batch * K * ROWS
1867 vw_base = V_WORK + batch * K * K
1868 col_vals = tl.load(aw_base + col * ROWS + rows, mask=row_mask, other=0.0)
1869 inv_norm = tl.where(s_col > eps, 1.0 / s_col, 0.0)
1870 basis = tl.load(vw_base + col * K + basis_cols, mask=basis_mask, other=0.0)
1871 tl.store(S + batch * K + rank, s_col)
1873 if TALL:
1874 tl.store(
1875 U + batch * M * K + rows * K + rank,
1876 col_vals * inv_norm,
1877 mask=row_mask,
1878 )
1879 tl.store(
1880 V + batch * N * K + basis_cols * K + rank,
1881 basis,
1882 mask=basis_mask,
1883 )
1884 else:
1885 tl.store(
1886 U + batch * M * K + basis_cols * K + rank,
1887 basis,
1888 mask=basis_mask,
1889 )
1890 tl.store(
1891 V + batch * N * K + rows * K + rank,
1892 col_vals * inv_norm,
1893 mask=row_mask,
1894 )
1897@libentry()
1898@triton.jit
1899def _blocked_jacobi_pair_svals_kernel(
1900 A_WORK,
1901 STEP,
1902 K: tl.constexpr,
1903 ROUND: tl.constexpr,
1904 ROWS: tl.constexpr,
1905 BLOCK_R: tl.constexpr,
1906):
1907 batch = tl.program_id(0)
1908 pair = tl.program_id(1)
1909 rows = tl.arange(0, BLOCK_R)
1910 ring = ROUND - 1
1912 pos_p = pair
1913 pos_q = ROUND - 1 - pair
1914 p = tl.where(pos_p == 0, 0, ((pos_p + ring - STEP - 1) % ring) + 1)
1915 q = tl.where(pos_q == 0, 0, ((pos_q + ring - STEP - 1) % ring) + 1)
1916 valid_pair = (p < K) & (q < K)
1917 swap = p > q
1918 p2 = tl.where(swap, q, p)
1919 q2 = tl.where(swap, p, q)
1920 row_mask = (rows < ROWS) & valid_pair
1922 aw_base = A_WORK + batch * K * ROWS
1923 ap = tl.load(aw_base + p2 * ROWS + rows, mask=row_mask, other=0.0)
1924 aq = tl.load(aw_base + q2 * ROWS + rows, mask=row_mask, other=0.0)
1925 alpha = tl.sum(ap * ap)
1926 beta = tl.sum(aq * aq)
1927 gamma = tl.sum(ap * aq)
1928 eps = 1.0e-20
1929 abs_gamma = tl.abs(gamma)
1930 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps)
1931 active = abs_gamma > threshold
1932 safe_gamma = tl.where(active, gamma, 1.0)
1933 tau = (beta - alpha) / (2.0 * safe_gamma)
1934 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0)
1935 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau))
1936 c = tl.rsqrt(1.0 + t * t)
1937 s_rot = t * c
1938 c = tl.where(active & valid_pair, c, 1.0)
1939 s_rot = tl.where(active & valid_pair, s_rot, 0.0)
1941 new_ap = c * ap - s_rot * aq
1942 new_aq = s_rot * ap + c * aq
1943 tl.store(aw_base + p2 * ROWS + rows, new_ap, mask=row_mask)
1944 tl.store(aw_base + q2 * ROWS + rows, new_aq, mask=row_mask)
1947@libentry()
1948@triton.jit
1949def _blocked_jacobi_pair_a_kernel(
1950 A_WORK,
1951 ROT_C,
1952 ROT_S,
1953 STEP,
1954 K: tl.constexpr,
1955 ROUND: tl.constexpr,
1956 ROWS: tl.constexpr,
1957 BLOCK_R: tl.constexpr,
1958):
1959 batch = tl.program_id(0)
1960 pair = tl.program_id(1)
1961 rows = tl.arange(0, BLOCK_R)
1962 ring = ROUND - 1
1964 pos_p = pair
1965 pos_q = ROUND - 1 - pair
1966 p = tl.where(pos_p == 0, 0, ((pos_p + ring - STEP - 1) % ring) + 1)
1967 q = tl.where(pos_q == 0, 0, ((pos_q + ring - STEP - 1) % ring) + 1)
1968 valid_pair = (p < K) & (q < K)
1969 swap = p > q
1970 p2 = tl.where(swap, q, p)
1971 q2 = tl.where(swap, p, q)
1972 row_mask = (rows < ROWS) & valid_pair
1974 aw_base = A_WORK + batch * K * ROWS
1975 ap = tl.load(aw_base + p2 * ROWS + rows, mask=row_mask, other=0.0)
1976 aq = tl.load(aw_base + q2 * ROWS + rows, mask=row_mask, other=0.0)
1977 alpha = tl.sum(ap * ap)
1978 beta = tl.sum(aq * aq)
1979 gamma = tl.sum(ap * aq)
1980 eps = 1.0e-20
1981 abs_gamma = tl.abs(gamma)
1982 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps)
1983 active = abs_gamma > threshold
1984 safe_gamma = tl.where(active, gamma, 1.0)
1985 tau = (beta - alpha) / (2.0 * safe_gamma)
1986 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0)
1987 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau))
1988 c = tl.rsqrt(1.0 + t * t)
1989 s_rot = t * c
1990 c = tl.where(active & valid_pair, c, 1.0)
1991 s_rot = tl.where(active & valid_pair, s_rot, 0.0)
1993 new_ap = c * ap - s_rot * aq
1994 new_aq = s_rot * ap + c * aq
1995 tl.store(aw_base + p2 * ROWS + rows, new_ap, mask=row_mask)
1996 tl.store(aw_base + q2 * ROWS + rows, new_aq, mask=row_mask)
1998 rot_base = batch * (ROUND // 2) + pair
1999 tl.store(ROT_C + rot_base, c)
2000 tl.store(ROT_S + rot_base, s_rot)
2003@libentry()
2004@triton.jit
2005def _hier_block_jacobi_pair_a_kernel(
2006 A_WORK,
2007 STEP,
2008 K: tl.constexpr,
2009 K_BLOCKS: tl.constexpr,
2010 ROUND_BLOCKS: tl.constexpr,
2011 ROWS: tl.constexpr,
2012 TILE_B: tl.constexpr,
2013 TILE_COLS: tl.constexpr,
2014 BLOCK_R: tl.constexpr,
2015 LOCAL_SWEEPS: tl.constexpr,
2016):
2017 batch = tl.program_id(0)
2018 pair = tl.program_id(1)
2019 rows = tl.arange(0, BLOCK_R)
2020 local_cols = tl.arange(0, TILE_COLS)
2021 ring = ROUND_BLOCKS - 1
2023 pos_p = pair
2024 pos_q = ROUND_BLOCKS - 1 - pair
2025 p_block = tl.where(pos_p == 0, 0, ((pos_p + ring - STEP - 1) % ring) + 1)
2026 q_block = tl.where(pos_q == 0, 0, ((pos_q + ring - STEP - 1) % ring) + 1)
2027 valid_pair = (p_block < K_BLOCKS) & (q_block < K_BLOCKS)
2028 p2 = tl.minimum(p_block, q_block)
2029 q2 = tl.maximum(p_block, q_block)
2031 col_ids = tl.where(
2032 local_cols < TILE_B,
2033 p2 * TILE_B + local_cols,
2034 q2 * TILE_B + local_cols - TILE_B,
2035 )
2036 row_mask = rows < ROWS
2037 col_mask = (col_ids < K) & valid_pair
2038 aw_base = A_WORK + batch * K * ROWS
2039 vals = tl.load(
2040 aw_base + col_ids[:, None] * ROWS + rows[None, :],
2041 mask=col_mask[:, None] & row_mask[None, :],
2042 other=0.0,
2043 ).to(tl.float32)
2044 col_axis = local_cols[:, None]
2045 eps = 1.0e-20
2047 for _ in tl.static_range(0, LOCAL_SWEEPS):
2048 for p in tl.static_range(0, TILE_COLS):
2049 for q in tl.static_range(p + 1, TILE_COLS):
2050 ap = tl.sum(tl.where(col_axis == p, vals, 0.0), axis=0)
2051 aq = tl.sum(tl.where(col_axis == q, vals, 0.0), axis=0)
2052 alpha = tl.sum(ap * ap)
2053 beta = tl.sum(aq * aq)
2054 gamma = tl.sum(ap * aq)
2055 abs_gamma = tl.abs(gamma)
2056 threshold = 1.0e-7 * tl.sqrt(alpha * beta + eps)
2057 active = abs_gamma > threshold
2058 safe_gamma = tl.where(active, gamma, 1.0)
2059 tau = (beta - alpha) / (2.0 * safe_gamma)
2060 sign_tau = tl.where(tau >= 0.0, 1.0, -1.0)
2061 t = sign_tau / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau))
2062 c = tl.rsqrt(1.0 + t * t)
2063 s_rot = t * c
2064 c = tl.where(active & valid_pair, c, 1.0)
2065 s_rot = tl.where(active & valid_pair, s_rot, 0.0)
2067 new_ap = c * ap - s_rot * aq
2068 new_aq = s_rot * ap + c * aq
2069 vals = tl.where(col_axis == p, new_ap[None, :], vals)
2070 vals = tl.where(col_axis == q, new_aq[None, :], vals)
2072 tl.store(
2073 aw_base + col_ids[:, None] * ROWS + rows[None, :],
2074 vals,
2075 mask=col_mask[:, None] & row_mask[None, :],
2076 )
2079@libentry()
2080@triton.jit
2081def _blocked_jacobi_apply_v_kernel(
2082 V_WORK,
2083 ROT_C,
2084 ROT_S,
2085 STEP,
2086 K: tl.constexpr,
2087 ROUND: tl.constexpr,
2088 BLOCK_V: tl.constexpr,
2089):
2090 batch = tl.program_id(0)
2091 pair = tl.program_id(1)
2092 block = tl.program_id(2)
2093 cols = block * BLOCK_V + tl.arange(0, BLOCK_V)
2094 ring = ROUND - 1
2096 pos_p = pair
2097 pos_q = ROUND - 1 - pair
2098 p = tl.where(pos_p == 0, 0, ((pos_p + ring - STEP - 1) % ring) + 1)
2099 q = tl.where(pos_q == 0, 0, ((pos_q + ring - STEP - 1) % ring) + 1)
2100 valid_pair = (p < K) & (q < K)
2101 swap = p > q
2102 p2 = tl.where(swap, q, p)
2103 q2 = tl.where(swap, p, q)
2104 mask = (cols < K) & valid_pair
2106 rot_base = batch * (ROUND // 2) + pair
2107 c = tl.load(ROT_C + rot_base)
2108 s_rot = tl.load(ROT_S + rot_base)
2109 vw_base = V_WORK + batch * K * K
2110 vp = tl.load(vw_base + p2 * K + cols, mask=mask, other=0.0)
2111 vq = tl.load(vw_base + q2 * K + cols, mask=mask, other=0.0)
2112 new_vp = c * vp - s_rot * vq
2113 new_vq = s_rot * vp + c * vq
2114 tl.store(vw_base + p2 * K + cols, new_vp, mask=mask)
2115 tl.store(vw_base + q2 * K + cols, new_vq, mask=mask)
2118@libentry()
2119@triton.jit
2120def _blocked_jacobi_rank_kernel(
2121 S_WORK,
2122 RANKS,
2123 S,
2124 K,
2125):
2126 batch = tl.program_id(0)
2127 col = tl.program_id(1)
2128 s_col = tl.load(S_WORK + batch * K + col)
2129 rank = tl.full((), 0, dtype=tl.int32)
2130 other = 0
2131 while other < K:
2132 s_other = tl.load(S_WORK + batch * K + other)
2133 rank += ((s_other > s_col) | ((s_other == s_col) & (other < col))).to(tl.int32)
2134 other += 1
2135 tl.store(RANKS + batch * K + col, rank)
2136 tl.store(S + batch * K + rank, s_col)
2139@libentry()
2140@triton.jit
2141def _blocked_jacobi_store_projected_kernel(
2142 A_WORK,
2143 S_WORK,
2144 RANKS,
2145 PROJECTED,
2146 K: tl.constexpr,
2147 ROWS: tl.constexpr,
2148 OUT_ROWS: tl.constexpr,
2149 BLOCK_R: tl.constexpr,
2150):
2151 batch = tl.program_id(0)
2152 col = tl.program_id(1)
2153 block = tl.program_id(2)
2154 rows = block * BLOCK_R + tl.arange(0, BLOCK_R)
2155 mask = rows < OUT_ROWS
2156 rank = tl.load(RANKS + batch * K + col)
2157 s_col = tl.load(S_WORK + batch * K + col)
2158 eps = 1.0e-20
2159 vals = tl.load(
2160 A_WORK + batch * K * ROWS + col * ROWS + rows,
2161 mask=mask,
2162 other=0.0,
2163 )
2164 vals = vals / tl.maximum(s_col, eps)
2165 basis = tl.where(rows == rank, 1.0, 0.0)
2166 vals = tl.where(s_col <= eps, basis, vals)
2167 tl.store(
2168 PROJECTED + batch * OUT_ROWS * K + rows * K + rank,
2169 vals,
2170 mask=mask,
2171 )
2174@libentry()
2175@triton.jit
2176def _blocked_jacobi_store_basis_kernel(
2177 V_WORK,
2178 RANKS,
2179 BASIS,
2180 K: tl.constexpr,
2181 BLOCK_V: tl.constexpr,
2182):
2183 batch = tl.program_id(0)
2184 col = tl.program_id(1)
2185 block = tl.program_id(2)
2186 rows = block * BLOCK_V + tl.arange(0, BLOCK_V)
2187 mask = rows < K
2188 rank = tl.load(RANKS + batch * K + col)
2189 vals = tl.load(
2190 V_WORK + batch * K * K + col * K + rows,
2191 mask=mask,
2192 other=0.0,
2193 )
2194 tl.store(
2195 BASIS + batch * K * K + rows * K + rank,
2196 vals,
2197 mask=mask,
2198 )
2201@libentry()
2202@triton.jit
2203def _thin_reorthogonalize_kernel(
2204 Q,
2205 ROWS: tl.constexpr,
2206 K: tl.constexpr,
2207 BLOCK_R: tl.constexpr,
2208):
2209 batch = tl.program_id(0)
2210 rows = tl.arange(0, BLOCK_R)
2211 row_mask = rows < ROWS
2212 base = Q + batch * ROWS * K
2213 eps = 1.0e-20
2215 for j in tl.static_range(0, K):
2216 vec = tl.load(base + rows * K + j, mask=row_mask, other=0.0).to(tl.float32)
2218 for prev in tl.static_range(0, K):
2219 if prev < j:
2220 q_prev = tl.load(base + rows * K + prev, mask=row_mask, other=0.0).to(
2221 tl.float32
2222 )
2223 coeff = tl.sum(vec * q_prev)
2224 vec = vec - coeff * q_prev
2226 for prev in tl.static_range(0, K):
2227 if prev < j:
2228 q_prev = tl.load(base + rows * K + prev, mask=row_mask, other=0.0).to(
2229 tl.float32
2230 )
2231 coeff = tl.sum(vec * q_prev)
2232 vec = vec - coeff * q_prev
2234 norm = tl.sqrt(tl.sum(vec * vec))
2235 basis = tl.where(rows == j, 1.0, 0.0)
2236 vec = tl.where(norm > eps, vec, basis)
2238 for prev in tl.static_range(0, K):
2239 if prev < j:
2240 q_prev = tl.load(base + rows * K + prev, mask=row_mask, other=0.0).to(
2241 tl.float32
2242 )
2243 coeff = tl.sum(vec * q_prev)
2244 vec = vec - coeff * q_prev
2246 norm = tl.sqrt(tl.sum(vec * vec))
2247 vec = vec / tl.maximum(norm, eps)
2248 tl.store(base + rows * K + j, vec, mask=row_mask)
2251def _cyclic_jacobi_svd(input):
2252 batch, m, n = _svd_shape(input)
2253 k = min(m, n)
2254 rows = max(m, n)
2255 a = input.contiguous().reshape(batch, m, n)
2256 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device)
2257 v_work = torch.empty((batch, k, k), dtype=torch.float32, device=input.device)
2258 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device)
2259 u = torch.empty((batch, m, k), dtype=input.dtype, device=input.device)
2260 s = torch.empty((batch, k), dtype=input.dtype, device=input.device)
2261 v = torch.empty((batch, n, k), dtype=input.dtype, device=input.device)
2262 block_r = triton.next_power_of_2(rows)
2263 block_k = triton.next_power_of_2(k)
2264 sweeps = 6 if k == 32 else 8 if k < 32 else 12
2265 tail_steps = 20 if k == 32 else 0
2266 round_size = k if k % 2 == 0 else k + 1
2267 serial_medium = 16 <= k <= 32 and rows <= 64 and batch <= 32
2268 with torch_device_fn.device(input.device):
2269 _cyclic_jacobi_init_kernel[(batch, k)](
2270 a,
2271 a_work,
2272 v_work,
2273 M=m,
2274 N=n,
2275 K=k,
2276 ROWS=rows,
2277 TALL=m >= n,
2278 BLOCK_R=block_r,
2279 BLOCK_K=block_k,
2280 num_warps=1 if block_r <= 64 else 4,
2281 )
2282 if serial_medium:
2283 _serial_cyclic_jacobi_kernel[(batch,)](
2284 a_work,
2285 v_work,
2286 K=k,
2287 ROUND=round_size,
2288 ROWS=rows,
2289 SWEEPS=sweeps,
2290 TAIL_STEPS=tail_steps,
2291 BLOCK_R=block_r,
2292 BLOCK_K=block_k,
2293 num_warps=1,
2294 )
2295 else:
2296 for _ in range(sweeps):
2297 for step in range(round_size - 1):
2298 _cyclic_jacobi_pair_kernel[(batch, round_size // 2)](
2299 a_work,
2300 v_work,
2301 step,
2302 K=k,
2303 ROUND=round_size,
2304 ROWS=rows,
2305 BLOCK_R=block_r,
2306 BLOCK_K=block_k,
2307 num_warps=1 if block_r <= 64 else 4,
2308 )
2309 _cyclic_jacobi_norm_kernel[(batch, k)](
2310 a_work,
2311 s_work,
2312 K=k,
2313 ROWS=rows,
2314 BLOCK_R=block_r,
2315 num_warps=1 if block_r <= 64 else 4,
2316 )
2317 _cyclic_jacobi_finalize_kernel[(batch, k)](
2318 a_work,
2319 v_work,
2320 s_work,
2321 u,
2322 s,
2323 v,
2324 M=m,
2325 N=n,
2326 K=k,
2327 ROWS=rows,
2328 TALL=m >= n,
2329 BLOCK_R=block_r,
2330 BLOCK_K=block_k,
2331 num_warps=1 if block_r <= 64 else 4,
2332 )
2333 if k <= 17 and rows <= 64:
2334 with torch_device_fn.device(input.device):
2335 if m >= n:
2336 _thin_reorthogonalize_kernel[(batch,)](
2337 v,
2338 ROWS=n,
2339 K=k,
2340 BLOCK_R=triton.next_power_of_2(n),
2341 num_warps=1,
2342 )
2343 else:
2344 _thin_reorthogonalize_kernel[(batch,)](
2345 u,
2346 ROWS=m,
2347 K=k,
2348 BLOCK_R=triton.next_power_of_2(m),
2349 num_warps=1,
2350 )
2352 if m >= n:
2353 u = _triton_bmm(a, v, (batch, m, k))
2354 projected = u
2355 projected_rows = m
2356 else:
2357 a_t = a.transpose(1, 2).contiguous()
2358 v = _triton_bmm(a_t, u, (batch, n, k))
2359 projected = v
2360 projected_rows = n
2361 with torch_device_fn.device(input.device):
2362 _normalize_projection_kernel[(batch, k)](
2363 projected,
2364 s,
2365 ROWS=projected_rows,
2366 K=k,
2367 BLOCK_R=triton.next_power_of_2(projected_rows),
2368 num_warps=1,
2369 )
2370 _complete_zero_projection_kernel[(batch, k)](
2371 projected,
2372 s,
2373 ROWS=projected_rows,
2374 K=k,
2375 BLOCK_R=triton.next_power_of_2(projected_rows),
2376 num_warps=1,
2377 )
2378 if batch > 1 and k <= 16:
2379 _thin_reorthogonalize_kernel[(batch,)](
2380 projected,
2381 ROWS=projected_rows,
2382 K=k,
2383 BLOCK_R=triton.next_power_of_2(projected_rows),
2384 num_warps=1,
2385 )
2386 return (
2387 u.reshape(*input.shape[:-2], m, k),
2388 s.reshape(*input.shape[:-2], k),
2389 v.reshape(*input.shape[:-2], n, k),
2390 )
2393def _projected_jacobi_svd(input):
2394 batch, m, n = _svd_shape(input)
2395 k = min(m, n)
2396 rows = max(m, n)
2397 tall = m >= n
2398 a = input.contiguous().reshape(batch, m, n)
2399 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device)
2400 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device)
2401 ranks = torch.empty((batch, k), dtype=torch.int32, device=input.device)
2402 s = torch.empty((batch, k), dtype=input.dtype, device=input.device)
2404 projected_rows = m if tall else n
2405 projected = torch.empty(
2406 (batch, projected_rows, k), dtype=input.dtype, device=input.device
2407 )
2408 block_r = triton.next_power_of_2(rows)
2409 sweeps = 10 if k >= 128 else 8
2410 round_size = k if k % 2 == 0 else k + 1
2411 half_round = round_size // 2
2412 rot_c = torch.empty((batch, half_round), dtype=torch.float32, device=input.device)
2413 rot_s = torch.empty((batch, half_round), dtype=torch.float32, device=input.device)
2415 with torch_device_fn.device(input.device):
2416 _cyclic_jacobi_init_a_kernel[(batch, k)](
2417 a,
2418 a_work,
2419 M=m,
2420 N=n,
2421 K=k,
2422 ROWS=rows,
2423 TALL=tall,
2424 BLOCK_R=block_r,
2425 num_warps=1 if block_r <= 64 else 4,
2426 )
2427 for _ in range(sweeps):
2428 for step in range(round_size - 1):
2429 _blocked_jacobi_pair_a_kernel[(batch, half_round)](
2430 a_work,
2431 rot_c,
2432 rot_s,
2433 step,
2434 K=k,
2435 ROUND=round_size,
2436 ROWS=rows,
2437 BLOCK_R=block_r,
2438 num_warps=1 if block_r <= 64 else 4,
2439 )
2440 _cyclic_jacobi_norm_kernel[(batch, k)](
2441 a_work,
2442 s_work,
2443 K=k,
2444 ROWS=rows,
2445 BLOCK_R=block_r,
2446 num_warps=1 if block_r <= 64 else 4,
2447 )
2448 _blocked_jacobi_rank_kernel[(batch, k)](
2449 s_work,
2450 ranks,
2451 s,
2452 k,
2453 num_warps=1,
2454 )
2455 _blocked_jacobi_store_projected_kernel[
2456 (batch, k, triton.cdiv(projected_rows, block_r))
2457 ](
2458 a_work,
2459 s_work,
2460 ranks,
2461 projected,
2462 K=k,
2463 ROWS=rows,
2464 OUT_ROWS=projected_rows,
2465 BLOCK_R=block_r,
2466 num_warps=1 if block_r <= 64 else 4,
2467 )
2469 if tall:
2470 u = projected
2471 v = _triton_bmm(a.transpose(1, 2).contiguous(), u, (batch, n, k))
2472 normalized = v
2473 normalized_rows = n
2474 else:
2475 v = projected
2476 u = _triton_bmm(a, v, (batch, m, k))
2477 normalized = u
2478 normalized_rows = m
2480 with torch_device_fn.device(input.device):
2481 _normalize_projection_kernel[(batch, k)](
2482 normalized,
2483 s,
2484 ROWS=normalized_rows,
2485 K=k,
2486 BLOCK_R=triton.next_power_of_2(normalized_rows),
2487 num_warps=1 if normalized_rows <= 64 else 4,
2488 )
2490 return (
2491 u.reshape(*input.shape[:-2], m, k),
2492 s.reshape(*input.shape[:-2], k),
2493 v.reshape(*input.shape[:-2], n, k),
2494 )
2497def _blocked_jacobi_svd(input):
2498 batch, m, n = _svd_shape(input)
2499 k = min(m, n)
2500 rows = max(m, n)
2501 a = input.contiguous().reshape(batch, m, n)
2502 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device)
2503 v_work = torch.empty((batch, k, k), dtype=torch.float32, device=input.device)
2504 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device)
2505 ranks = torch.empty((batch, k), dtype=torch.int32, device=input.device)
2506 u = torch.empty((batch, m, k), dtype=input.dtype, device=input.device)
2507 s = torch.empty((batch, k), dtype=input.dtype, device=input.device)
2508 v = torch.empty((batch, n, k), dtype=input.dtype, device=input.device)
2510 block_r = triton.next_power_of_2(rows)
2511 block_k = triton.next_power_of_2(k)
2512 block_v = 64
2513 sweeps = 14 if k > 256 else 10
2514 round_size = k if k % 2 == 0 else k + 1
2515 half_round = round_size // 2
2516 rot_c = torch.empty((batch, half_round), dtype=torch.float32, device=input.device)
2517 rot_s = torch.empty((batch, half_round), dtype=torch.float32, device=input.device)
2518 with torch_device_fn.device(input.device):
2519 _cyclic_jacobi_init_kernel[(batch, k)](
2520 a,
2521 a_work,
2522 v_work,
2523 M=m,
2524 N=n,
2525 K=k,
2526 ROWS=rows,
2527 TALL=m >= n,
2528 BLOCK_R=block_r,
2529 BLOCK_K=block_k,
2530 num_warps=1 if block_r <= 64 else 4,
2531 )
2532 for _ in range(sweeps):
2533 for step in range(round_size - 1):
2534 _blocked_jacobi_pair_a_kernel[(batch, half_round)](
2535 a_work,
2536 rot_c,
2537 rot_s,
2538 step,
2539 K=k,
2540 ROUND=round_size,
2541 ROWS=rows,
2542 BLOCK_R=block_r,
2543 num_warps=1 if block_r <= 64 else 4,
2544 )
2545 _blocked_jacobi_apply_v_kernel[
2546 (batch, half_round, triton.cdiv(k, block_v))
2547 ](
2548 v_work,
2549 rot_c,
2550 rot_s,
2551 step,
2552 K=k,
2553 ROUND=round_size,
2554 BLOCK_V=block_v,
2555 num_warps=1,
2556 )
2557 _cyclic_jacobi_norm_kernel[(batch, k)](
2558 a_work,
2559 s_work,
2560 K=k,
2561 ROWS=rows,
2562 BLOCK_R=block_r,
2563 num_warps=1 if block_r <= 64 else 4,
2564 )
2565 _blocked_jacobi_rank_kernel[(batch, k)](
2566 s_work,
2567 ranks,
2568 s,
2569 k,
2570 num_warps=1,
2571 )
2572 if m >= n:
2573 _blocked_jacobi_store_projected_kernel[(batch, k, triton.cdiv(m, block_r))](
2574 a_work,
2575 s_work,
2576 ranks,
2577 u,
2578 K=k,
2579 ROWS=rows,
2580 OUT_ROWS=m,
2581 BLOCK_R=block_r,
2582 num_warps=1 if block_r <= 64 else 4,
2583 )
2584 _blocked_jacobi_store_basis_kernel[(batch, k, triton.cdiv(n, block_v))](
2585 v_work,
2586 ranks,
2587 v,
2588 K=k,
2589 BLOCK_V=block_v,
2590 num_warps=1,
2591 )
2592 else:
2593 _blocked_jacobi_store_basis_kernel[(batch, k, triton.cdiv(m, block_v))](
2594 v_work,
2595 ranks,
2596 u,
2597 K=k,
2598 BLOCK_V=block_v,
2599 num_warps=1,
2600 )
2601 _blocked_jacobi_store_projected_kernel[(batch, k, triton.cdiv(n, block_r))](
2602 a_work,
2603 s_work,
2604 ranks,
2605 v,
2606 K=k,
2607 ROWS=rows,
2608 OUT_ROWS=n,
2609 BLOCK_R=block_r,
2610 num_warps=1 if block_r <= 64 else 4,
2611 )
2613 return (
2614 u.reshape(*input.shape[:-2], m, k),
2615 s.reshape(*input.shape[:-2], k),
2616 v.reshape(*input.shape[:-2], n, k),
2617 )
2620def _blocked_jacobi_square_project_svd(input):
2621 batch, m, n = _svd_shape(input)
2622 k = min(m, n)
2623 rows = max(m, n)
2624 a = input.contiguous().reshape(batch, m, n)
2625 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device)
2626 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device)
2627 ranks = torch.empty((batch, k), dtype=torch.int32, device=input.device)
2628 u = torch.empty((batch, m, k), dtype=input.dtype, device=input.device)
2629 s = torch.empty((batch, k), dtype=input.dtype, device=input.device)
2631 block_r = triton.next_power_of_2(rows)
2632 sweeps = 12 if k <= 256 else 16
2633 round_size = k if k % 2 == 0 else k + 1
2634 half_round = round_size // 2
2635 rot_c = torch.empty((batch, half_round), dtype=torch.float32, device=input.device)
2636 rot_s = torch.empty((batch, half_round), dtype=torch.float32, device=input.device)
2637 with torch_device_fn.device(input.device):
2638 _cyclic_jacobi_init_a_kernel[(batch, k)](
2639 a,
2640 a_work,
2641 M=m,
2642 N=n,
2643 K=k,
2644 ROWS=rows,
2645 TALL=True,
2646 BLOCK_R=block_r,
2647 num_warps=1 if block_r <= 64 else 4,
2648 )
2649 for _ in range(sweeps):
2650 for step in range(round_size - 1):
2651 _blocked_jacobi_pair_a_kernel[(batch, half_round)](
2652 a_work,
2653 rot_c,
2654 rot_s,
2655 step,
2656 K=k,
2657 ROUND=round_size,
2658 ROWS=rows,
2659 BLOCK_R=block_r,
2660 num_warps=1 if block_r <= 64 else 4,
2661 )
2662 _cyclic_jacobi_norm_kernel[(batch, k)](
2663 a_work,
2664 s_work,
2665 K=k,
2666 ROWS=rows,
2667 BLOCK_R=block_r,
2668 num_warps=1 if block_r <= 64 else 4,
2669 )
2670 _blocked_jacobi_rank_kernel[(batch, k)](
2671 s_work,
2672 ranks,
2673 s,
2674 k,
2675 num_warps=1,
2676 )
2677 _blocked_jacobi_store_projected_kernel[(batch, k, triton.cdiv(m, block_r))](
2678 a_work,
2679 s_work,
2680 ranks,
2681 u,
2682 K=k,
2683 ROWS=rows,
2684 OUT_ROWS=m,
2685 BLOCK_R=block_r,
2686 num_warps=1 if block_r <= 64 else 4,
2687 )
2689 a_t = a.transpose(1, 2).contiguous()
2690 v = _triton_bmm(a_t, u, (batch, n, k))
2691 with torch_device_fn.device(input.device):
2692 _renorm_projection_update_s_kernel[(batch, k)](
2693 v,
2694 s,
2695 ROWS=n,
2696 K=k,
2697 BLOCK_R=triton.next_power_of_2(n),
2698 num_warps=1 if n <= 64 else 4,
2699 )
2701 return (
2702 u.reshape(*input.shape[:-2], m, k),
2703 s.reshape(*input.shape[:-2], k),
2704 v.reshape(*input.shape[:-2], n, k),
2705 )
2708def _hier_block_jacobi_square_project_svd(input):
2709 batch, m, n = _svd_shape(input)
2710 k = min(m, n)
2711 rows = max(m, n)
2712 a = input.contiguous().reshape(batch, m, n)
2713 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device)
2714 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device)
2715 ranks = torch.empty((batch, k), dtype=torch.int32, device=input.device)
2716 u = torch.empty((batch, m, k), dtype=input.dtype, device=input.device)
2717 s = torch.empty((batch, k), dtype=input.dtype, device=input.device)
2719 tile_b = 4 if k == 512 else 2
2720 if m != n or k % tile_b != 0:
2721 return _unsupported_svd(
2722 input,
2723 True,
2724 True,
2725 "Hierarchical block Jacobi supports square matrices with "
2726 "k divisible by two.",
2727 )
2729 block_r = triton.next_power_of_2(rows)
2730 block_count = k // tile_b
2731 round_blocks = block_count if block_count % 2 == 0 else block_count + 1
2732 half_round_blocks = round_blocks // 2
2733 sweep_count = 10 if k <= 256 else 12
2734 tile_cols = tile_b * 2
2735 with torch_device_fn.device(input.device):
2736 _cyclic_jacobi_init_a_kernel[(batch, k)](
2737 a,
2738 a_work,
2739 M=m,
2740 N=n,
2741 K=k,
2742 ROWS=rows,
2743 TALL=True,
2744 BLOCK_R=block_r,
2745 num_warps=1 if block_r <= 64 else 4,
2746 )
2747 for _ in range(sweep_count):
2748 for step in range(round_blocks - 1):
2749 _hier_block_jacobi_pair_a_kernel[(batch, half_round_blocks)](
2750 a_work,
2751 step,
2752 K=k,
2753 K_BLOCKS=block_count,
2754 ROUND_BLOCKS=round_blocks,
2755 ROWS=rows,
2756 TILE_B=tile_b,
2757 TILE_COLS=tile_cols,
2758 BLOCK_R=block_r,
2759 LOCAL_SWEEPS=1,
2760 num_warps=4,
2761 )
2762 _cyclic_jacobi_norm_kernel[(batch, k)](
2763 a_work,
2764 s_work,
2765 K=k,
2766 ROWS=rows,
2767 BLOCK_R=block_r,
2768 num_warps=1 if block_r <= 64 else 4,
2769 )
2770 _blocked_jacobi_rank_kernel[(batch, k)](
2771 s_work,
2772 ranks,
2773 s,
2774 k,
2775 num_warps=1,
2776 )
2777 _blocked_jacobi_store_projected_kernel[(batch, k, triton.cdiv(m, block_r))](
2778 a_work,
2779 s_work,
2780 ranks,
2781 u,
2782 K=k,
2783 ROWS=rows,
2784 OUT_ROWS=m,
2785 BLOCK_R=block_r,
2786 num_warps=1 if block_r <= 64 else 4,
2787 )
2789 a_t = a.transpose(1, 2).contiguous()
2790 v = _triton_bmm(a_t, u, (batch, n, k))
2791 with torch_device_fn.device(input.device):
2792 _renorm_projection_update_s_kernel[(batch, k)](
2793 v,
2794 s,
2795 ROWS=n,
2796 K=k,
2797 BLOCK_R=triton.next_power_of_2(n),
2798 num_warps=1 if n <= 64 else 4,
2799 )
2801 return (
2802 u.reshape(*input.shape[:-2], m, k),
2803 s.reshape(*input.shape[:-2], k),
2804 v.reshape(*input.shape[:-2], n, k),
2805 )
2808def _blocked_jacobi_singular_values(input):
2809 batch, m, n = _svd_shape(input)
2810 k = min(m, n)
2811 rows = max(m, n)
2812 a = input.contiguous().reshape(batch, m, n)
2813 a_work = torch.empty((batch, k, rows), dtype=torch.float32, device=input.device)
2814 s_work = torch.empty((batch, k), dtype=torch.float32, device=input.device)
2815 ranks = torch.empty((batch, k), dtype=torch.int32, device=input.device)
2816 s = torch.empty((batch, k), dtype=input.dtype, device=input.device)
2818 block_r = triton.next_power_of_2(rows)
2819 sweeps = 14 if k > 256 else 10
2820 round_size = k if k % 2 == 0 else k + 1
2821 half_round = round_size // 2
2822 with torch_device_fn.device(input.device):
2823 _cyclic_jacobi_init_a_kernel[(batch, k)](
2824 a,
2825 a_work,
2826 M=m,
2827 N=n,
2828 K=k,
2829 ROWS=rows,
2830 TALL=m >= n,
2831 BLOCK_R=block_r,
2832 num_warps=1 if block_r <= 64 else 4,
2833 )
2834 for _ in range(sweeps):
2835 for step in range(round_size - 1):
2836 _blocked_jacobi_pair_svals_kernel[(batch, half_round)](
2837 a_work,
2838 step,
2839 K=k,
2840 ROUND=round_size,
2841 ROWS=rows,
2842 BLOCK_R=block_r,
2843 num_warps=1 if block_r <= 64 else 4,
2844 )
2845 _cyclic_jacobi_norm_kernel[(batch, k)](
2846 a_work,
2847 s_work,
2848 K=k,
2849 ROWS=rows,
2850 BLOCK_R=block_r,
2851 num_warps=1 if block_r <= 64 else 4,
2852 )
2853 _blocked_jacobi_rank_kernel[(batch, k)](
2854 s_work,
2855 ranks,
2856 s,
2857 k,
2858 num_warps=1,
2859 )
2861 return s.reshape(*input.shape[:-2], k)
2864def _small4_square_svd(input):
2865 batch, m, n = _svd_shape(input)
2866 a = input.contiguous().reshape(batch, m, n)
2867 u = torch.empty((batch, 4, 4), dtype=input.dtype, device=input.device)
2868 s = torch.empty((batch, 4), dtype=input.dtype, device=input.device)
2869 v = torch.empty((batch, 4, 4), dtype=input.dtype, device=input.device)
2870 block_b = 16
2871 with torch_device_fn.device(input.device):
2872 _small4_square_svd_kernel[(triton.cdiv(batch, block_b),)](
2873 a, u, s, v, BATCH=batch, BLOCK_B=block_b, SWEEPS=4, num_warps=1
2874 )
2875 return (
2876 u.reshape(*input.shape[:-2], 4, 4),
2877 s.reshape(*input.shape[:-2], 4),
2878 v.reshape(*input.shape[:-2], 4, 4),
2879 )
2882@libentry()
2883@triton.jit
2884def _rank1_svd_kernel(
2885 A,
2886 U,
2887 S,
2888 V,
2889 M: tl.constexpr,
2890 N: tl.constexpr,
2891 TALL: tl.constexpr,
2892 BLOCK_R: tl.constexpr,
2893):
2894 pid = tl.program_id(0)
2895 offsets = tl.arange(0, BLOCK_R)
2896 eps = 1.1920928955078125e-7
2897 a_base = A + pid * M * N
2898 norm_sq = tl.full((), 0.0, dtype=tl.float32)
2900 if TALL:
2901 for base in range(0, M, BLOCK_R):
2902 rows = base + offsets
2903 mask = rows < M
2904 vals = tl.load(a_base + rows * N, mask=mask, other=0.0).to(tl.float32)
2905 norm_sq += tl.sum(vals * vals)
2907 norm = tl.sqrt(norm_sq)
2908 denom = tl.maximum(norm, eps)
2909 tl.store(S + pid, norm)
2910 tl.store(V + pid, 1.0)
2912 u_base = U + pid * M
2913 for base in range(0, M, BLOCK_R):
2914 rows = base + offsets
2915 mask = rows < M
2916 vals = tl.load(a_base + rows * N, mask=mask, other=0.0).to(tl.float32)
2917 tl.store(u_base + rows, vals / denom, mask=mask)
2918 else:
2919 for base in range(0, N, BLOCK_R):
2920 cols = base + offsets
2921 mask = cols < N
2922 vals = tl.load(a_base + cols, mask=mask, other=0.0).to(tl.float32)
2923 norm_sq += tl.sum(vals * vals)
2925 norm = tl.sqrt(norm_sq)
2926 denom = tl.maximum(norm, eps)
2927 tl.store(S + pid, norm)
2928 tl.store(U + pid, 1.0)
2930 v_base = V + pid * N
2931 for base in range(0, N, BLOCK_R):
2932 cols = base + offsets
2933 mask = cols < N
2934 vals = tl.load(a_base + cols, mask=mask, other=0.0).to(tl.float32)
2935 tl.store(v_base + cols, vals / denom, mask=mask)
2938def _rank1_svd(input):
2939 batch, m, n = _svd_shape(input)
2940 a = input.contiguous().reshape(batch, m, n)
2941 u = torch.empty((batch, m, 1), dtype=input.dtype, device=input.device)
2942 s = torch.empty((batch, 1), dtype=input.dtype, device=input.device)
2943 v = torch.empty((batch, n, 1), dtype=input.dtype, device=input.device)
2944 if batch != 0:
2945 rows = max(m, n)
2946 block_r = _RANK1_BLOCK_R_MAX
2947 if rows <= _RANK1_BLOCK_R_MAX:
2948 block_r = triton.next_power_of_2(rows)
2949 with torch_device_fn.device(input.device):
2950 _rank1_svd_kernel[(batch,)](
2951 a,
2952 u,
2953 s,
2954 v,
2955 m,
2956 n,
2957 TALL=n == 1,
2958 BLOCK_R=block_r,
2959 num_warps=1 if block_r <= 64 else 4,
2960 )
2961 return (
2962 u.reshape(*input.shape[:-2], m, 1),
2963 s.reshape(*input.shape[:-2], 1),
2964 v.reshape(*input.shape[:-2], n, 1),
2965 )
2968@libentry()
2969@triton.jit
2970def _complex_to_real_embedding_kernel(
2971 A_RI,
2972 R,
2973 M: tl.constexpr,
2974 N: tl.constexpr,
2975 BLOCK_SIZE: tl.constexpr,
2976):
2977 batch = tl.program_id(0)
2978 offsets = tl.arange(0, BLOCK_SIZE)
2979 total = 4 * M * N
2980 mask = offsets < total
2981 row = offsets // (2 * N)
2982 col = offsets - row * (2 * N)
2983 src_row = tl.where(row < M, row, row - M)
2984 src_col = tl.where(col < N, col, col - N)
2985 comp = tl.where((row < M) & (col >= N), 1, 0)
2986 comp = tl.where((row >= M) & (col < N), 1, comp)
2987 vals = tl.load(
2988 A_RI + batch * M * N * 2 + (src_row * N + src_col) * 2 + comp,
2989 mask=mask,
2990 other=0.0,
2991 )
2992 sign = tl.where((row < M) & (col >= N), -1.0, 1.0)
2993 tl.store(R + batch * 4 * M * N + offsets, vals * sign, mask=mask)
2996@libentry()
2997@triton.jit
2998def _complex_svd_pick_factor_kernel(
2999 REAL_FACTOR,
3000 OUT_RI,
3001 ROWS: tl.constexpr,
3002 K: tl.constexpr,
3003 REAL_K: tl.constexpr,
3004 BLOCK_SIZE: tl.constexpr,
3005):
3006 batch = tl.program_id(0)
3007 offsets = tl.arange(0, BLOCK_SIZE)
3008 mask = offsets < ROWS * K
3009 row = offsets // K
3010 col = offsets % K
3011 src_col = col * 2
3012 real = tl.load(
3013 REAL_FACTOR + batch * (2 * ROWS) * REAL_K + row * REAL_K + src_col,
3014 mask=mask,
3015 other=0.0,
3016 )
3017 imag = tl.load(
3018 REAL_FACTOR + batch * (2 * ROWS) * REAL_K + (ROWS + row) * REAL_K + src_col,
3019 mask=mask,
3020 other=0.0,
3021 )
3022 out_base = OUT_RI + batch * ROWS * K * 2 + offsets * 2
3023 tl.store(out_base, real, mask=mask)
3024 tl.store(out_base + 1, imag, mask=mask)
3027@libentry()
3028@triton.jit
3029def _complex_svd_pick_s_kernel(
3030 S_REAL,
3031 S,
3032 K: tl.constexpr,
3033 REAL_K: tl.constexpr,
3034 BLOCK_K: tl.constexpr,
3035):
3036 batch = tl.program_id(0)
3037 cols = tl.arange(0, BLOCK_K)
3038 mask = cols < K
3039 src = cols * 2
3040 vals_a = tl.load(S_REAL + batch * REAL_K + src, mask=mask, other=0.0)
3041 vals_b = tl.load(S_REAL + batch * REAL_K + src + 1, mask=mask, other=0.0)
3042 tl.store(S + batch * K + cols, 0.5 * (vals_a + vals_b), mask=mask)
3045@libentry()
3046@triton.jit
3047def _complex_svd_pick_orthonormal_v_kernel(
3048 V_REAL,
3049 V_RI,
3050 ROWS: tl.constexpr,
3051 K: tl.constexpr,
3052 REAL_K: tl.constexpr,
3053 BLOCK_ROWS: tl.constexpr,
3054 BLOCK_K: tl.constexpr,
3055):
3056 batch = tl.program_id(0)
3057 rows = tl.arange(0, BLOCK_ROWS)
3058 cols = tl.arange(0, BLOCK_K)
3059 row_mask = rows < ROWS
3060 col_mask = cols < K
3061 src_cols = cols * 2
3062 base = V_REAL + batch * (2 * ROWS) * REAL_K
3063 vr = tl.load(
3064 base + rows[:, None] * REAL_K + src_cols[None, :],
3065 mask=row_mask[:, None] & col_mask[None, :],
3066 other=0.0,
3067 )
3068 vi = tl.load(
3069 base + (ROWS + rows[:, None]) * REAL_K + src_cols[None, :],
3070 mask=row_mask[:, None] & col_mask[None, :],
3071 other=0.0,
3072 )
3074 for c in tl.static_range(0, 16):
3075 cur_mask = c < K
3076 cur_r = tl.sum(tl.where(cols[None, :] == c, vr, 0.0), axis=1)
3077 cur_i = tl.sum(tl.where(cols[None, :] == c, vi, 0.0), axis=1)
3078 for p in tl.static_range(0, c):
3079 prev_r = tl.sum(tl.where(cols[None, :] == p, vr, 0.0), axis=1)
3080 prev_i = tl.sum(tl.where(cols[None, :] == p, vi, 0.0), axis=1)
3081 coeff_r = tl.sum(
3082 tl.where(row_mask, prev_r * cur_r + prev_i * cur_i, 0.0), axis=0
3083 )
3084 coeff_i = tl.sum(
3085 tl.where(row_mask, prev_r * cur_i - prev_i * cur_r, 0.0), axis=0
3086 )
3087 cur_r -= prev_r * coeff_r - prev_i * coeff_i
3088 cur_i -= prev_r * coeff_i + prev_i * coeff_r
3089 norm_sq = tl.sum(tl.where(row_mask, cur_r * cur_r + cur_i * cur_i, 0.0), axis=0)
3090 inv_norm = tl.rsqrt(tl.maximum(norm_sq, 1.0e-20))
3091 cur_r *= inv_norm
3092 cur_i *= inv_norm
3093 vr = tl.where((cols[None, :] == c) & cur_mask, cur_r[:, None], vr)
3094 vi = tl.where((cols[None, :] == c) & cur_mask, cur_i[:, None], vi)
3096 out_base = V_RI + batch * ROWS * K * 2
3097 offsets = rows[:, None] * K + cols[None, :]
3098 mask = row_mask[:, None] & col_mask[None, :]
3099 tl.store(out_base + offsets * 2, vr, mask=mask)
3100 tl.store(out_base + offsets * 2 + 1, vi, mask=mask)
3103@libentry()
3104@triton.jit
3105def _complex_svd_project_u_kernel(
3106 A_RI,
3107 V_RI,
3108 S,
3109 U_RI,
3110 M: tl.constexpr,
3111 N: tl.constexpr,
3112 K: tl.constexpr,
3113 BLOCK_SIZE: tl.constexpr,
3114):
3115 batch = tl.program_id(0)
3116 offsets = tl.arange(0, BLOCK_SIZE)
3117 mask = offsets < M * K
3118 row = offsets // K
3119 col = offsets % K
3121 acc_r = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
3122 acc_i = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
3123 for j in tl.static_range(0, N):
3124 a_base = A_RI + batch * M * N * 2 + (row * N + j) * 2
3125 v_base = V_RI + batch * N * K * 2 + (j * K + col) * 2
3126 ar = tl.load(a_base, mask=mask, other=0.0)
3127 ai = tl.load(a_base + 1, mask=mask, other=0.0)
3128 vr = tl.load(v_base, mask=mask, other=0.0)
3129 vi = tl.load(v_base + 1, mask=mask, other=0.0)
3130 acc_r += ar * vr - ai * vi
3131 acc_i += ar * vi + ai * vr
3133 s = tl.load(S + batch * K + col, mask=mask, other=1.0)
3134 inv_s = tl.where(s > 1.0e-20, 1.0 / s, 0.0)
3135 out_base = U_RI + batch * M * K * 2 + offsets * 2
3136 tl.store(out_base, acc_r * inv_s, mask=mask)
3137 tl.store(out_base + 1, acc_i * inv_s, mask=mask)
3140def _complex_svd_via_real_embedding(input):
3141 batch, m, n = _svd_shape(input)
3142 k = min(m, n)
3143 src = input.contiguous()
3144 a_ri = tensor_wrapper.TypedPtr.reinterpret_tensor(src, src.dtype.to_real())
3145 real_matrix = torch.empty(
3146 (batch, 2 * m, 2 * n), dtype=torch.float32, device=input.device
3147 )
3148 block_size = triton.next_power_of_2(4 * m * n)
3149 with torch_device_fn.device(input.device):
3150 _complex_to_real_embedding_kernel[(batch,)](
3151 a_ri,
3152 real_matrix,
3153 M=m,
3154 N=n,
3155 BLOCK_SIZE=block_size,
3156 num_warps=1,
3157 )
3158 _, s_real, v_real = svd(real_matrix, some=True, compute_uv=True)
3159 s = torch.empty((batch, k), dtype=torch.float32, device=input.device)
3160 u = torch.empty((*input.shape[:-2], m, k), dtype=input.dtype, device=input.device)
3161 v = torch.empty((*input.shape[:-2], n, k), dtype=input.dtype, device=input.device)
3162 u_ri = tensor_wrapper.TypedPtr.reinterpret_tensor(u, u.dtype.to_real())
3163 v_ri = tensor_wrapper.TypedPtr.reinterpret_tensor(v, v.dtype.to_real())
3164 with torch_device_fn.device(input.device):
3165 _complex_svd_pick_s_kernel[(batch,)](
3166 s_real,
3167 s,
3168 K=k,
3169 REAL_K=2 * k,
3170 BLOCK_K=triton.next_power_of_2(k),
3171 num_warps=1,
3172 )
3173 _complex_svd_pick_orthonormal_v_kernel[(batch,)](
3174 v_real,
3175 v_ri,
3176 ROWS=n,
3177 K=k,
3178 REAL_K=2 * k,
3179 BLOCK_ROWS=triton.next_power_of_2(n),
3180 BLOCK_K=triton.next_power_of_2(k),
3181 num_warps=1,
3182 )
3183 _complex_svd_project_u_kernel[(batch,)](
3184 a_ri,
3185 v_ri,
3186 s,
3187 u_ri,
3188 M=m,
3189 N=n,
3190 K=k,
3191 BLOCK_SIZE=triton.next_power_of_2(m * k),
3192 num_warps=1,
3193 )
3194 return (
3195 u,
3196 s.reshape(*input.shape[:-2], k),
3197 v,
3198 )
3201def _complex_svd_cpu_fallback(input, some=True, compute_uv=True):
3202 cpu_u, cpu_s, cpu_v = torch.svd(input.cpu(), some=some, compute_uv=compute_uv)
3203 return (
3204 cpu_u.to(input.device),
3205 cpu_s.to(input.device),
3206 cpu_v.to(input.device),
3207 )
3210def _gram_svd(input):
3211 return _unsupported_svd(input, True, True)
3214@libentry()
3215@triton.jit
3216def _gram16_finalize_kernel(
3217 A,
3218 EVALS,
3219 EVECS,
3220 U,
3221 S,
3222 V,
3223 M: tl.constexpr,
3224 N: tl.constexpr,
3225 ROWS: tl.constexpr,
3226 TALL: tl.constexpr,
3227 EVECS_BATCH_STRIDE: tl.constexpr,
3228 EVECS_ROW_STRIDE: tl.constexpr,
3229 EVECS_COL_STRIDE: tl.constexpr,
3230 BLOCK_R: tl.constexpr,
3231):
3232 batch = tl.program_id(0)
3233 row_block = tl.program_id(1)
3234 rows = row_block * BLOCK_R + tl.arange(0, BLOCK_R)
3235 cols = tl.arange(0, 16)
3236 src_cols = 15 - cols
3237 row_mask = rows < ROWS
3238 eps = 1.0e-20
3240 vals = tl.load(EVALS + batch * 16 + src_cols)
3241 s_vals = tl.sqrt(tl.maximum(vals, 0.0))
3242 inv_s = tl.where(s_vals > eps, 1.0 / s_vals, 0.0)
3244 acc = tl.zeros((BLOCK_R, 16), dtype=tl.float32)
3245 a_base = A + batch * M * N
3246 e_base = EVECS + batch * EVECS_BATCH_STRIDE
3247 for k in tl.static_range(0, 16):
3248 eig = tl.load(e_base + k * EVECS_ROW_STRIDE + src_cols * EVECS_COL_STRIDE)
3249 if TALL:
3250 a_vals = tl.load(
3251 a_base + rows * N + k,
3252 mask=row_mask,
3253 other=0.0,
3254 )
3255 else:
3256 a_vals = tl.load(
3257 a_base + k * N + rows,
3258 mask=row_mask,
3259 other=0.0,
3260 )
3261 acc += a_vals[:, None] * eig[None, :]
3263 projected = acc * inv_s[None, :]
3264 if TALL:
3265 tl.store(
3266 U + batch * M * 16 + rows[:, None] * 16 + cols[None, :],
3267 projected,
3268 mask=row_mask[:, None],
3269 )
3270 else:
3271 tl.store(
3272 V + batch * N * 16 + rows[:, None] * 16 + cols[None, :],
3273 projected,
3274 mask=row_mask[:, None],
3275 )
3277 head_mask = row_block == 0
3278 tl.store(S + batch * 16 + cols, s_vals, mask=head_mask)
3280 basis_rows = tl.arange(0, 16)
3281 basis_cols = tl.arange(0, 16)
3282 basis_src_cols = 15 - basis_cols
3283 basis = tl.load(
3284 e_base
3285 + basis_rows[:, None] * EVECS_ROW_STRIDE
3286 + basis_src_cols[None, :] * EVECS_COL_STRIDE
3287 )
3288 if TALL:
3289 tl.store(
3290 V + batch * N * 16 + basis_rows[:, None] * 16 + basis_cols[None, :],
3291 basis,
3292 mask=head_mask,
3293 )
3294 else:
3295 tl.store(
3296 U + batch * M * 16 + basis_rows[:, None] * 16 + basis_cols[None, :],
3297 basis,
3298 mask=head_mask,
3299 )
3302def _gram16_svd(input):
3303 return _unsupported_svd(input, True, True)
3306def _large_native_svd(input):
3307 if _can_use_hier_block_square_project_kernel(input, True, True):
3308 return _hier_block_jacobi_square_project_svd(input)
3309 if _can_use_blocked_square_project_kernel(input, True, True):
3310 return _blocked_jacobi_square_project_svd(input)
3311 if _can_use_blocked_jacobi_kernel(input, True, True):
3312 return _blocked_jacobi_svd(input)
3313 return _bidiagonal_qr_dqds_svd(input)
3316def _bidiagonal_qr_dqds_svd(input):
3317 return _unsupported_svd(
3318 input,
3319 True,
3320 True,
3321 "The k > 512 blocked-bidiagonalization plus QR/DQDS path is reserved "
3322 "for the next native large-matrix solver stage.",
3323 )
3326def _empty_svd_result(input, some=True, compute_uv=True):
3327 _, m, n = _svd_shape(input)
3328 k = min(m, n)
3329 u_cols = k if compute_uv and some else m
3330 v_cols = k if compute_uv and some else n
3331 u = torch.empty(
3332 (*input.shape[:-2], m, u_cols), dtype=input.dtype, device=input.device
3333 )
3334 s = torch.empty((*input.shape[:-2], k), dtype=input.dtype, device=input.device)
3335 v = torch.empty(
3336 (*input.shape[:-2], n, v_cols), dtype=input.dtype, device=input.device
3337 )
3338 return u, s, v
3341@libentry()
3342@triton.jit
3343def _complete_svd_factor_kernel(
3344 THIN,
3345 FULL,
3346 ROWS: tl.constexpr,
3347 THIN_COLS: tl.constexpr,
3348 FULL_COLS: tl.constexpr,
3349 BLOCK_ROWS: tl.constexpr,
3350 BLOCK_COLS: tl.constexpr,
3351):
3352 batch = tl.program_id(0)
3353 rows = tl.arange(0, BLOCK_ROWS)
3354 cols = tl.arange(0, BLOCK_COLS)
3355 row_mask = rows < ROWS
3356 col_mask = cols < FULL_COLS
3357 vals = tl.load(
3358 THIN + batch * ROWS * THIN_COLS + rows[:, None] * THIN_COLS + cols[None, :],
3359 mask=row_mask[:, None] & (cols[None, :] < THIN_COLS),
3360 other=0.0,
3361 )
3362 identity = tl.where(rows[:, None] == cols[None, :], 1.0, 0.0)
3363 vals = tl.where(cols[None, :] < THIN_COLS, vals, identity)
3365 for c in tl.static_range(0, 64):
3366 cur_mask = c < FULL_COLS
3367 cur = tl.sum(tl.where(cols[None, :] == c, vals, 0.0), axis=1)
3368 for p in tl.static_range(0, c):
3369 prev = tl.sum(tl.where(cols[None, :] == p, vals, 0.0), axis=1)
3370 coeff = tl.sum(tl.where(row_mask, prev * cur, 0.0), axis=0)
3371 cur -= prev * coeff
3372 norm_sq = tl.sum(tl.where(row_mask, cur * cur, 0.0), axis=0)
3373 inv_norm = tl.rsqrt(tl.maximum(norm_sq, 1.0e-20))
3374 cur *= inv_norm
3375 vals = tl.where((cols[None, :] == c) & cur_mask, cur[:, None], vals)
3377 out_base = FULL + batch * ROWS * FULL_COLS
3378 offsets = rows[:, None] * FULL_COLS + cols[None, :]
3379 mask = row_mask[:, None] & col_mask[None, :]
3380 tl.store(out_base + offsets, vals, mask=mask)
3383def _low_precision_svd_via_float32(input, some=True, compute_uv=True):
3384 u, s, v = svd(input.to(torch.float32), some=some, compute_uv=compute_uv)
3385 return u.to(input.dtype), s.to(input.dtype), v.to(input.dtype)
3388def _some_false_svd_via_thin(input):
3389 batch, m, n = _svd_shape(input)
3390 k = min(m, n)
3391 thin_u, s, thin_v = svd(input, some=True, compute_uv=True)
3392 u = torch.empty((*input.shape[:-2], m, m), dtype=input.dtype, device=input.device)
3393 v = torch.empty((*input.shape[:-2], n, n), dtype=input.dtype, device=input.device)
3394 with torch_device_fn.device(input.device):
3395 _complete_svd_factor_kernel[(batch,)](
3396 thin_u,
3397 u,
3398 ROWS=m,
3399 THIN_COLS=k,
3400 FULL_COLS=m,
3401 BLOCK_ROWS=triton.next_power_of_2(m),
3402 BLOCK_COLS=triton.next_power_of_2(m),
3403 num_warps=4,
3404 )
3405 _complete_svd_factor_kernel[(batch,)](
3406 thin_v,
3407 v,
3408 ROWS=n,
3409 THIN_COLS=k,
3410 FULL_COLS=n,
3411 BLOCK_ROWS=triton.next_power_of_2(n),
3412 BLOCK_COLS=triton.next_power_of_2(n),
3413 num_warps=4,
3414 )
3415 return u, s, v
3418def _compute_uv_false_result(input, s):
3419 _, m, n = _svd_shape(input)
3420 u = torch.empty((*input.shape[:-2], m, m), dtype=input.dtype, device=input.device)
3421 v = torch.empty((*input.shape[:-2], n, n), dtype=input.dtype, device=input.device)
3422 return u, s, v
3425def _singular_values_only(input):
3426 _, m, n = _svd_shape(input)
3427 k = min(m, n)
3428 largest = max(m, n)
3429 if k == 2 and largest <= _RANK2_BLOCK_R_MAX:
3430 return _rank2_singular_values(input)
3431 if k <= 16 and largest <= 1024:
3432 return _small_jacobi_singular_values(input)
3433 if 16 < k <= 512 and largest <= 1024:
3434 return _blocked_jacobi_singular_values(input)
3435 return _unsupported_svd(input, True, False)
3438def _should_use_gram16(batch, m, n):
3439 return batch >= 16 and min(m, n) == 16 and max(m, n) <= 1024
3442def _should_use_gram(batch, m, n):
3443 k = min(m, n)
3444 largest = max(m, n)
3445 if k <= 32:
3446 return True
3447 if batch <= 4 and m == n and m <= 256:
3448 return True
3449 if (m, n) == (1024, 1024):
3450 return True
3451 if batch >= 128 and k <= 64 and largest <= 1024:
3452 return False
3453 return False
3456def svd(input, some=True, compute_uv=True):
3457 logger.debug("GEMS SVD")
3458 if (
3459 input.device.type == "ptpu"
3460 and input.dtype == torch.complex64
3461 and some
3462 and compute_uv
3463 and input.dim() >= 2
3464 and 0 not in input.shape[-2:]
3465 and max(input.shape[-2:]) <= 16
3466 ):
3467 return SVDResult(*_complex_svd_cpu_fallback(input, some, compute_uv))
3468 if _is_low_precision_cuda_matrix(input):
3469 return SVDResult(*_low_precision_svd_via_float32(input, some, compute_uv))
3470 if _is_float32_cuda_matrix(input) and 0 in input.shape[-2:]:
3471 return SVDResult(*_empty_svd_result(input, some, compute_uv))
3472 if _can_use_singular_values_only(input, some, compute_uv):
3473 return SVDResult(*_compute_uv_false_result(input, _singular_values_only(input)))
3474 if (
3475 _is_float32_cuda_matrix(input)
3476 and not some
3477 and compute_uv
3478 and max(input.shape[-2:]) <= 64
3479 ):
3480 return SVDResult(*_some_false_svd_via_thin(input))
3481 if not _is_float32_cuda_matrix(input) or not some:
3482 return SVDResult(*_unsupported_svd(input, some, compute_uv))
3483 batch, m, n = _svd_shape(input)
3484 k = min(m, n)
3485 try:
3486 if k == 1:
3487 return SVDResult(*_rank1_svd(input))
3488 if k == 2 and max(m, n) <= _RANK2_BLOCK_R_MAX:
3489 return SVDResult(*_rank2_svd(input))
3490 if k == 4 and m == 4 and n == 4 and batch >= 16:
3491 return SVDResult(*_small4_square_svd(input))
3492 if _can_use_tall_wide_gram_jacobi_kernel(input, some, compute_uv):
3493 return SVDResult(*_gram_jacobi_svd(input))
3494 use_batched_cyclic16 = k == 16 and batch >= 8 and max(m, n) <= 64
3495 if (
3496 _can_use_small_jacobi_kernel(input, some, compute_uv)
3497 and not use_batched_cyclic16
3498 ):
3499 return SVDResult(*_small_jacobi_svd(input))
3500 if _can_use_tsqr_cholesky_kernel(input, some, compute_uv):
3501 return SVDResult(*_tsqr_cholesky_svd(input))
3502 if _can_use_projected_jacobi_kernel(input, some, compute_uv):
3503 return SVDResult(*_projected_jacobi_svd(input))
3504 if _can_use_cyclic_jacobi_kernel(input, some, compute_uv):
3505 return SVDResult(*_cyclic_jacobi_svd(input))
3506 if _can_use_hier_block_square_project_kernel(input, some, compute_uv):
3507 return SVDResult(*_hier_block_jacobi_square_project_svd(input))
3508 if _can_use_blocked_square_project_kernel(input, some, compute_uv):
3509 return SVDResult(*_blocked_jacobi_square_project_svd(input))
3510 if _can_use_blocked_jacobi_kernel(input, some, compute_uv):
3511 return SVDResult(*_blocked_jacobi_svd(input))
3512 if _should_use_gram16(batch, m, n):
3513 return SVDResult(*_gram16_svd(input))
3514 if _should_use_gram(batch, m, n):
3515 return SVDResult(*_gram_svd(input))
3516 return SVDResult(*_large_native_svd(input))
3517 except RuntimeError:
3518 return SVDResult(*_unsupported_svd(input, some, compute_uv))