from math import sqrt
import torch
from functools import reduce
from torch.optim.lbfgs import LBFGS, _strong_wolfe
import logging
log = logging.getLogger(__name__)
# from https://github.com/scipy/scipy/blob/master/scipy/optimize/linesearch.py
def _scalar_search_armijo(phi, phi0, derphi0, args=(), c1=1e-4, alpha0=1, amin=1.0e-8):
"""Minimize over alpha, the function ``phi(alpha)``.
Uses the interpolation algorithm (Armijo backtracking) as suggested by
Wright and Nocedal in 'Numerical Optimization', 1999, pp. 56-57
alpha > 0 is assumed to be a descent direction.
Returns
-------
alpha
phi1
"""
log.info(f"LS expected phi: {phi0+c1*alpha0*derphi0} (derphi0: {derphi0})")
phi_a0 = phi(alpha0, *args)
if phi_a0 <= phi0 + c1*alpha0*derphi0:
return alpha0, phi_a0
# Otherwise, compute the minimizer of a quadratic interpolant:
alpha1 = -(derphi0) * alpha0**2 / 2.0 / (phi_a0 - phi0 - derphi0 * alpha0)
phi_a1 = phi(alpha1, *args)
if (phi_a1 <= phi0 + c1*alpha1*derphi0):
return alpha1, phi_a1
# Otherwise, loop with cubic interpolation until we find an alpha which
# satisfies the first Wolfe condition (since we are backtracking, we will
# assume that the value of alpha is not too small and satisfies the second
# condition.
while alpha1 > amin: # we are assuming alpha>0 is a descent direction
factor = alpha0**2 * alpha1**2 * (alpha1-alpha0)
a = alpha0**2 * (phi_a1 - phi0 - derphi0*alpha1) - \
alpha1**2 * (phi_a0 - phi0 - derphi0*alpha0)
a = a / factor
b = -alpha0**3 * (phi_a1 - phi0 - derphi0*alpha1) + \
alpha1**3 * (phi_a0 - phi0 - derphi0*alpha0)
b = b / factor
alpha2 = (-b + sqrt(abs(b**2 - 3 * a * derphi0))) / (3.0*a)
phi_a2 = phi(alpha2, *args)
if (phi_a2 <= phi0 + c1*alpha2*derphi0):
return alpha2, phi_a2
if (alpha1 - alpha2) > alpha1 / 2.0 or (1 - alpha2/alpha1) < 0.96:
alpha2 = alpha1 / 2.0
alpha0 = alpha1
alpha1 = alpha2
phi_a0 = phi_a1
phi_a1 = phi_a2
# Failed to find a suitable step length
return None, phi_a1
[docs]class LBFGS_MOD(LBFGS):
r"""
Extends the original steepest gradient descent of PyTorch
[`torch.optim.LBFGS <https://pytorch.org/docs/stable/optim.html#torch.optim.LBFGS>`_]
with optional backtracking linesearch. The linesearch implementation
is adapted from scipy
[`scipy.optimize.linesearch <https://github.com/scipy/scipy/blob/master/scipy/optimize/linesearch.py>`_]
and relies only on the value of the loss function, not derivatives.
"""
def __init__(self,
params,
lr=1,
max_iter=20,
max_eval=None,
tolerance_grad=1e-7,
tolerance_change=1e-9,
history_size=100,
line_search_fn=None,
line_search_eps=1.0e-4):
r"""
Args:
lr : float
learning rate
max_iter : int
maximal number of iterations per optimization step
max_eval : int
maximal number of function evaluations per optimization step
(default: max_iter * 1.25).
tolerance_grad : float
termination tolerance on first order optimality
tolerance_change : float
termination tolerance on function value/parameter changes.
history_size : int
update history size.
line_search_fn : str
either 'strong_wolfe' or ``None``.
line_search_eps : float
minimal step size
"""
super(LBFGS_MOD, self).__init__(
params,
lr=lr,
max_iter=max_iter,
max_eval=max_eval,
tolerance_grad=tolerance_grad,
tolerance_change=tolerance_change,
history_size=history_size,
line_search_fn=line_search_fn)
# TODO for each param group ?
assert len(self.param_groups) == 1
group = self.param_groups[0]
group["line_search_eps"]= line_search_eps
def _directional_evaluate_derivative_free(self, closure, t, x, d):
self._add_grad(t, d)
with torch.no_grad():
orig_loss= closure(True)
loss= float(orig_loss)
self._set_param(x)
return loss
def _directional_evaluate(self, closure, x, t, d):
self._add_grad(t, d)
with torch.enable_grad():
loss = float(closure(linesearching=True))
flat_grad = self._gather_flat_grad()
self._set_param(x)
return loss, flat_grad
[docs] @torch.no_grad()
def step_2c(self, closure, closure_linesearch):
"""Performs a single optimization step.
Args:
closure (callable): A closure that reevaluates the model
and returns the loss.
closure_linesearch (callable): A closure that reevaluates the model and returns
the loss in torch.no_grad context
"""
assert len(self.param_groups) == 1
# Make sure the closure is always called with grad enabled
closure = torch.enable_grad()(closure)
group = self.param_groups[0]
lr = group['lr']
max_iter = group['max_iter']
max_eval = group['max_eval']
tolerance_grad = group['tolerance_grad']
tolerance_change = group['tolerance_change']
line_search_fn = group['line_search_fn']
line_search_eps= group['line_search_eps']
history_size = group['history_size']
# NOTE: LBFGS has only global state, but we register it as state for
# the first param, because this helps with casting in load_state_dict
state = self.state[self._params[0]]
state.setdefault('func_evals', 0)
state.setdefault('n_iter', 0)
# evaluate initial f(x) and df/dx
with torch.enable_grad():
orig_loss = closure()
loss = float(orig_loss)
current_evals = 1
state['func_evals'] += 1
flat_grad = self._gather_flat_grad()
is_complex= flat_grad.is_complex()
opt_cond = flat_grad.abs().max() <= tolerance_grad
# optimal condition
if opt_cond:
return orig_loss
# tensors cached in state (for tracing)
d = state.get('d')
t = state.get('t')
old_dirs = state.get('old_dirs')
old_stps = state.get('old_stps')
ro = state.get('ro')
H_diag = state.get('H_diag')
prev_flat_grad = state.get('prev_flat_grad')
prev_loss = state.get('prev_loss')
n_iter = 0
# optimize for a max of max_iter iterations
while n_iter < max_iter:
# keep track of nb of iterations
n_iter += 1
state['n_iter'] += 1
############################################################
# compute gradient descent direction
############################################################
if state['n_iter'] == 1:
d = flat_grad.neg()
old_dirs = []
old_stps = []
ro = []
H_diag = 1
else:
# do lbfgs update (update memory)
y = flat_grad.sub(prev_flat_grad)
s = d.mul(t)
ys = torch.real(y.conj().dot(s)) if is_complex else y.dot(s)
if ys > 1e-10:
# updating memory
if len(old_dirs) == history_size:
# shift history by one (limited-memory)
old_dirs.pop(0)
old_stps.pop(0)
ro.pop(0)
# store new direction/step
old_dirs.append(y)
old_stps.append(s)
ro.append(1. / ys)
# update scale of initial Hessian approximation
H_diag = ys / y.conj().dot(y) if is_complex else ys / y.dot(y)
# compute the approximate (L-BFGS) inverse Hessian
# multiplied by the gradient
num_old = len(old_dirs)
if 'al' not in state:
state['al'] = [None] * history_size
al = state['al']
# iteration in L-BFGS loop collapsed to use just one buffer
q = flat_grad.neg()
if is_complex:
for i in range(num_old - 1, -1, -1):
al[i] = torch.real(old_stps[i].conj().dot(q)) * ro[i]
q.add_(old_dirs[i], alpha=-al[i])
else:
for i in range(num_old - 1, -1, -1):
al[i] = old_stps[i].dot(q) * ro[i]
q.add_(old_dirs[i], alpha=-al[i])
# multiply by initial Hessian
# r/d is the final direction
d = r = torch.mul(q, H_diag)
if is_complex:
for i in range(num_old):
be_i = torch.real(old_dirs[i].conj().dot(r)) * ro[i]
r.add_(old_stps[i], alpha=al[i] - be_i)
else:
for i in range(num_old):
be_i = old_dirs[i].dot(r) * ro[i]
r.add_(old_stps[i], alpha=al[i] - be_i)
if prev_flat_grad is None:
prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format)
else:
prev_flat_grad.copy_(flat_grad)
prev_loss = loss
############################################################
# compute step length
############################################################
# reset initial guess for step size
if state['n_iter'] == 1:
t = min(1., 1. / flat_grad.abs().sum()) * lr
else:
t = lr
# directional derivative
gtd = torch.real(flat_grad.conj().dot(d)) if is_complex else flat_grad.dot(d)
# directional derivative is below tolerance
if gtd > -tolerance_change:
break
# optional line search: user function
ls_func_evals = 0
if line_search_fn is not None and line_search_fn != "default":
# perform line search, using user function
if line_search_fn == "backtracking":
x_init = self._clone_param()
def obj_func(t, x, d):
return self._directional_evaluate_derivative_free(closure_linesearch, t, x, d)
# return (xmin, fval, iter, funcalls)
t, loss= _scalar_search_armijo(obj_func, loss, gtd, args=(x_init,d), alpha0=t)
if t is None:
raise RuntimeError("minimize_scalar failed")
elif line_search_fn == "strong_wolfe":
x_init = self._clone_param()
def obj_func(x, t, d):
return self._directional_evaluate(closure, x, t, d)
loss, flat_grad, t, ls_func_evals = _strong_wolfe(
obj_func, x_init, t, d, loss, flat_grad, gtd)
else:
raise RuntimeError("unsupported line search")
log.info(f"LS final step: {t}")
self._add_grad(t, d)
opt_cond = flat_grad.abs().max() <= tolerance_grad
else:
# no line search, simply move with fixed-step
self._add_grad(t, d)
if n_iter != max_iter:
# re-evaluate function only if not in last iteration
# the reason we do this: in a stochastic setting,
# no use to re-evaluate that function here
with torch.enable_grad():
loss = float(closure())
flat_grad = self._gather_flat_grad()
opt_cond = flat_grad.abs().max() <= tolerance_grad
ls_func_evals = 1
# update func eval
current_evals += ls_func_evals
state['func_evals'] += ls_func_evals
############################################################
# check conditions
############################################################
if n_iter == max_iter:
break
if current_evals >= max_eval:
break
# optimal condition
if opt_cond:
break
# lack of progress
if d.mul(t).abs().max() <= tolerance_change:
break
if abs(loss - prev_loss) < tolerance_change:
break
state['d'] = d
state['t'] = t
state['old_dirs'] = old_dirs
state['old_stps'] = old_stps
state['ro'] = ro
state['H_diag'] = H_diag
state['prev_flat_grad'] = prev_flat_grad
state['prev_loss'] = prev_loss
return orig_loss