Various callbacks to customize training behavior

class ShortEpochCallback[source]

ShortEpochCallback(pct=0.01, short_valid=True) :: Callback

Fit just pct of an epoch, then stop

learn = synth_learner(), cbs=ShortEpochCallback())
epoch train_loss valid_loss time
0 00:00
learn = synth_learner(), cbs=ShortEpochCallback(short_valid=False))
epoch train_loss valid_loss time
0 12.395771 00:00

class GradientAccumulation[source]

GradientAccumulation(n_acc=32) :: Callback

Accumulate gradients before updating weights

learn = synth_learner(), lr=0.01, cbs=GradientAccumulation(n_acc=2*
# ensure train_loss decreased
assert learn.recorder.values[-1][0] < learn.recorder.values[0][0], lr=0.01, cbs=GradientAccumulation(n_acc=1e6))
# ensure valid_loss didn't change (same weights)
assert learn.recorder.values[-1][1] == learn.recorder.values[0][1]
epoch train_loss valid_loss time
0 10.566907 3.633753 00:00
1 5.525984 0.397483 00:00
epoch train_loss valid_loss time
0 0.476599 0.397483 00:00
1 0.478213 0.397483 00:00



set_bn_eval(m:Module, use_eval=True)

Set bn layers in eval mode for all recursive children of m.

class BnFreeze[source]

BnFreeze(after_create=None, before_fit=None, before_epoch=None, before_train=None, before_batch=None, after_pred=None, after_loss=None, before_backward=None, after_backward=None, after_step=None, after_cancel_batch=None, after_batch=None, after_cancel_train=None, after_train=None, before_validate=None, after_cancel_validate=None, after_validate=None, after_cancel_epoch=None, after_epoch=None, after_cancel_fit=None, after_fit=None) :: Callback

Freeze moving average statistics in all non-trainable batchnorm layers.

BnFreeze is useful when you'd like to train two separate models that have a common feature extractor / body. The only part of the model that's different is the head that you attach for transfer learning.

Learner.freeze()) doesn't suffice here as the BatchNorm layers are trainable by default, and running mean and std of batches are tracked. For feature extractors to fully match, you need to set train_bn=False and these stats need to be frozen as well, which is precisely the function of BnFreeze.

from import *
path = untar_data(URLs.MNIST_TINY)
dls  = ImageDataLoaders.from_folder(path, valid_pct=0.2)

We first demonstrate the mismatch of the running stats when using only train_bn=False, by creating a Learner...:

learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False)

...and grab the first BatchNorm layer, and store its running mean:

m = learn1.model[0][1].running_mean.clone()

You can see that now that running mean has changed:, lr=0.02)
test_ne(learn1.model[0][1].running_mean, m)
epoch train_loss valid_loss time
0 1.058304 0.713414 00:02

When we use the BnFreeze callback, the running statistics will not be changed during training. This is often important for getting good results from transfer learning.

learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False, cbs=BnFreeze)
m = learn1.model[0][1].running_mean.clone(), lr=0.02)
test_eq(learn1.model[0][1].running_mean, m)
epoch train_loss valid_loss time
0 0.540841 0.432421 00:02