Extensions to Learner that easily implement Callback

Additional training functions

train provides a number of extension methods that are added to Learner (see below for a list and details), along with three simple callbacks:

Learner extension methods

These methods are automatically added to all Learner objects created after importing this module. They provide convenient access to a number of callbacks, without requiring them to be manually created.


fit_one_cycle(learn:Learner, cyc_len:int, max_lr:Union[float, Collection[float], slice]=slice(None, 0.003, None), moms:Point=(0.95, 0.85), div_factor:float=25.0, pct_start:float=0.3, wd:float=None, callbacks:Optional[Collection[Callback]]=None, kwargs)

Fit a model following the 1cycle policy.

Fit a model with 1cycle training. See OneCycleScheduler for details.


lr_find(learn:Learner, start_lr:Floats=1e-07, end_lr:Floats=10, num_it:int=100, stop_div:bool=True, kwargs:Any)

Explore lr from start_lr to end_lr over num_it iterations in learn. If stop_div, stops when loss explodes.

See LRFinder for details.


to_fp16(learn:Learner, loss_scale:float=512.0, flat_master:bool=False) → Learner

Transform learn in FP16 precision.

See MixedPrecision for details.


mixup(learn:Learner, alpha:float=0.4, stack_x:bool=False, stack_y:bool=True) → Learner

Add mixup https://arxiv.org/abs/1710.09412 to learn.

See MixUpCallback for more details.

A last extension method comes from the module tta.


TTA(learn:Learner, beta:float=0.4, scale:float=1.35, ds_type:DatasetType=<DatasetType.Valid: 2>, with_loss:bool=False) → Tensors

Applies Test Time Augmentation to learn on the dataset ds_type. We take the average of our regular predictions (with a weight beta) with the average of predictions obtained thourh augmented versions of the training set (with a weight 1-beta). The transforms decided for the training set are applied with a few changes scale controls the scale for zoom (which isn't random), the cropping isn't random but we make sure to get the four corners of the image. Flipping isn't random but applied once on each of those corner images (so that makes 8 augmented versions total).

We'll show examples below using our MNIST sample.

path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)

class ShowGraph[source]

ShowGraph(learn:Learner) :: LearnerCallback

Update a graph of learner stats and metrics after each epoch.

learn = create_cnn(data, models.resnet18, metrics=accuracy, callback_fns=ShowGraph)

Training graph


on_epoch_end(n_epochs:int, last_metrics:MetricsList, kwargs) → bool

If we have last_metrics, plot them in self.pbar. Set the size of the graph with n_epochs.

class GradientClipping[source]

GradientClipping(learn:Learner, clip:float) :: LearnerCallback

To do gradient clipping during training.

Clips gradient at a maximum absolute value of clip during training. For instance:

learn = create_cnn(data, models.resnet18, metrics=accuracy,
    callback_fns=partial(GradientClipping, clip=0.1))
Total time: 00:11
epoch  train loss  valid loss  accuracy
0      0.086958    0.038721    0.989696  (00:11)



Clip the gradients after they are computed but before the optimizer step.

class BnFreeze[source]

BnFreeze(learn:Learner) :: LearnerCallback

Freeze moving average statistics in all non-trainable batchnorm layers.

For batchnorm layers where requires_grad==False, you generally don't want to update their moving average statistics, in order to avoid the model's statistics getting out of sync with its pre-trained weights. You can add this callback to automate this freezing of statistics (internally, it calls eval on these layers).

learn = create_cnn(data, models.resnet18, metrics=accuracy, callback_fns=BnFreeze)
Total time: 00:07
epoch  train loss  valid loss  accuracy
0      0.079278    0.041832    0.985280  (00:07)



Set back the batchnorm layers on eval mode after the model has been set to train.