Source code for linalg.svd_symeig

Implementation taken from
which follows derivation given in
import torch

def safe_inverse(x, epsilon=1E-12):
    return x/(x**2 + epsilon)

[docs]class SVDSYMEIG(torch.autograd.Function):
[docs] @staticmethod def forward(self, A): r""" :param A: square symmetric matrix :type A: torch.tensor :return: left singular vectors U, singular values S, and right singular vectors V :rtype: torch.tensor, torch.tensor, torch.tensor Computes SVD of a matrix M, where M is symmetric :math:`M=M^T`, through symmetric decomposition :math:`M= UDU^T`. """ # input validation (A is square and symmetric) is provided by torch.symeig D, U = torch.symeig(A, eigenvectors=True) # torch.symeig returns eigenpairs ordered in the ascending order with # respect to eigenvalues. Reorder the eigenpairs by abs value of the eigenvalues # abs(D) S,p= torch.sort(torch.abs(D),descending=True) U= U[:,p] # TODO how to handle case of vanishingly small S <=> Sinv -> inf # in principle, the multiplication by Sinv just scales the rows # of V to norm 1, such that V is unitary # 0) M = USV^t => (M^t)US^-1 = V => (M^t=M) => MUS^-1 = V # eps_cutoff=S[0] * 1.0e-14 # Sinv= 1/S # Sinv[Sinv > 1/eps_cutoff]= 0. # V= A@U@torch.diag(Sinv) # 1) M = UDU^t = US(sgn)U^t = U S (sgn)U^t = U S V^t # (sgn) is a diagonal matrix with signs of the eigenvales D V= U@torch.diag(torch.sign(D[p])) self.save_for_backward(U, S, V) return U, S, V
[docs] @staticmethod def backward(self, dU, dS, dV): U, S, V = self.saved_tensors Vt = V.t() Ut = U.t() M = U.size(0) N = V.size(0) NS = len(S) F = (S - S[:, None]) F = safe_inverse(F) #F= 1/F F.diagonal().fill_(0) # F[abs(F) > 1.0e+8]=0 G = (S + S[:, None]) G = safe_inverse(G) G.diagonal().fill_(0) # G = 1/G # G[abs(G) > 1.0e+8]=0 UdU = Ut @ dU VdV = Vt @ dV Su = (F+G)*(UdU-UdU.t())/2 Sv = (F-G)*(VdV-VdV.t())/2 dA = U @ (Su + Sv + torch.diag(dS)) @ Vt if (M>NS): dA = dA + (torch.eye(M, dtype=dU.dtype, device=dU.device) - U@Ut) @ (dU/S) @ Vt if (N>NS): dA = dA + (U/S) @ dV.t() @ (torch.eye(N, dtype=dU.dtype, device=dU.device) - V@Vt) return dA
def test_SVDSYMEIG_random(): m= 50 M= torch.rand(m, m, dtype=torch.float64) M= 0.5*(M+M.t()) U,S,V= SVDSYMEIG.apply(M) assert( torch.norm(M-U@torch.diag(S)@V.t()) < S[0]*(m**2)*1e-14 ) # since we always assume matrix M to be symmetric, the finite difference # perturbations should be symmetric as well M.requires_grad_(True) def force_sym_SVD(M): M=0.5*(M+M.t()) return SVDSYMEIG.apply(M) assert(torch.autograd.gradcheck(force_sym_SVD, M, eps=1e-6, atol=1e-4)) def test_SVDSYMEIG_su2sym(): import su2sym.sym_ten_parser as tenSU2 # Available D: [3,5,7,9] for D in [3,5,7]: su2sym_t= tenSU2.import_sym_tensors(2,D,"A_1",dtype=torch.float64) c= torch.rand(len(su2sym_t), dtype=torch.float64) ts= torch.stack([tensor for meta,tensor in su2sym_t]) a= torch.einsum('i,ipuldr->puldr',c,ts) D2= D**2 M= torch.einsum('mijef,mijab->eafb',(a,a)).contiguous().view(D2, D2) U,S,V= SVDSYMEIG.apply(M) assert( torch.norm(M-U@torch.diag(S)@V.t()) < S[0]*(M.size()[0]**2)*1e-14 ) M.requires_grad_(True) def force_sym_SVD(M): M=0.5*(M+M.t()) return SVDSYMEIG.apply(M) assert(torch.autograd.gradcheck(force_sym_SVD, M, eps=1e-6, atol=1e-4)) def test_SVDSYMEIG_3x3degenerate(): M= torch.zeros((3,3),dtype=torch.float64) M[0,1]=M[0,2]=M[1,2]=1. M= 0.5*(M+M.t()) print(M) U,S,V= SVDSYMEIG.apply(M) assert( torch.norm(M-U@torch.diag(S)@V.t()) < S[0]*(M.size()[0]**2)*1e-14 ) M.requires_grad_(True) torch.set_printoptions(precision=9) def force_sym_SVD(M): M=0.5*(M+M.t()) print(M) U,S,V= SVDSYMEIG.apply(M) return U assert(torch.autograd.gradcheck(force_sym_SVD, M, eps=1e-6, atol=1e-4)) def test_SVDSYMEIG_rank_deficient(): m= 50 r= 10 M= torch.rand((m,m),dtype=torch.float64) M= M+M.t() D, U= torch.symeig(M, eigenvectors=True) D[-r:]=0 M= U@torch.diag(D)@U.t() U,S,V= SVDSYMEIG.apply(M) assert( torch.norm(M-U@torch.diag(S)@V.t()) < S[0]*(M.size()[0]**2)*1e-14 ) M.requires_grad_(True) def force_sym_SVD(M): M=0.5*(M+M.t()) U,S,V= SVDSYMEIG.apply(M) return U assert(torch.autograd.gradcheck(force_sym_SVD, M, eps=1e-6, atol=1e-4)) if __name__=='__main__': import os import sys sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) test_SVDSYMEIG_random() test_SVDSYMEIG_rank_deficient() # test_SVDSYMEIG_3x3degenerate() # test_SVDSYMEIG_su2sym()