Source code for linalg.svd_gesdd

import torch
import numpy as np
from config import _torch_version_check

def safe_inverse(x, epsilon=1E-12):
    return x/(x**2 + epsilon)
    
def safe_inverse_2(x, epsilon):
    x[abs(x)<epsilon]=float('inf')
    return x.pow(-1)

class SVDGESDD_legacy(torch.autograd.Function):
    @staticmethod
    def forward(self, A, cutoff, diagnostics):
        U, S, V = torch.svd(A)
        cutoff= torch.as_tensor(cutoff, dtype=S.dtype, device=S.device)
        self.save_for_backward(U, S, V, cutoff)
        return U, S, V
    
    @staticmethod
    def backward(self, dU, dS, dV):
        U, S, V, cutoff = self.saved_tensors
        
        Vt = V.t()
        Ut = U.t()
        M = U.size(0)
        N = V.size(0)
        NS = S.size(0)

        F = (S - S[:, None])
        # mask0= F==0
        F = safe_inverse(F, cutoff)
        F.diagonal().fill_(0)
        # F[mask0]= 0

        G = (S + S[:, None])
        G = safe_inverse(G)
        G.diagonal().fill_(0)

        UdU = Ut @ dU
        VdV = Vt @ dV

        # F_ij= 1/(S_i - S_j)
        # G_ij= 1/(S_i + S_j)
        # (F+G)_ij= 1/(S_i - S_j) - 1/(S_i + S_j) = ((S_i + S_j) - (S_i - S_j))/((S_i - S_j)*(S_i + S_j))
        #         = 2S_j / (S^2_i - S^2_j) 
        Su = (F+G)*(UdU-UdU.t())/2
        Sv = (F-G)*(VdV-VdV.t())/2

        dA = U @ (Su + Sv + torch.diag(dS)) @ Vt
        if (M>NS):
            dA = dA + (torch.eye(M, dtype=dU.dtype, device=dU.device) - U@Ut) @ (dU*safe_inverse(S)) @ Vt 
        if (N>NS):
            dA = dA + (U*safe_inverse(S)) @ dV.t() @ (torch.eye(N, dtype=dU.dtype, device=dU.device) - V@Vt)

        return dA, None, None

[docs]class SVDGESDD(torch.autograd.Function): if _torch_version_check("1.8.1"): @staticmethod def forward(self, A, cutoff, diagnostics): r""" :param A: rank-2 tensor :type A: torch.Tensor :param cutoff: cutoff for backward function :type cutoff: torch.Tensor :param diagnostics: optional dictionary for debugging purposes :type diagnostics: dict :return: U, S, V :rtype: torch.Tensor, torch.Tensor, torch.Tensor Computes SVD decompostion of matrix :math:`A = USV^\dagger`. """ # A = U @ diag(S) @ Vh U, S, Vh = torch.linalg.svd(A) V= Vh.transpose(-2,-1).conj() self.save_for_backward(U, S, V, cutoff) self.diagnostics= diagnostics return U, S, V else:
[docs] @staticmethod def forward(self, A, cutoff, diagnostics): r""" :param A: rank-2 tensor :type A: torch.Tensor :param cutoff: cutoff for backward function :type cutoff: torch.Tensor :param diagnostics: optional dictionary for debugging purposes :type diagnostics: dict :return: U, S, V :rtype: torch.Tensor, torch.Tensor, torch.Tensor Computes SVD decompostion of matrix :math:`A = USV^\dagger`. """ U, S, V = torch.svd(A) self.diagnostics= diagnostics self.save_for_backward(U, S, V, cutoff) return U, S, V
@staticmethod def v1_10_Fonly_backward(self, gu, gsigma, gv): # Adopted from # https://github.com/pytorch/pytorch/blob/v1.10.2/torch/csrc/autograd/FunctionsManual.cpp # # TORCH_CHECK(compute_uv, # "svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ", # "and hence we cannot compute backward. Please use torch.svd(compute_uv=True)"); diagnostics= self.diagnostics u, sigma, v, eps = self.saved_tensors m= u.size(-2) # first dim of original tensor A = u sigma v^\dag n= v.size(-2) # second dim of A k= sigma.size(0) sigma_scale= sigma[0] # ? some if (u.size(-2)!=u.size(-1)) or (v.size(-2)!=v.size(-1)): # We ignore the free subspace here because possible base vectors cancel # each other, e.g., both -v and +v are valid base for a dimension. # Don't assume behavior of any particular implementation of svd. u = u.narrow(-1, 0, k) v = v.narrow(-1, 0, k) if not (gu is None): gu = gu.narrow(-1, 0, k) if not (gv is None): gv = gv.narrow(-1, 0, k) vh= v.conj().transpose(-2,-1) if not (gsigma is None): # computes u @ diag(gsigma) @ vh sigma_term = u * gsigma.unsqueeze(-2) @ vh else: sigma_term = torch.zeros(m,n,dtype=u.dtype,device=u.device) # in case that there are no gu and gv, we can avoid the series of kernel # calls below if (gv is None) and (gv is None): if not (diagnostics is None): print(f"{diagnostics} {dA.abs().max()} {S.max()}") return sigma_term, None, None sigma_inv= safe_inverse_2(sigma.clone(), sigma_scale*eps) sigma_sq= sigma.pow(2) F= sigma_sq.unsqueeze(-2) - sigma_sq.unsqueeze(-1) # F_ij = 1/(S^2_i - S^2_j) # // The following two lines invert values of F, and fills the diagonal with 0s. # // Notice that F currently has 0s on diagonal. So we fill diagonal with +inf # // first to prevent nan from appearing in backward of this function. F.diagonal(0,-2,-1).fill_(float('inf')) F= safe_inverse_2(F, sigma_scale*eps) uh= u.conj().transpose(-2,-1) if not (gu is None): guh = gu.conj().transpose(-2, -1); u_term = u @ (F.mul( uh @ gu - guh @ u) * sigma.unsqueeze(-2)) if m > k: # projection operator onto subspace orthogonal to span(U) defined as I - UU^H proj_on_ortho_u = -u @ uh proj_on_ortho_u.diagonal(0, -2, -1).add_(1); u_term = u_term + proj_on_ortho_u @ (gu * sigma_inv.unsqueeze(-2)) u_term = u_term @ vh else: u_term = torch.zeros(m,n,dtype=u.dtype,device=u.device) if not (gv is None): gvh = gv.conj().transpose(-2, -1); v_term = sigma.unsqueeze(-1) * (F.mul(vh @ gv - gvh @ v) @ vh) if n > k: # projection operator onto subspace orthogonal to span(V) defined as I - VV^H proj_on_v_ortho = -v @ vh proj_on_v_ortho.diagonal(0, -2, -1).add_(1); v_term = v_term + sigma_inv.unsqueeze(-1) * (gvh @ proj_on_v_ortho) v_term = u @ v_term else: v_term = torch.zeros(m,n,dtype=u.dtype,device=u.device) # // for complex-valued input there is an additional term # // https://giggleliu.github.io/2019/04/02/einsumbp.html # // https://arxiv.org/abs/1909.02659 dA= u_term + sigma_term + v_term if u.is_complex() or v.is_complex(): L= (uh @ gu).diagonal(0,-2,-1) L.real.zero_() L.imag.mul_(sigma_inv) imag_term= (u * L.unsqueeze(-2)) @ vh dA= dA + imag_term if not (diagnostics is None): print(f"{diagnostics} {dA.abs().max()} {S.max()}") return dA, None, None
[docs] @staticmethod def backward(self, gu, gsigma, gv): r""" :param gu: gradient on U :type gu: torch.Tensor :param gsigma: gradient on S :type gsigma: torch.Tensor :param gv: gradient on V :type gv: torch.Tensor :return: gradient :rtype: torch.Tensor Computes backward gradient for SVD, adopted from https://github.com/pytorch/pytorch/blob/v1.10.2/torch/csrc/autograd/FunctionsManual.cpp For complex-valued input there is an additional term, see * https://giggleliu.github.io/2019/04/02/einsumbp.html * https://arxiv.org/abs/1909.02659 The backward is regularized following * https://github.com/wangleiphy/tensorgrad/blob/master/tensornets/adlib/svd.py * https://arxiv.org/abs/1903.09650 using .. math:: S_i/(S^2_i-S^2_j) = (F_{ij}+G_{ij})/2\ \ \textrm{and}\ \ S_j/(S^2_i-S^2_j) = (F_{ij}-G_{ij})/2 where .. math:: F_{ij}=1/(S_i-S_j),\ G_{ij}=1/(S_i+S_j) """ # # TORCH_CHECK(compute_uv, # "svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ", # "and hence we cannot compute backward. Please use torch.svd(compute_uv=True)"); diagnostics= self.diagnostics u, sigma, v, eps = self.saved_tensors m= u.size(0) # first dim of original tensor A = u sigma v^\dag n= v.size(0) # second dim of A k= sigma.size(0) sigma_scale= sigma[0] # ? some if (u.size(-2)!=u.size(-1)) or (v.size(-2)!=v.size(-1)): # We ignore the free subspace here because possible base vectors cancel # each other, e.g., both -v and +v are valid base for a dimension. # Don't assume behavior of any particular implementation of svd. u = u.narrow(-1, 0, k) v = v.narrow(-1, 0, k) if not (gu is None): gu = gu.narrow(-1, 0, k) if not (gv is None): gv = gv.narrow(-1, 0, k) vh= v.conj().transpose(-2,-1) if not (gsigma is None): # computes u @ diag(gsigma) @ vh sigma_term = u * gsigma.unsqueeze(-2) @ vh else: sigma_term = torch.zeros(m,n,dtype=u.dtype,device=u.device) # in case that there are no gu and gv, we can avoid the series of kernel # calls below if (gv is None) and (gv is None): if not (diagnostics is None): print(f"{diagnostics} {dA.abs().max()} {S.max()}") return sigma_term, None, None sigma_inv= safe_inverse_2(sigma.clone(), sigma_scale*eps) F = sigma.unsqueeze(-2) - sigma.unsqueeze(-1) F = safe_inverse(F, sigma_scale*eps) F.diagonal(0,-2,-1).fill_(0) G = sigma.unsqueeze(-2) + sigma.unsqueeze(-1) G = safe_inverse(G, sigma_scale*eps) G.diagonal(0,-2,-1).fill_(0) uh= u.conj().transpose(-2,-1) if not (gu is None): guh = gu.conj().transpose(-2, -1); u_term = u @ ( (F+G).mul( uh @ gu - guh @ u) ) * 0.5 if m > k: # projection operator onto subspace orthogonal to span(U) defined as I - UU^H proj_on_ortho_u = -u @ uh proj_on_ortho_u.diagonal(0, -2, -1).add_(1); u_term = u_term + proj_on_ortho_u @ (gu * sigma_inv.unsqueeze(-2)) u_term = u_term @ vh else: u_term = torch.zeros(m,n,dtype=u.dtype,device=u.device) if not (gv is None): gvh = gv.conj().transpose(-2, -1); v_term = ( (F-G).mul(vh @ gv - gvh @ v) ) @ vh * 0.5 if n > k: # projection operator onto subspace orthogonal to span(V) defined as I - VV^H proj_on_v_ortho = -v @ vh proj_on_v_ortho.diagonal(0, -2, -1).add_(1); v_term = v_term + sigma_inv.unsqueeze(-1) * (gvh @ proj_on_v_ortho) v_term = u @ v_term else: v_term = torch.zeros(m,n,dtype=u.dtype,device=u.device) # // for complex-valued input there is an additional term # // https://giggleliu.github.io/2019/04/02/einsumbp.html # // https://arxiv.org/abs/1909.02659 dA= u_term + sigma_term + v_term if u.is_complex() or v.is_complex(): L= (uh @ gu).diagonal(0,-2,-1) L.real.zero_() L.imag.mul_(sigma_inv) imag_term= (u * L.unsqueeze(-2)) @ vh dA= dA + imag_term if not (diagnostics is None): print(f"{diagnostics} {dA.abs().max()} {sigma.max()}") return dA, None, None
# From https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/FunctionsManual.cpp # commit 5375b2e (v1.11.+) # Tensor svd_backward(const Tensor& gU, # const Tensor& gS, # const Tensor& gVh, # const Tensor& U, # const Tensor& S, # const Tensor& Vh) { # at::NoTF32Guard disable_tf32; @staticmethod def v1_11_backward(self, gU, gS, gVh): U, S, Vh, ad_decomp_reg = self.saved_tensors diagnostics = self.diagnostics # // Throughout both the real and complex case we assume A has distinct singular values. # // Furthermore, if A is rectangular or complex, we assume it's full-rank. # // # // # // The real case (A \in R) # // See e.g. https://j-towns.github.io/papers/svd-derivative.pdf # // # // Denote by skew(X) = X - X^T, and by A o B the coordinatewise product, then # // if m == n # // gA = U [(skew(U^T gU) / E)S + S(skew(V^T gV) / E) + I o gS ]V^T # // where E_{jk} = S_k^2 - S_j^2 if j != k and 1 otherwise # // # // if m > n # // gA = [term in m == n] + (I_m - UU^T)gU S^{-1} V^T # // if m < n # // gA = [term in m == n] + U S^{-1} (gV)^T (I_n - VV^T) # // # // # // The complex case (A \in C) # // This one is trickier because the svd is not locally unique. # // Denote L = diag(e^{i\theta_k}), then we have that if A = USV^H, then (UL, S, VL) is # // another valid SVD decomposition of A as # // A = ULS(VL)^H = ULSL^{-1}V^H = USV^H, # // since L, S and L^{-1} commute, since they are all diagonal. # // # // Assume wlog that n >= k in what follows, as otherwise we could reason about A^H. # // Denote by St_k(C^n) = {A \in C^{n,k} | A^H A = I_k} the complex Stiefel manifold. # // What this invariance means is that the svd decomposition is not a map # // svd: C^{n x k} -> St_k(C^n) x R^n x St_k(C^k) # // (where St_k(C^k) is simply the unitary group U(k)) but a map # // svd: C^{n x k} -> M x R^n # // where M is the manifold given by quotienting St_k(C^n) x U(n) by the action (U, V) -> (UL, VL) # // with L as above. # // Note that M is a manifold, because the action is free and proper (as U(1)^k \iso (S^1)^k is compact). # // For this reason, pi : St_k(C^n) x U(n) -> M forms a principal bundle. # // # // To think about M, consider the case case k = 1. The, we have the bundle # // pi : St_1(C^n) x U(1) -> M # // now, St_1(C^n) are just vectors of norm 1 in C^n. That's exactly the sphere of dimension 2n-1 in C^n \iso R^{2n} # // S^{2n-1} = { z \in C^n | z^H z = 1}. # // Then, in this case, we're quotienting out U(1) completely, so we get that # // pi : S^{2n-1} x U(1) -> CP(n-1) # // where CP(n-1) is the complex projective space of dimension n-1. # // In other words, M is just the complex projective space, and pi is (pretty similar to) # // the usual principal bundle from S^{2n-1} to CP(n-1). # // The case k > 1 is the same, but requiring a linear inependence condition between the # // vectors from the different S^{2n-1} or CP(n-1). # // # // Note that this is a U(1)^k-bundle. In plain words, this means that the fibres of this bundle, # // i.e. pi^{-1}(x) for x \in M are isomorphic to U(1) x ... x U(1). # // This is obvious as, if pi(U,V) = x, # // pi^{-1}(x) = {(U diag(e^{i\theta}), V diag(e^{i\theta})) | \theta \in R^k} # // = {(U diag(z), V diag(z)) | z \in U(1)^k} # // since U(1) = {z \in C | |z| = 1}. # // # // The big issue here is that M with its induced metric is not locally isometric to St_k(C^n) x U(k). # // [The why is rather technical, but you can see that the horizontal distribution is not involutive, # // and hence integrable due to Frobenius' theorem] # // What this means in plain words is that, no matter how we choose to return the U and V from the # // SVD, we won't be able to simply differentiate wrt. U and V and call it a day. # // An example of a case where we can do this is when performing an eigendecomposition on a real # // matrix that happens to have real eigendecomposition. In this case, even though you can rescale # // the eigenvectors by any real number, you can choose them of norm 1 and call it a day. # // In the eigenvector case, we are using that you can isometrically embed S^{n-1} into R^n. # // In the svd case, we need to work with the "quotient manifold" M explicitly, which is # // slightly more technically challenging. # // # // Since the columns of U and V are not uniquely defined, but are representatives of certain # // classes of equivalence which represent elements M, the user may not depend on the particular # // representative that we return from the SVD. In particular, if the loss function depends on U # // or V, it must be invariant under the transformation (U, V) -> (UL, VL) with # // L = diag(e^{i\theta})), for every \theta \in R^k. # // In more geometrical terms, this means that the loss function should be constant on the fibres, # // or, in other words, the gradient along the fibres should be zero. # // We may see this by checking that the gradients as element in the tangent space # // T_{(U, V)}(St(n,k) x U(k)) are normal to the fibres. Differentiating the map # // (U, V) -> (UL, VL), we see that the space tangent to the fibres is given by # // Vert_{(U, V)}(St(n,k) x U(k)) = { i[U, V]diag(\theta) | \theta in R^k} # // where [U, V] denotes the vertical concatenation of U and V to form an (n+k, k) matrix. # // Then, solving # // <i[U,V]diag(\theta), [S, T]> = 0 for two matrices S, T \in T_{(U, V)}(St(n,k) x U(k)) # // where <A, B> = Re tr(A^H B) is the canonical (real) inner product in C^{n x k} # // we get that the function is invariant under action of U(1)^k iff # // Im(diag(U^H gU + V^H gV)) = 0 # // # // Using this in the derviaton for the forward AD, one sees that, with the notation from those notes # // Using this and writing sym(X) = X + X^H, we get that the forward AD for SVD in the complex # // case is given by # // dU = U (sym(dX S) / E + i Im(diag(dX)) / (2S)) # // if m > n # // dU = [dU for m == n] + (I_m - UU^H) dA V S^{-1} # // dS = Re(diag(dP)) # // dV = V (sym(S dX) / E - i Im(diag(dX)) / (2S)) # // if m < n # // dV = [dV for m == n] + (I_n - VV^H) (dA)^H U S^{-1} # // dVh = dV^H # // with dP = U^H dA V # // dX = dP - dS # // E_{jk} = S_k^2 - S_j^2 if j != k # // 1 otherwise # // # // Similarly, writing skew(X) = X - X^H # // the adjoint wrt. the canonical metric is given by # // if m == n # // gA = U [((skew(U^H gU) / E) S + i Im(diag(U^H gU)) / S + S ((skew(V^H gV) / E)) + I o gS] V^H # // if m > n # // gA = [term in m == n] + (I_m - UU^H)gU S^{-1} V^H # // if m < n # // gA = [term in m == n] + U S^{-1} (gV)^H (I_n - VV^H) # // where we have used that Im(diag(U^H gU)) = - Im(diag(V^h gV)) to group the diagonal imaginary terms into one # // that just depends on U^H gU. # // Checks compute_uv=true # TORCH_INTERNAL_ASSERT(U.dim() >= 2 && Vh.dim() >= 2); # // Trivial case # if (!gS.defined() && !gU.defined() && !gVh.defined()) { # return {}; # } m = U.size(-2) n = Vh.size(-1) # // Optimisation for svdvals: gA = U @ diag(gS) @ Vh if (gU is None) and (gVh is None): if not (diagnostics is None): print(f"{diagnostics} {gA.size()} {gA.abs().max()} {S.max()}") return U @ (gS.unsqueeze(-1) * Vh) if m>=n else (U * gS.unsqueeze(-2)) @ Vh # // At this point, at least one of gU, gVh is defined is_complex = U.is_complex() def skew(A): return A - A.transpose(-2, -1).conj() # const auto UhgU = gU.defined() ? skew(at::matmul(U.mH(), gU)) : Tensor{}; # const auto VhgV = gVh.defined() ? skew(at::matmul(Vh, gVh.mH())) : Tensor{}; UhgU= skew( U.transpose(-2, -1).conj()@gU ) VhgV= skew( Vh@gVh.transpose(-2, -1).conj() ) # // Check for the invariance of the loss function, i.e. # // Im(diag(U^H gU)) + Im(diag(V^H gV)) = 0 # if (is_complex) { # const auto imdiag_UhgU = gU.defined() ? at::imag(UhgU.diagonal(0, -2, -1)) # : at::zeros_like(S); # const auto imdiag_VhgV = gVh.defined() ? at::imag(VhgV.diagonal(0, -2, -1)) # : at::zeros_like(S); # // Rather lax atol and rtol, as we don't want false positives # TORCH_CHECK(at::allclose(imdiag_UhgU, -imdiag_VhgV, /*rtol=*/1e-2, /*atol=*/1e-2), # "svd_backward: The singular vectors in the complex case are specified up to multiplication " # "by e^{i phi}. The specified loss function depends on this phase term, making " # "it ill-defined."); # } if is_complex: imdiag_UhgU= UhgU.diagonal(0, -2, -1).imag imdiag_VhgV= VhgV.diagonal(0, -2, -1).imag if not torch.allclose( imdiag_UhgU, -imdiag_VhgV, 1e-2, 1e-2 ): warnings.warn("svd_backward: The singular vectors in the complex case are "\ +"specified up to multiplication by e^{i phi}. The specified loss function depends on "\ +"this phase term, making it ill-defined.",RuntimeWarning) # import pdb; pdb.set_trace() # // gA = ((U^H gU) / E) S + S (((V^H gV) / E) + I o (gS + diag(U^H gU) / (2 * S)) # Tensor gA = [&] { # // ret holds everything but the diagonal of gA # auto ret = [&] { # const auto E = [&S]{ # const auto S2 = S * S; # auto ret = S2.unsqueeze(-2) - S2.unsqueeze(-1); # // Any number a != 0 would, as we are just going to use it to compute 0 / a later on # ret.diagonal(0, -2, -1).fill_(1); # return ret; # }(); # if (gU.defined()) { # if (gVh.defined()) { # return (UhgU * S.unsqueeze(-2) + S.unsqueeze(-1) * VhgV) / E; # } else { # return (UhgU / E) * S.unsqueeze(-2); # } # } else { // gVh.defined(); # return S.unsqueeze(-1) * (VhgV / E); # } # }(); # // Fill the diagonal # if (gS.defined()) { # ret = ret + gS.diag_embed(); # } # if (is_complex && gU.defined() && gVh.defined()) { # ret = ret + (UhgU.diagonal(0, -2, -1) / (2. * S)).diag_embed(); # } # return ret; # }(); def reg_preinv(x): x_reg= x.clone() x_scale= x.abs().max() if x_scale<ad_decomp_reg: x_reg= float('inf') else: x_reg[abs(x_reg/x_scale) < ad_decomp_reg] = float('inf') return x_reg S2= S*S E= S2.unsqueeze(-2) - S2.unsqueeze(-1) # S^2_i-S^2_j E.diagonal(0,-2,-1).fill_(1) gA= (UhgU * S.unsqueeze(-2) + S.unsqueeze(-1) * VhgV) * (1./reg_preinv(E)) + gS.diag_embed() if is_complex: gA = gA + (UhgU.diagonal(0, -2, -1) * (1./(2. * reg_preinv(S))) ).diag_embed() # if (m > n && gU.defined()) { # // gA = [UgA + (I_m - UU^H)gU S^{-1}]V^H # gA = at::matmul(U, gA); # const auto gUSinv = gU / S.unsqueeze(-2); # gA = gA + gUSinv - at::matmul(U, at::matmul(U.mH(), gUSinv)); # gA = at::matmul(gA, Vh); # } else if (m < n && gVh.defined()) { # // gA = U[gA V^H + S^{-1} (gV)^H (I_n - VV^H)] # gA = at::matmul(gA, Vh); # const auto SinvgVh = gVh / S.unsqueeze(-1); # gA = gA + SinvgVh - at::matmul(at::matmul(SinvgVh, Vh.mH()), Vh); # gA = at::matmul(U, gA); # } else { # // gA = U gA V^H # gA = m >= n ? at::matmul(U, at::matmul(gA, Vh)) # : at::matmul(at::matmul(U, gA), Vh); # } # TODO regularize 1/S if m>n: gA = U @ gA gUSinv = gU / S.unsqueeze(-2) gA = gA + gUSinv - U @ (U.transpose(-2, -1).conj() @ gUSinv) gA = gA @ Vh elif m<n: gA = gA @ Vh SinvgVh = gVh / S.unsqueeze(-1) gA = gA + SinvgVh - (SinvgVh @ Vh.transpose(-2, -1).conj()) @ Vh gA = U @ gA else: gA = U @ (gA @ Vh) if m>=n else (U @ gA) @ Vh # return gA; # } if not (diagnostics is None): print(f"{diagnostics} {gA.size()} {gA.abs().max()} {S.max()}") return gA, None, None, None
def test_SVDGESDD_legacy_random(): eps= 1.0e-12 eps= torch.as_tensor(eps, dtype=torch.float64) # M, N = 50, 40 # A = torch.rand(M, N, dtype=torch.float64, requires_grad=True) # assert(torch.autograd.gradcheck(SVDGESDD_legacy.apply, (A, eps, None,) , eps=1e-6, atol=1e-5)) # M, N = 40, 40 # A = torch.rand(M, N, dtype=torch.float64, requires_grad=True) # assert(torch.autograd.gradcheck(SVDGESDD_legacy.apply, (A, eps, None,) , eps=1e-6, atol=1e-5)) M= 50 A= torch.rand(M, M, dtype=torch.float64) A= 0.5*(A+A.t()) D, U = torch.symeig(A, eigenvectors=True) # make random spectrum with almost degen for split_scale in [0.]: # 10.0, 1.0, 0.1, 0.01, tot_scale=1000 d0= torch.rand(M//2, dtype=torch.float64) splits= torch.rand(M//2, dtype=torch.float64) for i in range(M//2): D[2*i]= tot_scale*d0[i] D[2*i+1]= tot_scale*d0[i]+split_scale*splits[i] A= U.t() @ torch.diag(D) @ U print(f"split_scale {split_scale} {D}") try: A.requires_grad_() assert(torch.autograd.gradcheck(SVDGESDD_legacy.apply, (A, eps, None,) , eps=1e-6, atol=1.0, rtol=1.0e-3)) except Exception as e: print(f"FAILED for splits: {split_scale}") print(e) def test_SVDGESDD_random(): eps= 1.0e-12 eps= torch.as_tensor(eps, dtype=torch.float64) M, N = 50, 40 A = torch.rand(M, N, dtype=torch.float64, requires_grad=True) assert(torch.autograd.gradcheck(SVDGESDD_legacy.apply, (A, eps, None,) , eps=1e-6, atol=1e-5)) M, N = 40, 40 A = torch.rand(M, N, dtype=torch.float64, requires_grad=True) assert(torch.autograd.gradcheck(SVDGESDD_legacy.apply, (A, eps, None,) , eps=1e-6, atol=1e-5)) M= 50 A= torch.rand(M, M, dtype=torch.float64) A= 0.5*(A+A.t()) D, U = torch.symeig(A, eigenvectors=True) # make random spectrum with almost degen for split_scale in [10.0, 1.0, 0.1, 0.01, 0.]: tot_scale=1000 d0= torch.rand(M//2, dtype=torch.float64) splits= torch.rand(M//2, dtype=torch.float64) for i in range(M//2): D[2*i]= tot_scale*d0[i] D[2*i+1]= tot_scale*d0[i]+split_scale*splits[i] A= U.t() @ torch.diag(D) @ U print(f"split_scale {split_scale} {D}") try: A.requires_grad_() assert(torch.autograd.gradcheck(SVDGESDD_legacy.apply, (A, eps, None,) , eps=1e-6, atol=1.0, rtol=1.0e-3)) except Exception as e: print(f"FAILED for splits: {split_scale}") print(e) def test_SVDGESDD_COMPLEX_random(): def test_f_1(M): U,S,V= SVDGESDD.apply(M) return torch.sum(S[0:1]) def test_f_2(M): U,S,V= SVDGESDD.apply(M) T= U @ V.conj().transpose(-2,-1) return T.norm() M= 25 A= torch.rand((M, M), dtype=torch.complex128) U,S,V= torch.svd(A) print(S) for split_scale in [10.0, 1.0, 0.1, 0.01, 0.]: tot_scale=1000 d0= torch.rand(M//2, dtype=torch.float64) splits= torch.rand(M//2, dtype=torch.float64) for i in range(M//2): S[2*i]= tot_scale*d0[i] S[2*i+1]= tot_scale*d0[i]+split_scale*splits[i] A= U * torch.diag(S) @ V.conj().transpose(-2,-1) A.requires_grad_() print(f"split_scale {split_scale}") print(S) assert(torch.autograd.gradcheck(test_f_1, A, eps=1e-6, atol=1e-4)) assert(torch.autograd.gradcheck(test_f_2, A, eps=1e-6, atol=1e-4)) if __name__=='__main__': test_SVDGESDD_legacy_random() test_SVDGESDD_random() # test_SVDGESDD_COMPLEX_random()