Source code for optim.sgd_modified

import torch
from math import sqrt
from functools import reduce
from torch.optim.sgd import SGD
import logging
log = logging.getLogger(__name__)

# from https://github.com/scipy/scipy/blob/master/scipy/optimize/linesearch.py
def _scalar_search_armijo(phi, phi0, derphi0, args=(), c1=1e-4, alpha0=1, amin=1.0e-8):
    """Minimize over alpha, the function ``phi(alpha)``.
    Uses the interpolation algorithm (Armijo backtracking) as suggested by
    Wright and Nocedal in 'Numerical Optimization', 1999, pp. 56-57
    alpha > 0 is assumed to be a descent direction.
    Returns
    -------
    alpha
    phi1
    """
    log.info(f"LS expected phi: {phi0+c1*alpha0*derphi0} (derphi0: {derphi0})") 
    phi_a0 = phi(alpha0, *args)
    if phi_a0 <= phi0 + c1*alpha0*derphi0:
        return alpha0, phi_a0

    # Otherwise, compute the minimizer of a quadratic interpolant:

    alpha1 = -(derphi0) * alpha0**2 / 2.0 / (phi_a0 - phi0 - derphi0 * alpha0)
    phi_a1 = phi(alpha1, *args)

    if (phi_a1 <= phi0 + c1*alpha1*derphi0):
        return alpha1, phi_a1

    # Otherwise, loop with cubic interpolation until we find an alpha which
    # satisfies the first Wolfe condition (since we are backtracking, we will
    # assume that the value of alpha is not too small and satisfies the second
    # condition.

    while alpha1 > amin:       # we are assuming alpha>0 is a descent direction
        factor = alpha0**2 * alpha1**2 * (alpha1-alpha0)
        a = alpha0**2 * (phi_a1 - phi0 - derphi0*alpha1) - \
            alpha1**2 * (phi_a0 - phi0 - derphi0*alpha0)
        a = a / factor
        b = -alpha0**3 * (phi_a1 - phi0 - derphi0*alpha1) + \
            alpha1**3 * (phi_a0 - phi0 - derphi0*alpha0)
        b = b / factor

        alpha2 = (-b + sqrt(abs(b**2 - 3 * a * derphi0))) / (3.0*a)
        phi_a2 = phi(alpha2, *args)

        if (phi_a2 <= phi0 + c1*alpha2*derphi0):
            return alpha2, phi_a2

        if (alpha1 - alpha2) > alpha1 / 2.0 or (1 - alpha2/alpha1) < 0.96:
            alpha2 = alpha1 / 2.0

        alpha0 = alpha1
        alpha1 = alpha2
        phi_a0 = phi_a1
        phi_a1 = phi_a2

    # Failed to find a suitable step length
    return None, phi_a1

[docs]class SGD_MOD(SGD): r""" Extends the original steepest gradient descent of PyTorch [`torch.optim.SGD <https://pytorch.org/docs/stable/optim.html#torch.optim.SGD>`_] with optional backtracking linesearch. The linesearch implementation is adapted from scipy [`scipy.optimize.linesearch <https://github.com/scipy/scipy/blob/master/scipy/optimize/linesearch.py>`_] and relies only on the value of the loss function, not derivatives. .. warning:: This optimizer doesn't support per-parameter options and parameter groups (there can be only one). """ def __init__(self, params, lr=1, momentum=0, dampening=0, weight_decay=0, nesterov=False, line_search_fn=None, line_search_eps=1.0e-4): """ :param params: iterable of parameters to optimize or dicts defining parameter groups :type params: iterable :param lr: learning rate (default: 1) :type lr: float :param momentum: momentum factor (default: 0) :type momentum: float, optional :param dampening: dampening for momentum (default: 0) :type dampening: float, optional :param weight_decay: weight decay (L2 penalty) (default: 0) :type weight_decay: float, optional :param nesterov: enables Nesterov momentum (default: False) :type nesterov: bool, optional :param line_search_fn: line search algorithm. Supported options ``"backtracking"`` (default: None) :type line_search_fn: str, optional :param line_search_eps: minimal step size (default: 1.0e-4) :type line_search_eps: float, optional """ super(SGD_MOD, self).__init__( params, lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov) if len(self.param_groups) != 1: raise ValueError("SGD_MOD doesn't support per-parameter options " "(parameter groups)") group = self.param_groups[0] group["line_search_fn"]= line_search_fn group["line_search_eps"]= line_search_eps self._params = self.param_groups[0]['params'] self._numel_cache = None self._momentum_buffer = None def __setstate__(self, state): super(SGD_MOD, self).__setstate__(state) def _directional_evaluate_derivative_free(self, closure, t, x, d): self._add_grad(t, d) with torch.no_grad(): orig_loss= closure(True) loss= float(orig_loss) self._set_param(x) return loss def _numel(self): if self._numel_cache is None: self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0) return self._numel_cache def _gather_flat_grad(self): views = [] for p in self._params: if p.grad is None: view = p.new(p.numel()).zero_() elif p.grad.is_sparse: view = p.grad.to_dense().view(-1) else: view = p.grad.view(-1) views.append(view) return torch.cat(views, 0) def _gather_flat_params(self): views = [] for p in self._params: if p.is_sparse: view = p.to_dense().view(-1) else: view = p.view(-1) views.append(view) return torch.cat(views, 0) def _add_grad(self, step_size, update): offset = 0 for p in self._params: numel = p.numel() # view as to avoid deprecated pointwise semantics p.data.add_(update[offset:offset + numel].view_as(p), alpha=step_size) offset += numel assert offset == self._numel() def _clone_param(self): return [p.clone(memory_format=torch.contiguous_format) for p in self._params] def _set_param(self, params_data): for p, pdata in zip(self._params, params_data): p.data.copy_(pdata)
[docs] @torch.no_grad() def step_2c(self, closure, closure_linesearch=None): r""" Performs a single optimization step. :param closure: A closure that reevaluates the model and returns the loss. :type closure: callable :param closure_linesearch: A closure that reevaluates the model and returns the loss in torch.no_grad context :type closure_linesearch: callable, optional """ loss = None with torch.enable_grad(): loss = closure() lr = self.param_groups[0]['lr'] weight_decay = self.param_groups[0]['weight_decay'] momentum = self.param_groups[0]['momentum'] dampening = self.param_groups[0]['dampening'] nesterov = self.param_groups[0]['nesterov'] line_search_fn= self.param_groups[0]['line_search_fn'] line_search_eps= self.param_groups[0]['line_search_eps'] flat_grad= self._gather_flat_grad() d_p= flat_grad.clone() p= self._gather_flat_params() if weight_decay != 0: d_p.add_(weight_decay, p.data) if momentum != 0: if self._momentum_buffer==None: buf = self._momentum_buffer = torch.clone(d_p).detach() else: buf = self._momentum_buffer buf.mul_(momentum).add_(1 - dampening, d_p) if nesterov: d_p = d_p.add(momentum, buf) else: d_p = buf # optional line search: user function f_loss= float(loss) ls_func_evals = 0 if line_search_fn is not None and line_search_fn != "default": # perform line search, using user function if line_search_fn == "backtracking": d_p.mul_(-1) x_init = self._clone_param() gtd = flat_grad.dot(d_p) def obj_func(t, x, d): return self._directional_evaluate_derivative_free(closure_linesearch, t, x, d) # return (xmin, fval, iter, funcalls) t, f_loss= _scalar_search_armijo(obj_func, f_loss, gtd, args=(x_init,d_p), \ alpha0=lr, amin= line_search_eps) if t is None: raise RuntimeError("minimize_scalar failed") log.info(f"LS final step: {t}") else: raise RuntimeError("unsupported line search") log.info(f"LS final step: {t}") self._add_grad(t, d_p) else: # no line search, simply move with fixed-step self._add_grad(-lr, d_p) return loss