from fastai.test_utils import *
This is the decorator we will use for all of our scheduling functions, as it transforms a function taking (start, end, pos)
to something taking (start, end)
and return a function depending of pos
.
annealings = "NO LINEAR COS EXP".split()
p = torch.linspace(0.,1,100)
fns = [SchedNo, SchedLin, SchedCos, SchedExp]
for fn, t in zip(fns, annealings):
plt.plot(p, [fn(2, 1e-2)(o) for o in p], label=t)
f = SchedPoly(2,1e-2,0.5)
plt.plot(p, [f(o) for o in p], label="POLY(0.5)")
plt.legend();
sched = SchedLin(0, 2)
test_eq(L(map(sched, [0., 0.25, 0.5, 0.75, 1.])), [0., 0.5, 1., 1.5, 2.])
sched = SchedCos(0, 2)
test_close(L(map(sched, [0., 0.25, 0.5, 0.75, 1.])), [0., 0.29289, 1., 1.70711, 2.])
sched = SchedNo(0, 2)
test_close(L(map(sched, [0., 0.25, 0.5, 0.75, 1.])), [0., 0., 0., 0., 0.])
sched = SchedExp(1, 2)
test_close(L(map(sched, [0., 0.25, 0.5, 0.75, 1.])), [1., 1.18921, 1.41421, 1.68179, 2.])
sched = SchedPoly(0, 2, 2)
test_close(L(map(sched, [0., 0.25, 0.5, 0.75, 1.])), [0., 0.125, 0.5, 1.125, 2.])
p = torch.linspace(0.,1,100)
pows = [0.5,1.,2.]
for e in pows:
f = SchedPoly(2, 0, e)
plt.plot(p, [f(o) for o in p], label=f'power {e}')
plt.legend();
pcts
must be a list of positive numbers that add up to 1 and is the same length as scheds
. The generated function will use scheds[0]
from 0 to pcts[0]
then scheds[1]
from pcts[0]
to pcts[0]+pcts[1]
and so forth.
p = torch.linspace(0.,1,100)
f = combine_scheds([0.3,0.7], [SchedCos(0.3,0.6), SchedCos(0.6,0.2)])
plt.plot(p, [f(o) for o in p]);
p = torch.linspace(0.,1,100)
f = combine_scheds([0.3,0.2,0.5], [SchedLin(0.,1.), SchedNo(1.,1.), SchedCos(1., 0.)])
plt.plot(p, [f(o) for o in p]);
This is a useful helper function for the 1cycle policy. pct
is used for the start
to middle
part, 1-pct
for the middle
to end
. Handles floats or collection of floats. For example:
f = combined_cos(0.25,0.5,1.,0.)
plt.plot(p, [f(o) for o in p]);
scheds
is a dictionary with one key for each hyper-parameter you want to schedule, with either a scheduler or a list of schedulers as values (in the second case, the list must have the same length as the the number of parameters groups of the optimizer).
learn = synth_learner()
sched = {'lr': SchedLin(1e-3, 1e-2)}
learn.fit(1, cbs=ParamScheduler(sched))
n = len(learn.dls.train)
test_close(learn.recorder.hps['lr'], [1e-3 + (1e-2-1e-3) * i/n for i in range(n)])
The 1cycle policy was introduced by Leslie N. Smith et al. in Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates. It schedules the learning rate with a cosine annealing from lr_max/div
to lr_max
then lr_max/div_final
(pass an array to lr_max
if you want to use differential learning rates) and the momentum with cosine annealing according to the values in moms
. The first phase takes pct_start
of the training. You can optionally pass additional cbs
and reset_opt
.
learn = synth_learner(lr=1e-2)
xb,yb = learn.dls.one_batch()
init_loss = learn.loss_func(learn.model(xb), yb)
learn.fit_one_cycle(2)
xb,yb = learn.dls.one_batch()
final_loss = learn.loss_func(learn.model(xb), yb)
assert final_loss < init_loss
lrs,moms = learn.recorder.hps['lr'],learn.recorder.hps['mom']
test_close(lrs, [combined_cos(0.25,1e-2/25,1e-2,1e-7)(i/20) for i in range(20)])
test_close(moms, [combined_cos(0.25,0.95,0.85,0.95)(i/20) for i in range(20)])
learn = synth_learner()
learn.fit_one_cycle(2)
learn.recorder.plot_sched()
learn = synth_learner()
learn.fit_flat_cos(2)
learn.recorder.plot_sched()
This schedule was introduced by Ilya Loshchilov et al. in SGDR: Stochastic Gradient Descent with Warm Restarts. It consists of n_cycles
that are cosine annealings from lr_max
(defaults to the Learner
lr) to 0, with a length of cycle_len * cycle_mult**i
for the i
-th cycle (first one is cycle_len
-long, then we multiply the length by cycle_mult
at each epoch). You can optionally pass additional cbs
and reset_opt
.
learn = synth_learner()
with learn.no_logging(): learn.fit_sgdr(3, 1)
test_eq(learn.n_epoch, 7)
iters = [k * len(learn.dls.train) for k in [0,1,3,7]]
for i in range(3):
n = iters[i+1]-iters[i]
#The start of a cycle can be mixed with the 0 of the previous cycle with rounding errors, so we test at +1
test_close(learn.recorder.lrs[iters[i]+1:iters[i+1]], [SchedCos(learn.lr, 0)(k/n) for k in range(1,n)])
learn.recorder.plot_sched()
learn.fine_tune(1)
To enable resuming from checkpoint make sure to save model and optimizer state. This can be done using SaveModelCallback setting (with_opt=True)
. If training is interrupted define learn
using the same parameters as before, load model from checkpoint and pass start_epoch
to fit
call. The training will be resumed from start_epoch
with properly scheduled lr
.
with tempfile.TemporaryDirectory() as d:
learn1 = synth_learner(path=d, cbs=SaveModelCallback(with_opt=True, fname="ckpt"))
learn1.fit_one_cycle(5, cbs=InterruptCallback(2))
learn2 = synth_learner(path=d)
learn2 = learn2.load("ckpt")
learn2.fit_one_cycle(5, start_epoch=2)
fig, axs = plt.subplots(1,2, sharey=True)
axs[0].plot(learn1.recorder.lrs)
axs[1].plot(learn2.recorder.lrs)
from fastai.vision.all import *
set_seed(99, True)
path = untar_data(URLs.PETS)/'images'
image_files = get_image_files(path)
if sys.platform == "win32" and IN_NOTEBOOK:
image_files = random.choices(image_files, k=int(len(image_files)/8))
print("Randomly select 1/8 files in NOTEBOOK on Windows to save time")
# pickle can't serializer lamda function.
def _label_func(x):
return x[0].isupper()
dls = ImageDataLoaders.from_name_func(
path, image_files, valid_pct=0.2,
label_func=_label_func, item_tfms=Resize(224))
learn = vision_learner(dls, resnet18)
learn.fit(1)
learn.opt.state_dict()['state'][1]['grad_avg']
with tempfile.TemporaryDirectory() as d:
learn = synth_learner(path=Path(d))
init_a,init_b = learn.model.a,learn.model.b
with learn.no_logging(): learn.fit(20, cbs=LRFinder(num_it=100))
assert len(learn.recorder.lrs) <= 100
test_eq(len(learn.recorder.lrs), len(learn.recorder.losses))
#Check stop if diverge
if len(learn.recorder.lrs) < 100: assert learn.recorder.losses[-1] > 4 * min(learn.recorder.losses)
#Test schedule
test_eq(learn.recorder.lrs, [SchedExp(1e-7, 10)(i/100) for i in range_of(learn.recorder.lrs)])
#No validation data
test_eq([len(v) for v in learn.recorder.values], [1 for _ in range_of(learn.recorder.values)])
#Model loaded back properly
test_eq(learn.model.a, init_a)
test_eq(learn.model.b, init_b)
test_eq(learn.opt.state_dict()['state'], [{}, {}])
There are a few methodologies for suggesting a learning rate automatically and these as we will see can further be passed into lr_find
. Currently four methods are supported, however to write your own it should look like a function that can accept LRFinder
's returned lrs
, losses
, as well as the num_it
.
Your function should return an x,y
coordinate that can be plotted, such as below:
def myfunc(lrs:list, losses:list, num_it:int) -> tuple(float, tuple(float,int)):
...
return suggestion, (suggestion,loss_idx)
If there are any more parameters to be passed in, you should pass in your func
as a partial and specify them yourself, such as:
def myfunc(lrs:list, losses:list, num_it:int, pct_reduction:float) -> tuple(float, tuple(float,int)):
...
return suggestion, (suggestion,loss_idx)
f = partial(myfunc, pct_reduction=.2)
The valley
algorithm was developed by ESRI and takes the steepest slope roughly 2/3 through the longest valley in the LR plot, and is also the default for Learner.lr_find
First introduced by Leslie N. Smith in Cyclical Learning Rates for Training Neural Networks, the LR Finder trains the model with exponentially growing learning rates from start_lr
to end_lr
for num_it
and stops in case of divergence (unless stop_div=False
) then plots the losses vs the learning rates with a log scale.
A variety of learning rate suggestion algorithms can be passed into the function, by default we use the valley
paradigm.
with tempfile.TemporaryDirectory() as d:
learn = synth_learner(path=Path(d))
weights_pre_lr_find = L(learn.model.parameters())
lr_min, lr_steep, lr_valley, lr_slide = learn.lr_find(suggest_funcs=(minimum, steep, valley, slide))
weights_post_lr_find = L(learn.model.parameters())
test_eq(weights_pre_lr_find, weights_post_lr_find)
print(f"Minimum/10:\t{lr_min:.2e}\nSteepest point:\t{lr_steep:.2e}\nLongest valley:\t{lr_valley:.2e}\nSlide interval:\t{lr_slide:.2e}")