learn = synth_learner() learn.fit(1, cbs=ShortEpochCallback())
learn = synth_learner() learn.fit(1, cbs=ShortEpochCallback(short_valid=False))
learn = synth_learner() learn.fit(2, lr=0.01, cbs=GradientAccumulation(n_acc=2*learn.dls.bs)) # ensure train_loss decreased assert learn.recorder.values[-1] < learn.recorder.values learn.fit(2, lr=0.01, cbs=GradientAccumulation(n_acc=1e6)) # ensure valid_loss didn't change (same weights) assert learn.recorder.values[-1] == learn.recorder.values
Freeze moving average statistics in all non-trainable batchnorm layers.
BnFreeze is useful when you'd like to train two separate models that have a common feature extractor / body. The only part of the model that's different is the head that you attach for transfer learning.
Learner.freeze()) doesn't suffice here as the
BatchNorm layers are trainable by default, and running mean and std of batches are tracked. For feature extractors to fully match, you need to set
train_bn=False and these stats need to be frozen as well, which is precisely the function of
from fastai.vision.all import *
path = untar_data(URLs.MNIST_TINY) dls = ImageDataLoaders.from_folder(path, valid_pct=0.2)
We first demonstrate the mismatch of the running stats when using only
train_bn=False, by creating a
learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False)
m = learn1.model.running_mean.clone()
You can see that now that running mean has changed:
learn1.fit(1, lr=0.02) test_ne(learn1.model.running_mean, m)
When we use the
BnFreeze callback, the running statistics will not be changed during training. This is often important for getting good results from transfer learning.
learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False, cbs=BnFreeze) m = learn1.model.running_mean.clone() learn1.fit(1, lr=0.02) test_eq(learn1.model.running_mean, m)