Extensions to Learner that easily implement Callback

Additional training functions

train provides a number of extension methods that are added to Learner (see below for a list and details), along with three simple callbacks:

Learner extension methods

These methods are automatically added to all Learner objects created after importing this module. They provide convenient access to a number of callbacks, without requiring them to be manually created.


fit_one_cycle(`learn`:Learner, `cyc_len`:int, `max_lr`:Union[float, Collection[float], slice]=`slice(None, 0.003, None)`, `moms`:Point=`(0.95, 0.85)`, `div_factor`:float=`25.0`, `pct_start`:float=`0.3`, `wd`:float=`None`, `callbacks`:Optional[Collection[Callback]]=`None`, `kwargs`)

Fit a model following the 1cycle policy.


one_cycle_scheduler(`lr_max`:float, `kwargs`:Any) → OneCycleScheduler

Instantiate a OneCycleScheduler with lr_max.

See OneCycleScheduler for details.


lr_find(`learn`:Learner, `start_lr`:Floats=`1e-07`, `end_lr`:Floats=`10`, `num_it`:int=`100`, `stop_div`:bool=`True`, `kwargs`:Any)

Explore lr from start_lr to end_lr over num_it iterations in learn. If stop_div, stops when loss diverges.

See LRFinder for details.


to_fp16(`learn`:Learner, `loss_scale`:float=`512.0`, `flat_master`:bool=`False`) → Learner

Put learn in FP16 precision mode.

See MixedPrecision for details.



Put learn back to FP32 precision mode.


mixup(`learn`:Learner, `alpha`:float=`0.4`, `stack_x`:bool=`False`, `stack_y`:bool=`True`) → Learner

Add mixup https://arxiv.org/abs/1710.09412 to learn.

See MixUpCallback for more details.

Additional callbacks

We'll show examples below using our MNIST sample. As usual the on_something methods are directly called by the fastai library, no need to call them yourself.

path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)

class ShowGraph[source]

ShowGraph(`learn`) :: LearnerCallback

Update a graph of learner stats and metrics after each epoch.

learn = create_cnn(data, models.resnet18, metrics=accuracy, callback_fns=ShowGraph)

Training graph


on_epoch_end(`n_epochs`:int, `last_metrics`:MetricsList, `kwargs`) → bool

If we have last_metrics plot them in our pbar graph

class GradientClipping[source]

GradientClipping(`learn`:Learner, `clip`:float=`0.0`) :: LearnerCallback

Gradient clipping during training.

learn = create_cnn(data, models.resnet18, metrics=accuracy,
    callback_fns=partial(GradientClipping, clip=0.1))
Total time: 00:11

epoch train_loss valid_loss accuracy
1 0.131133 0.078190 0.973013



Clip the gradient before the optimizer step.

class BnFreeze[source]

BnFreeze(`learn`) :: LearnerCallback

Freeze moving average statistics in all non-trainable batchnorm layers.

For batchnorm layers where requires_grad==False, you generally don't want to update their moving average statistics, in order to avoid the model's statistics getting out of sync with its pre-trained weights. You can add this callback to automate this freezing of statistics (internally, it calls eval on these layers).

learn = create_cnn(data, models.resnet18, metrics=accuracy, callback_fns=BnFreeze)
Total time: 00:07

epoch train_loss valid_loss accuracy
1 0.132564 0.078910 0.972031



Put bn layers in eval mode just after model.train().