Implementation of the LR Range test from Leslie Smith

Learning Rate Finder

Learning rate finder plots lr vs loss relationship for a Learner. The idea is to reduce the amount of guesswork on picking a good starting learning rate.

Overview:

  1. First run lr_find learn.lr_find()
  2. Plot the learning rate vs loss learn.recorder.plot()
  3. Pick a learning rate before it diverges then start training

Technical Details: (first described by Leslie Smith)

Train Learner over a few iterations. Start with a very low start_lr and change it at each mini-batch until it reaches a very high end_lr. Recorder will record the loss at each iteration. Plot those losses against the learning rate to find the optimal value before it diverges.

Choosing a good learning rate

For a more intuitive explanation, please check out Sylvain Gugger's post

path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
def simple_learner(): return Learner(data, simple_cnn((3,16,16,2)), metrics=[accuracy])
learn = simple_learner()

First we run this command to launch the search:

lr_find[source]

lr_find(learn:Learner, start_lr:Floats=1e-07, end_lr:Floats=10, num_it:int=100, stop_div:bool=True, kwargs:Any)

Explore lr from start_lr to end_lr over num_it iterations in learn. If stop_div, stops when loss explodes.

learn.lr_find(stop_div=False, num_it=200)
LR Finder complete, type {learner_name}.recorder.plot() to see the graph.

Then we plot the loss versus the learning rates. We're interested in finding a good order of magnitude of learning rate, so we plot with a log scale.

learn.recorder.plot()

Then, we choose a value that is approximately in the middle of the sharpest downward slope. In this case, training with 3e-2 looks like it should work well:

simple_learner().fit(2, 3e-2)
Total time: 00:03
epoch  train_loss  valid_loss  accuracy
1      0.070224    0.039051    0.986752  (00:01)
2      0.038105    0.043696    0.985280  (00:01)

Don't just pick the minimum value from the plot!:

learn = simple_learner()
simple_learner().fit(2, 1e-0)
Total time: 00:03
epoch  train_loss  valid_loss  accuracy
1      0.724437    0.693147    0.495584  (00:01)
2      0.693758    0.693147    0.495584  (00:01)

Picking a value before the downward slope results in slow training:

learn = simple_learner()
simple_learner().fit(2, 1e-3)
Total time: 00:03
epoch  train_loss  valid_loss  accuracy
1      0.184354    0.168152    0.940137  (00:01)
2      0.146272    0.143661    0.946516  (00:01)

class LRFinder[source]

LRFinder(learn:Learner, start_lr:float=1e-07, end_lr:float=10, num_it:int=100, stop_div:bool=True) :: LearnerCallback

Causes learn to go on a mock training from start_lr to end_lr for num_it iterations. Training is interrupted if the loss diverges. Weights changes are reverted after run complete.

on_train_end[source]

on_train_end(kwargs:Any)

Cleanup learn model weights disturbed during LRFind exploration.

on_batch_end[source]

on_batch_end(iteration:int, smooth_loss:TensorOrNumber, kwargs:Any)

Determine if loss has runaway and we should stop.

on_train_begin[source]

on_train_begin(pbar, kwargs:Any)

Initialize optimizer and learner hyperparameters.

on_epoch_end[source]

on_epoch_end(kwargs:Any)

Tell Learner if we need to stop.