Callbacks that take decisions depending on the evolution of metrics during training

Tracking Callbacks

This module regroups the callbacks that track one of the metrics computed at the end of each epoch to take some decision about training. To show examples of use, we'll use our sample of MNIST and a simple cnn model.

path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)

class TerminateOnNaNCallback[source]

TerminateOnNaNCallback() :: Callback

A Callback that terminates training if loss is NaN.

Sometimes, training diverges and the loss goes to nan. In that case, there's no point continuing, so this callback stops the training.

model = simple_cnn((3,16,16,2))
learn = Learner(data, model, metrics=[accuracy])
learn.fit_one_cycle(2,1e4)
Total time: 00:04
epoch  train_loss  valid_loss  accuracy
1      nan         nan         0.504416  (00:02)
2      nan         nan         0.504416  (00:02)

Using it prevents that situation to happen.

model = simple_cnn((3,16,16,2))
learn = Learner(data, model, metrics=[accuracy], callbacks=[TerminateOnNaNCallback()])
learn.fit(2,1e4)
0.00% [0/2 00:00<00:00]
epoch train_loss valid_loss accuracy
Interrupted
Epoch/Batch (0/5): Invalid loss, terminating training.

class EarlyStoppingCallback[source]

EarlyStoppingCallback(learn:Learner, monitor:str='val_loss', mode:str='auto', min_delta:int=0, patience:int=0) :: TrackerCallback

A TrackerCallback that terminates training when monitored quantity stops improving.

This callback tracks the quantity in monitor during the training of learn. mode can be forced to 'min' or 'max' but will automatically try to determine if the quantity should be the lowest possible (validation loss) or the highest possible (accuracy). Will stop training after patience epochs if the quantity hasn't improved by min_delta.

model = simple_cnn((3,16,16,2))
learn = Learner(data, model, metrics=[accuracy], 
                callback_fns=[partial(EarlyStoppingCallback, monitor='accuracy', min_delta=0.01, patience=3)])
learn.fit(50,1e-42)
6.00% [3/50 00:06<01:49]
epoch train_loss valid_loss accuracy
1 0.692837 0.692778 0.496565
2 0.692831 0.692778 0.496565
3 0.692877 0.692778 0.496565
100.00% [22/22 00:00<00:00]
Epoch 4: early stopping

class SaveModelCallback[source]

SaveModelCallback(learn:Learner, monitor:str='val_loss', mode:str='auto', every:str='improvement', name:str='bestmodel') :: TrackerCallback

A TrackerCallback that saves the model when monitored quantity is best.

This callback tracks the quantity in monitor during the training of learn. mode can be forced to 'min' or 'max' but will automatically try to determine if the quantity should be the lowest possible (validation loss) or the highest possible (accuracy). Will save the model in name whenever determined by every ('improvement' or 'epoch'). Loads the best model at the end of training is every='improvement'.

class ReduceLROnPlateauCallback[source]

ReduceLROnPlateauCallback(learn:Learner, monitor:str='val_loss', mode:str='auto', patience:int=0, factor:float=0.2, min_delta:int=0) :: TrackerCallback

A TrackerCallback that reduces learning rate when a metric has stopped improving.

This callback tracks the quantity in monitor during the training of learn. mode can be forced to 'min' or 'max' but will automatically try to determine if the quantity should be the lowest possible (validation loss) or the highest possible (accuracy). Will reduce the learning rate by factor after patience epochs if the quantity hasn't improved by min_delta.

class TrackerCallback[source]

TrackerCallback(learn:Learner, monitor:str='val_loss', mode:str='auto') :: LearnerCallback

A LearnerCallback that keeps track of the best value in monitor.