Spectral decompositions

Implementation of adjoint functions is based on https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf

SVD

class linalg.svd_gesdd.SVDGESDD(*args, **kwargs)[source]
static backward(self, gu, gsigma, gv)[source]
Parameters
  • gu (torch.Tensor) – gradient on U

  • gsigma (torch.Tensor) – gradient on S

  • gv (torch.Tensor) – gradient on V

Returns

gradient

Return type

torch.Tensor

Computes backward gradient for SVD, adopted from https://github.com/pytorch/pytorch/blob/v1.10.2/torch/csrc/autograd/FunctionsManual.cpp

For complex-valued input there is an additional term, see

The backward is regularized following

using

\[S_i/(S^2_i-S^2_j) = (F_{ij}+G_{ij})/2\ \ \textrm{and}\ \ S_j/(S^2_i-S^2_j) = (F_{ij}-G_{ij})/2\]

where

\[F_{ij}=1/(S_i-S_j),\ G_{ij}=1/(S_i+S_j)\]
static forward(self, A, cutoff, diagnostics)[source]
Parameters
  • A (torch.Tensor) – rank-2 tensor

  • cutoff (torch.Tensor) – cutoff for backward function

  • diagnostics (dict) – optional dictionary for debugging purposes

Returns

U, S, V

Return type

torch.Tensor, torch.Tensor, torch.Tensor

Computes SVD decompostion of matrix \(A = USV^\dagger\).

Partial SVD

class linalg.svd_arnoldi.SVDARNOLDI(*args, **kwargs)[source]
static backward(self, dU, dS, dV)[source]

The backward is not implemented.

static forward(self, M, k)[source]
Parameters
  • M (torch.Tensor) – square matrix \(N \times N\)

  • k (int) – desired rank (must be smaller than \(N\))

Returns

leading k left eigenvectors U, singular values S, and right eigenvectors V

Return type

torch.Tensor, torch.Tensor, torch.Tensor

Note: depends on scipy

Return leading k-singular triples of a matrix M, by computing the symmetric decomposition of \(H=MM^\dagger\) as \(H= UDU^\dagger\) up to rank k. Partial eigendecomposition is done through Arnoldi method.

Randomized SVD

class linalg.svd_rsvd.RSVD(*args, **kwargs)[source]
static backward(self, dU, dS, dV)[source]
Parameters
  • dU (torch.Tensor) – gradient on U

  • dS (torch.Tensor) – gradient on S

  • dV (torch.Tensor) – gradient on V

Returns

gradient

Return type

torch.Tensor

The backward is evaluated as in linalg.svd_gesdd.SVDGESDD.backward() for real input matrix.

static forward(self, M, k, p=20, q=2, s=1, vnum=1)[source]
Parameters
  • M (torch.Tensor) – real matrix

  • k (int) – desired rank

  • p (int) – oversampling rank. Total rank sampled k+p

  • q (int) – number of matrix-vector multiplications for power scheme

  • s (int) – re-orthogonalization

Returns

approximate leading k left singular vectors U, singular values S, and right singular vectors V

Return type

torch.Tensor, torch.Tensor, torch.Tensor

Performs approximate truncated SVD of real matrix M using randomized sampling as \(M=USV^T\). Based on the implementation in https://arxiv.org/abs/1502.05366

Symmetric Eigendecomposition

class linalg.eig_sym.SYMEIG(*args, **kwargs)[source]
static backward(self, dD, dU)[source]
Parameters
  • dD (torch.Tensor) – gradient on D

  • dU (torch.Tensor) – gradient on U

Returns

gradient

Return type

torch.Tensor

Computes backward gradient for ED of symmetric matrix with regularization of \(F_{ij}=1/(D_i - D_j)\)

static forward(self, A, ad_decomp_reg)[source]
Parameters

A (torch.Tensor) – square symmetric matrix

Returns

eigenvalues values D, eigenvectors vectors U

Return type

torch.Tensor, torch.Tensor

Computes symmetric decomposition \(M= UDU^\dagger\).

Partial diagonalization

class linalg.eig_arnoldi.SYMARNOLDI(*args, **kwargs)[source]
static backward(self, dD, dU)[source]

The backward is not implemented.

static forward(self, M, k)[source]
Parameters
  • M (torch.tensor) – square symmetric matrix \(N \times N\)

  • k (int) – desired rank (must be smaller than \(N\))

Returns

eigenvalues D, leading k eigenvectors U

Return type

torch.Tensor, torch.Tensor

Note: depends on scipy

Return leading k-eigenpairs of a matrix M, where M is symmetric \(M=M^\dagger\), by computing the symmetric decomposition \(M= UDU^\dagger\) up to rank k. Partial eigendecomposition is done through Arnoldi method.