add_docs(_BaseOptimizer, ="List of param_groups, parameters, and hypers",
all_params="Freeze parameter groups up to `n`",
freeze_to="Freeze up to last parameter group",
freeze="Unfreeze the entire model",
unfreeze="`set_hyper` for all `kwargs`",
set_hypers="Set the value(s) in `v` for hyper-parameter `k`") set_hyper
Optimizers
Optimizer
Base optimizer class for the fastai library, updating params
with cbs
add_docs(Optimizer, ="Standard PyTorch API: Zero all the grad attributes of the parameters",
zero_grad="Standard PyTorch API: Update the stats and execute the steppers in on all parameters that have a grad",
step="Return the state of the optimizer in a dictionary",
state_dict="Load the content of `sd`",
load_state_dict="Reset the state of the optimizer") clear_state
Initializing an Optimizer
params
will be used to create the param_groups
of the optimizer. If it’s a collection (or a generator) of parameters, it will be a L
containing one L
with all the parameters. To define multiple parameter groups params
should be passed as a collection (or a generator) of L
s.
In PyTorch, model.parameters()
returns a generator with all the parameters, that you can directly pass to Optimizer
.
= Optimizer([1,2,3], noop)
opt 1,2,3]])
test_eq(opt.param_lists, [[= Optimizer(range(3), noop)
opt 0,1,2]])
test_eq(opt.param_lists, [[= Optimizer([[1,2],[3]], noop)
opt 1,2],[3]])
test_eq(opt.param_lists, [[= Optimizer(([o,o+1] for o in range(0,4,2)), noop)
opt 0,1],[2,3]]) test_eq(opt.param_lists, [[
cbs
is a list of functions that will be composed when applying the step. For instance, you can compose a function making the SGD step, with another one applying weight decay. Additionally, each cb
can have a defaults
attribute that contains hyper-parameters and their default value. Those are all gathered at initialization, and new values can be passed to override those defaults with the defaults
kwargs. The steppers will be called by Optimizer.step
(which is the standard PyTorch name), and gradients can be cleared with Optimizer.zero_grad
(also a standard PyTorch name).
Once the defaults have all been pulled off, they are copied as many times as there are param_groups
and stored in hypers
. To apply different hyper-parameters to different groups (differential learning rates, or no weight decay for certain layers for instance), you will need to adjust those values after the init.
def tst_arg(p, lr=0, **kwargs): return p
= dict(lr=1e-2)
tst_arg.defaults
def tst_arg2(p, lr2=0, **kwargs): return p
= dict(lr2=1e-3)
tst_arg2.defaults
def tst_arg3(p, mom=0, **kwargs): return p
= dict(mom=0.9)
tst_arg3.defaults
def tst_arg4(p, **kwargs): return p
= Optimizer([1,2,3], [tst_arg,tst_arg2, tst_arg3])
opt 'lr2': 1e-3, 'mom': 0.9, 'lr': 1e-2}])
test_eq(opt.hypers, [{= Optimizer([1,2,3], tst_arg, lr=0.1)
opt 'lr': 0.1}])
test_eq(opt.hypers, [{= Optimizer([[1,2],[3]], tst_arg)
opt 'lr': 1e-2}, {'lr': 1e-2}])
test_eq(opt.hypers, [{= Optimizer([[1,2],[3]], tst_arg, lr=0.1)
opt 'lr': 0.1}, {'lr': 0.1}]) test_eq(opt.hypers, [{
For each hyper-parameter, you can pass a slice or a collection to set them, if there are multiple parameter groups. A slice will be converted to a log-uniform collection from its beginning to its end, or if it only has an end e
, to a collection of as many values as there are parameter groups that are ...,e/10,e/10,e
.
Setting an hyper-parameter with a collection that has a different number of elements than the optimizer has parameter groups will raise an error.
= Optimizer([[1,2],[3]], tst_arg, lr=[0.1,0.2])
opt 'lr': 0.1}, {'lr': 0.2}])
test_eq(opt.hypers, [{= Optimizer([[1,2],[3],[4]], tst_arg, lr=slice(1e-2))
opt 'lr': 1e-3}, {'lr': 1e-3}, {'lr': 1e-2}])
test_eq(opt.hypers, [{= Optimizer([[1,2],[3],[4]], tst_arg, lr=slice(1e-4,1e-2))
opt 'lr': 1e-4}, {'lr': 1e-3}, {'lr': 1e-2}])
test_eq(opt.hypers, [{'params': [1,2], 'lr': 1e-4}, {'params': [3], 'lr': 1e-3}, {'params': [4], 'lr': 1e-2}])
test_eq(opt.param_groups, [{lambda: Optimizer([[1,2],[3],[4]], tst_arg, lr=np.array([0.1,0.2]))) test_fail(
Basic steppers
To be able to give examples of optimizer steps, we will need some steppers, like the following:
sgd_step
sgd_step (p, lr, **kwargs)
def tst_param(val, grad=None):
"Create a tensor with `val` and a gradient of `grad` for testing"
= tensor([val]).float()
res = tensor([val/10 if grad is None else grad]).float()
res.grad return res
= tst_param(1., 0.1)
p 1.)
sgd_step(p, 0.9]))
test_eq(p, tensor([0.1])) test_eq(p.grad, tensor([
weight_decay
weight_decay (p, lr, wd, do_wd=True, **kwargs)
Weight decay as decaying p
with lr*wd
= tst_param(1., 0.1)
p 1., 0.1)
weight_decay(p, 0.9]))
test_eq(p, tensor([0.1])) test_eq(p.grad, tensor([
l2_reg
l2_reg (p, lr, wd, do_wd=True, **kwargs)
L2 regularization as adding wd*p
to p.grad
= tst_param(1., 0.1)
p 1., 0.1)
l2_reg(p, 1.]))
test_eq(p, tensor([0.2])) test_eq(p.grad, tensor([
Weight decay and L2 regularization is the same thing for basic SGD, but for more complex optimizers, they are very different.
Making the step
Optimizer.step
Optimizer.step (closure=None)
This method will loop over all param groups, then all parameters for which grad
is not None and call each function in stepper
, passing it the parameter p
with the hyper-parameters in the corresponding dict in hypers
.
#test basic step
= L.range(4)
r def tst_params(): return r.map(tst_param)
= tst_params()
params = Optimizer(params, sgd_step, lr=0.1)
opt
opt.step()for p in params], r.map(mul(0.99))) test_close([p.item()
#test two steps
= tst_params()
params = Optimizer(params, [weight_decay, sgd_step], lr=0.1, wd=0.1)
opt
opt.step()for p in params], r.map(mul(0.98))) test_close([p.item()
#test None gradients are ignored
= tst_params()
params = Optimizer(params, sgd_step, lr=0.1)
opt -1].grad = None
params[
opt.step()for p in params], [0., 0.99, 1.98, 3.]) test_close([p.item()
#test discriminative lrs
= tst_params()
params = Optimizer([params[:2], params[2:]], sgd_step, lr=0.1)
opt 0]['lr'] = 0.01
opt.hypers[
opt.step()for p in params], [0., 0.999, 1.98, 2.97]) test_close([p.item()
Optimizer.zero_grad
Optimizer.zero_grad ()
= tst_params()
params = Optimizer(params, [weight_decay, sgd_step], lr=0.1, wd=0.1)
opt
opt.zero_grad()0.])) for p in params]; [test_eq(p.grad, tensor([
Some of the Optimizer
cbs
can be functions updating the state associated with a parameter. That state can then be used by any stepper. The best example is a momentum calculation.
def tst_stat(p, **kwargs):
= kwargs.get('sum', torch.zeros_like(p)) + p.data
s return {'sum': s}
= {'mom': 0.9}
tst_stat.defaults
#Test Optimizer init
= Optimizer([1,2,3], tst_stat)
opt 'mom': 0.9}])
test_eq(opt.hypers, [{= Optimizer([1,2,3], tst_stat, mom=0.99)
opt 'mom': 0.99}])
test_eq(opt.hypers, [{
#Test stat
= torch.randn(4,5)
x = tst_stat(x)
state assert 'sum' in state
'sum'])
test_eq(x, state[= tst_stat(x, **state)
state 'sum'], 2*x) test_eq(state[
Statistics
average_grad
average_grad (p, mom, dampening=False, grad_avg=None, **kwargs)
Keeps track of the avg grads of p
in state
with mom
.
dampening=False
gives the classical formula for momentum in SGD:
new_val = old_val * mom + grad
whereas dampening=True
makes it an exponential moving average:
new_val = old_val * mom + grad * (1-mom)
= tst_param([1,2,3], [4,5,6])
p = {}
state = average_grad(p, mom=0.9, **state)
state 'grad_avg'], p.grad)
test_eq(state[= average_grad(p, mom=0.9, **state)
state 'grad_avg'], p.grad * 1.9)
test_eq(state[
#Test dampening
= {}
state = average_grad(p, mom=0.9, dampening=True, **state)
state 'grad_avg'], 0.1*p.grad)
test_eq(state[= average_grad(p, mom=0.9, dampening=True, **state)
state 'grad_avg'], (0.1*0.9+0.1)*p.grad) test_close(state[
average_sqr_grad
average_sqr_grad (p, sqr_mom, dampening=True, sqr_avg=None, **kwargs)
dampening=False
gives the classical formula for momentum in SGD:
new_val = old_val * mom + grad**2
whereas dampening=True
makes it an exponential moving average:
new_val = old_val * mom + (grad**2) * (1-mom)
= tst_param([1,2,3], [4,5,6])
p = {}
state = average_sqr_grad(p, sqr_mom=0.99, dampening=False, **state)
state 'sqr_avg'], p.grad.pow(2))
test_eq(state[= average_sqr_grad(p, sqr_mom=0.99, dampening=False, **state)
state 'sqr_avg'], p.grad.pow(2) * 1.99)
test_eq(state[
#Test dampening
= {}
state = average_sqr_grad(p, sqr_mom=0.99, **state)
state 'sqr_avg'], 0.01*p.grad.pow(2))
test_close(state[= average_sqr_grad(p, sqr_mom=0.99, **state)
state 'sqr_avg'], (0.01*0.99+0.01)*p.grad.pow(2)) test_close(state[
Freezing part of the model
Optimizer.freeze
Optimizer.freeze ()
Optimizer.freeze_to
Optimizer.freeze_to (n:int)
Type | Details | |
---|---|---|
n | int | Freeze up to n layers |
Optimizer.unfreeze
Optimizer.unfreeze ()
#Freezing the first layer
= [tst_params(), tst_params(), tst_params()]
params = Optimizer(params, sgd_step, lr=0.1)
opt 1)
opt.freeze_to(= Self.requires_grad()
req_grad 0]).map(req_grad), [False]*4)
test_eq(L(params[for i in {1,2}: test_eq(L(params[i]).map(req_grad), [True]*4)
#Unfreezing
opt.unfreeze()for i in range(2): test_eq(L(params[i]).map(req_grad), [True]*4)
#TODO: test warning
# opt.freeze_to(3)
Parameters such as batchnorm weights/bias can be marked to always be in training mode, just put force_train=true
in their state.
= [tst_params(), tst_params(), tst_params()]
params = Optimizer(params, sgd_step, lr=0.1)
opt for p in L(params[1])[[1,3]]: opt.state[p] = {'force_train': True}
opt.freeze()0]).map(req_grad), [False]*4)
test_eq(L(params[1]).map(req_grad), [False, True, False, True])
test_eq(L(params[2]).map(req_grad), [True]*4) test_eq(L(params[
Serializing
Optimizer.state_dict
Optimizer.state_dict ()
Optimizer.load_state_dict
Optimizer.load_state_dict (sd:dict)
Type | Details | |
---|---|---|
sd | dict | State dict with hypers and state to load on the optimizer |
= tst_param([1,2,3], [4,5,6])
p = Optimizer(p, average_grad)
opt
opt.step()'grad_avg'], tensor([[4., 5., 6.]]))
test_eq(opt.state[p][
= opt.state_dict()
sd = tst_param([10,20,30], [40,50,60])
p1 = Optimizer(p1, average_grad, mom=0.99)
opt 0]['mom'], 0.99)
test_eq(opt.hypers[
test_eq(opt.state, {})
opt.load_state_dict(sd)0]['mom'], 0.9)
test_eq(opt.hypers['grad_avg'], tensor([[4., 5., 6.]])) test_eq(opt.state[p1][
Optimizer.clear_state
Optimizer.clear_state ()
= tst_param([1,2,3], [4,5,6])
p = Optimizer(p, average_grad)
opt = {'force_train': True}
opt.state[p]
opt.step()'grad_avg'], tensor([[4., 5., 6.]]))
test_eq(opt.state[p][
opt.clear_state()'force_train': True}) test_eq(opt.state[p], {
Optimizers
SGD with momentum
momentum_step
momentum_step (p, lr, grad_avg, **kwargs)
Step for SGD with momentum with lr
SGD
SGD (params:Union[torch.Tensor,Iterable], lr:float|slice, mom:float=0.0, wd:numbers.Real=0.0, decouple_wd:bool=True)
A SGD Optimizer
Type | Default | Details | |
---|---|---|---|
params | Union | Model parameters | |
lr | float | slice | Default learning rate | |
mom | float | 0.0 | Gradient moving average (β1) coefficient |
wd | Real | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay or L2 regularization (SGD) |
Returns | Optimizer |
Optional weight decay of wd
is applied, as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
#Vanilla SGD
= tst_params()
params = SGD(params, lr=0.1)
opt
opt.step()for p in params], [i*0.99 for i in range(4)])
test_close([p.item()
opt.step()for p in params], [i*0.98 for i in range(4)]) test_close([p.item()
#SGD with momentum
= tst_params()
params = SGD(params, lr=0.1, mom=0.9)
opt assert isinstance(opt, Optimizer)
opt.step()for p in params], [i*0.99 for i in range(4)])
test_close([p.item()
opt.step()for p in params], [i*(1 - 0.1 * (0.1 + 0.1*1.9)) for i in range(4)])
test_close([p.item() for i,p in enumerate(params): test_close(opt.state[p]['grad_avg'].item(), i*0.19)
Test weight decay, notice how we can see that L2 regularization is different from weight decay even for simple SGD with momentum.
= tst_params()
params #Weight decay
= SGD(params, lr=0.1, mom=0.9, wd=0.1)
opt
opt.step()for p in params], [i*0.98 for i in range(4)])
test_close([p.item() #L2 reg
= SGD(params, lr=0.1, mom=0.9, wd=0.1, decouple_wd=False)
opt
opt.step()#TODO: fix cause this formula was wrong
#test_close([p.item() for p in params], [i*0.97 for i in range(4)])
RMSProp
rms_prop_step
rms_prop_step (p, lr, sqr_avg, eps, grad_avg=None, **kwargs)
Step for RMSProp with momentum with lr
RMSProp
RMSProp (params:Union[torch.Tensor,Iterable], lr:float|slice, mom:float=0.0, sqr_mom:float=0.99, eps:float=1e-08, wd:numbers.Real=0.0, decouple_wd:bool=True)
A RMSProp Optimizer
Type | Default | Details | |
---|---|---|---|
params | Union | Model parameters | |
lr | float | slice | Default learning rate | |
mom | float | 0.0 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-08 | Added for numerical stability |
wd | Real | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay or L2 regularization (RMSProp) |
Returns | Optimizer |
RMSProp was introduced by Geoffrey Hinton in his course. What is named sqr_mom
here is the alpha
in the course. Optional weight decay of wd
is applied, as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
#Without momentum
= tst_param([1,2,3], [0.1,0.2,0.3])
params = RMSProp(params, lr=0.1)
opt
opt.step()0], tensor([0.,1.,2.]))
test_close(params[
opt.step()= - 0.1 * 0.1 / (math.sqrt((0.01*0.99+0.01) * 0.1**2) + 1e-8)
step 0], tensor([step, 1+step, 2+step])) test_close(params[
#With momentum
= tst_param([1,2,3], [0.1,0.2,0.3])
params = RMSProp(params, lr=0.1, mom=0.9)
opt
opt.step()0], tensor([0.,1.,2.]))
test_close(params[
opt.step()= - 0.1 * (0.1 + 0.9*0.1) / (math.sqrt((0.01*0.99+0.01) * 0.1**2) + 1e-8)
step 0], tensor([step, 1+step, 2+step])) test_close(params[
Adam
step_stat
step_stat (p, step=0, **kwargs)
Register the number of steps done in state
for p
= tst_param(1,0.1)
p = {}
state = step_stat(p, **state)
state 'step'], 1)
test_eq(state[for _ in range(5): state = step_stat(p, **state)
'step'], 6) test_eq(state[
debias
debias (mom, damp, step)
adam_step
adam_step (p, lr, mom, step, sqr_mom, grad_avg, sqr_avg, eps, **kwargs)
Step for Adam with lr
on p
Adam
Adam (params:Union[torch.Tensor,Iterable], lr:float|slice, mom:float=0.9, sqr_mom:float=0.99, eps:float=1e-05, wd:numbers.Real=0.01, decouple_wd:bool=True)
A Adam/AdamW Optimizer
Type | Default | Details | |
---|---|---|---|
params | Union | Model parameters | |
lr | float | slice | Default learning rate | |
mom | float | 0.9 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-05 | Added for numerical stability |
wd | Real | 0.01 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (AdamW) or L2 regularization (Adam) |
Returns | Optimizer |
Adam was introduced by Diederik P. Kingma and Jimmy Ba in Adam: A Method for Stochastic Optimization. For consistency across optimizers, we renamed beta1
and beta2
in the paper to mom
and sqr_mom
. Note that our defaults also differ from the paper (0.99 for sqr_mom
or beta2
, 1e-5 for eps
). Those values seem to be better from our experiments in a wide range of situations.
Optional weight decay of wd
is applied, as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
Don’t forget that eps
is an hyper-parameter you can change. Some models won’t train without a very high eps
like 0.1 (intuitively, the higher eps
is, the closer we are to normal SGD). The usual default of 1e-8 is often too extreme in the sense we don’t manage to get as good results as with SGD.
= tst_param([1,2,3], [0.1,0.2,0.3])
params = Adam(params, lr=0.1, wd=0)
opt
opt.step()= -0.1 * 0.1 / (math.sqrt(0.1**2) + 1e-8)
step 0], tensor([1+step, 2+step, 3+step]))
test_close(params[
opt.step()0], tensor([1+2*step, 2+2*step, 3+2*step]), eps=1e-3) test_close(params[
RAdam
RAdam (for rectified Adam) was introduced by Zhang et al. in On the Variance of the Adaptive Learning Rate and Beyond to slightly modify the Adam optimizer to be more stable at the beginning of training (and thus not require a long warmup). They use an estimate of the variance of the moving average of the squared gradients (the term in the denominator of traditional Adam) and rescale this moving average by this term before performing the update.
This version also incorporates SAdam; set beta
to enable this (definition same as in the paper).
radam_step
radam_step (p, lr, mom, step, sqr_mom, grad_avg, sqr_avg, eps, beta, **kwargs)
Step for RAdam with lr
on p
RAdam
RAdam (params:Union[torch.Tensor,Iterable], lr:float|slice, mom:float=0.9, sqr_mom:float=0.99, eps:float=1e-05, wd:numbers.Real=0.0, beta:float=0.0, decouple_wd:bool=True)
A RAdam/RAdamW Optimizer
Type | Default | Details | |
---|---|---|---|
params | Union | Model parameters | |
lr | float | slice | Default learning rate | |
mom | float | 0.9 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-05 | Added for numerical stability |
wd | Real | 0.0 | Optional weight decay (true or L2) |
beta | float | 0.0 | Set to enable SAdam |
decouple_wd | bool | True | Apply true weight decay (RAdamW) or L2 regularization (RAdam) |
Returns | Optimizer |
This is the effective correction reported to the adam step for 500 iterations in RAdam. We can see how it goes from 0 to 1, mimicking the effect of a warm-up.
= 0.99
beta = 2/(1-beta) - 1
r_inf = np.array([r_inf - 2*s*beta**s/(1-beta**s) for s in range(5,500)])
rs = np.sqrt(((rs-4) * (rs-2) * r_inf)/((r_inf-4)*(r_inf-2)*rs))
v ; plt.plot(v)
= tst_param([1,2,3], [0.1,0.2,0.3])
params = RAdam(params, lr=0.1)
opt #The r factor is lower than 5 during the first 5 steps so updates use the average of gradients (all the same)
= 2/(1-0.99) - 1
r_inf for i in range(5):
= r_inf - 2*(i+1)*0.99**(i+1)/(1-0.99**(i+1))
r assert r <= 5
opt.step()= tensor([0.95, 1.9, 2.85])
p 0], p)
test_close(params[
#The r factor is greater than 5 for the sixth step so we update with RAdam
= r_inf - 2*6*0.99**6/(1-0.99**6)
r assert r > 5
opt.step()= math.sqrt(((r-4) * (r-2) * r_inf)/((r_inf-4)*(r_inf-2)*r))
v = -0.1*0.1*v/(math.sqrt(0.1**2) + 1e-8)
step 0], p+step) test_close(params[
QHAdam
QHAdam (for Quasi-Hyperbolic Adam) was introduced by Ma & Yarats in Quasi-Hyperbolic Momentum and Adam for Deep Learning as a “computationally cheap, intuitive to interpret, and simple to implement” optimizer. Additional code can be found in their qhoptim repo. QHAdam is based on QH-Momentum, which introduces the immediate discount factor nu
, encapsulating plain SGD (nu = 0
) and momentum (nu = 1
). QH-Momentum is defined below, where g_t+1 is the update of the moment. An interpretation of QHM is as a nu-weighted average of the momentum update step and the plain SGD update step.
θ_t+1 ← θ_t − lr * [(1 − nu) · ∇L_t(θ_t) + nu · g_t+1]
QHAdam takes the concept behind QHM above and applies it to Adam, replacing both of Adam’s moment estimators with quasi-hyperbolic terms.
The paper’s suggested default parameters are mom = 0.999
, sqr_mom = 0.999
, nu_1 = 0.7
and and nu_2 = 1.0
. When training is not stable, it is possible that setting nu_2 < 1
can improve stability by imposing a tighter step size bound. Note that QHAdam recovers Adam when nu_1 = nu_2 = 1.0
. QHAdam recovers RMSProp (Hinton et al., 2012) when nu_1 = 0
and nu_2 = 1
, and NAdam (Dozat, 2016) when nu_1 = mom
and nu_2 = 1
.
Optional weight decay of wd
is applied, as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
qhadam_step
qhadam_step (p, lr, mom, sqr_mom, sqr_avg, nu_1, nu_2, step, grad_avg, eps, **kwargs)
QHAdam
QHAdam (params:Union[torch.Tensor,Iterable], lr:float|slice, mom:float=0.999, sqr_mom:float=0.999, nu_1:float=0.7, nu_2:float=1.0, eps:float=1e-08, wd:numbers.Real=0.0, decouple_wd:bool=True)
A QHAdam/QHAdamW Optimizer
Type | Default | Details | |
---|---|---|---|
params | Union | Model parameters | |
lr | float | slice | Default learning rate | |
mom | float | 0.999 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.999 | Gradient squared moving average (β2) coefficient |
nu_1 | float | 0.7 | QH immediate discount factor |
nu_2 | float | 1.0 | QH momentum discount factor |
eps | float | 1e-08 | Added for numerical stability |
wd | Real | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (QHAdamW) or L2 regularization (QHAdam) |
Returns | Optimizer |
= tst_param([1,2,3], [0.1,0.2,0.3])
params = QHAdam(params, lr=0.1)
opt
opt.step()= -0.1 * (((1-0.7) * 0.1) + (0.7 * 0.1)) / (
step 1-1.0) * 0.1**2) + (1.0 * 0.1**2)) + 1e-8)
math.sqrt(((0], tensor([1+step, 2+step, 3+step]))
test_close(params[
opt.step()0], tensor([1+2*step, 2+2*step, 3+2*step]), eps=1e-3) test_close(params[
LARS/LARC
larc_layer_lr
larc_layer_lr (p, lr, trust_coeff, wd, eps, clip=True, **kwargs)
Computes the local lr before weight decay is applied
larc_step
larc_step (p, local_lr, grad_avg=None, **kwargs)
Step for LARC local_lr
on p
Larc
Larc (params:Union[torch.Tensor,Iterable], lr:float|slice, mom:float=0.9, clip:bool=True, trust_coeff:float=0.02, eps:float=1e-08, wd:numbers.Real=0.0, decouple_wd:bool=True)
A LARC/LARS Optimizer
Type | Default | Details | |
---|---|---|---|
params | Union | Model parameters | |
lr | float | slice | Default learning rate | |
mom | float | 0.9 | Gradient moving average (β1) coefficient |
clip | bool | True | LARC if clip=True, LARS if clip=False |
trust_coeff | float | 0.02 | Trust coeffiecnet for calculating layerwise LR |
eps | float | 1e-08 | Added for numerical stability |
wd | Real | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay or L2 regularization |
Returns | Optimizer |
The LARS optimizer was first introduced in Large Batch Training of Convolutional Networks then refined in its LARC variant (original LARS is with clip=False
). A learning rate is computed for each individual layer with a certain trust_coefficient
, then clipped to be always less than lr
.
Optional weight decay of wd
is applied, as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
= [tst_param([1,2,3], [0.1,0.2,0.3]), tst_param([1,2,3], [0.01,0.02,0.03])]
params = Larc(params, lr=0.1)
opt
opt.step()#First param local lr is 0.02 < lr so it's not clipped
0]]['local_lr'], 0.02)
test_close(opt.state[params[#Second param local lr is 0.2 > lr so it's clipped
1]]['local_lr'], 0.1)
test_eq(opt.state[params[0], tensor([0.998,1.996,2.994]))
test_close(params[1], tensor([0.999,1.998,2.997])) test_close(params[
= [tst_param([1,2,3], [0.1,0.2,0.3]), tst_param([1,2,3], [0.01,0.02,0.03])]
params = Larc(params, lr=0.1, clip=False)
opt
opt.step()#No clipping
0]]['local_lr'], 0.02)
test_close(opt.state[params[1]]['local_lr'], 0.2)
test_close(opt.state[params[0], tensor([0.998,1.996,2.994]))
test_close(params[1], tensor([0.998,1.996,2.994])) test_close(params[
LAMB
lamb_step
lamb_step (p, lr, mom, step, sqr_mom, grad_avg, sqr_avg, eps, **kwargs)
Step for LAMB with lr
on p
Lamb
Lamb (params:Union[torch.Tensor,Iterable], lr:float|slice, mom:float=0.9, sqr_mom:float=0.99, eps:float=1e-05, wd:numbers.Real=0.0, decouple_wd:bool=True)
A LAMB Optimizer
Type | Default | Details | |
---|---|---|---|
params | Union | Model parameters | |
lr | float | slice | Default learning rate | |
mom | float | 0.9 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-05 | Added for numerical stability |
wd | Real | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay or L2 regularization |
Returns | Optimizer |
LAMB was introduced in Large Batch Optimization for Deep Learning: Training BERT in 76 minutes. Intuitively, it’s LARC applied to Adam. As in Adam
, we renamed beta1
and beta2
in the paper to mom
and sqr_mom
. Note that our defaults also differ from the paper (0.99 for sqr_mom
or beta2
, 1e-5 for eps
). Those values seem to be better from our experiments in a wide range of situations.
Optional weight decay of wd
is applied, as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
= tst_param([1,2,3], [0.1,0.2,0.3])
params = Lamb(params, lr=0.1)
opt
opt.step()0], tensor([0.7840,1.7840,2.7840]), eps=1e-3) test_close(params[
Lookahead was introduced by Zhang et al. in Lookahead Optimizer: k steps forward, 1 step back. It can be run on top of any optimizer and consists in having the final weights of the model be a moving average. In practice, we update our model using the internal optimizer but keep a copy of old weights that and every k
steps, we change the weights by a moving average of the fast weights (the ones updated by the inner optimizer) with the slow weights (the copy of old weights). Those slow weights act like a stability mechanism.
Lookahead
Lookahead (opt:__main__.Optimizer, k:int=6, alpha:float=0.5)
Wrap opt
in a lookahead optimizer
Type | Default | Details | |
---|---|---|---|
opt | Optimizer | Optimizer to wrap with Lookahead |
|
k | int | 6 | How often to conduct Lookahead step |
alpha | float | 0.5 | Slow weight moving average coefficient |
= tst_param([1,2,3], [0.1,0.2,0.3])
params = params[0].data.clone(),tensor([0.1,0.2,0.3])
p,g = Lookahead(SGD(params, lr=0.1))
opt for k in range(5): opt.step()
#first 5 steps are normal SGD steps
0], p - 0.5*g)
test_close(params[#Since k=6, sixth step is a moving average of the 6 SGD steps with the initial weight
opt.step()0], p * 0.5 + (p-0.6*g) * 0.5) test_close(params[
ranger
ranger (params:Tensor|Iterable, lr:float|slice, mom:float=0.95, wd:Real=0.01, eps:float=1e-06, k:int=6, alpha:float=0.5, sqr_mom:float=0.99, beta:float=0.0, decouple_wd:bool=True)
Convenience method for Lookahead
with RAdam
Type | Default | Details | |
---|---|---|---|
params | Tensor | Iterable | Model parameters | |
lr | float | slice | Default learning rate | |
mom | float | 0.95 | Gradient moving average (β1) coefficient |
wd | Real | 0.01 | Optional weight decay (true or L2) |
eps | float | 1e-06 | Added for numerical stability |
k | int | 6 | How often to conduct Lookahead step |
alpha | float | 0.5 | Slow weight moving average coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
beta | float | 0.0 | Set to enable SAdam |
decouple_wd | bool | True | Apply true weight decay (RAdamW) or L2 regularization (RAdam) |
Returns | Lookahead |
OptimWrapper
provides simple functionality to use existing optimizers constructed with torch.optim.Optimizer
.
detuplify_pg
detuplify_pg (d)
set_item_pg
set_item_pg (pg, k, v)
OptimWrapper
A wrapper class for existing PyTorch optimizers
To use an existing PyTorch optimizer, you can define an optimizer function like this:
= partial(OptimWrapper, opt=torch.optim.SGD) opt_func
Or if you already have a built optimizer, pass in only opt
:
= torch.optim.SGD([tensor([1,2,3])], lr=1e-2)
opt = OptimWrapper(opt=opt) opt_func
When passing a built optimizer to Learner
, instead of resetting the optimizer Learner.fit
will clear the optimizer state if reset_opt=True
or when calling Learner.fit
for the first time.
To prevent Learner
from clearing the optimizer state when calling Learner.fit
for the first time, assign the optimizer directly to Learner.opt
:
= torch.optim.SGD([tensor([1,2,3])], lr=1e-2)
opt = Learner(..., opt_func=None)
learn = OptimWrapper(opt=opt) learn.opt