Extensions to Learner that easily implement Callback

Additional training functions

train provides a number of extension methods that are added to Learner (see below for a list and details), along with three simple callbacks:

Learner extension methods

These methods are automatically added to all Learner objects created after importing this module. They provide convenient access to a number of callbacks, without requiring them to be manually created.

fit_one_cycle[source][test]

fit_one_cycle(learn:Learner, cyc_len:int, max_lr:Union[float, Collection[float], slice]=slice(None, 0.003, None), moms:Point=(0.95, 0.85), div_factor:float=25.0, pct_start:float=0.3, final_div:float=None, wd:float=None, callbacks:Optional[Collection[Callback]]=None, tot_epochs:int=None, start_epoch:int=None)

Tests found for fit_one_cycle:

  • pytest -sv tests/test_train.py::test_fit_one_cycle [source]

Some other tests where fit_one_cycle is used:

  • pytest -sv tests/test_tabular_train.py::test_empty_cont [source]
  • pytest -sv tests/test_text_train.py::test_qrnn_works_if_split_fn_provided [source]
  • pytest -sv tests/test_text_train.py::test_qrnn_works_with_no_split [source]

To run tests please refer to this guide.

Fit a model following the 1cycle policy.

one_cycle_scheduler[source][test]

one_cycle_scheduler(lr_max:float, **kwargs:Any) → OneCycleScheduler

No tests found for one_cycle_scheduler. To contribute a test please refer to this guide and this discussion.

Instantiate a OneCycleScheduler with lr_max.

See OneCycleScheduler for details.

lr_find[source][test]

lr_find(learn:Learner, start_lr:Floats=1e-07, end_lr:Floats=10, num_it:int=100, stop_div:bool=True, wd:float=None)

Tests found for lr_find:

  • pytest -sv tests/test_train.py::test_lr_find [source]
  • pytest -sv tests/test_vision_train.py::test_lrfind [source]

To run tests please refer to this guide.

Explore lr from start_lr to end_lr over num_it iterations in learn. If stop_div, stops when loss diverges.

See LRFinder for details.

to_fp16[source][test]

to_fp16(learn:Learner, loss_scale:float=None, max_noskip:int=1000, dynamic:bool=False, clip:float=None, flat_master:bool=False, max_scale:float=16777216) → Learner

No tests found for to_fp16. To contribute a test please refer to this guide and this discussion.

Put learn in FP16 precision mode.

See MixedPrecision for details.

to_fp32[source][test]

to_fp32(learn:Learner)

No tests found for to_fp32. To contribute a test please refer to this guide and this discussion.

Put learn back to FP32 precision mode.

mixup[source][test]

mixup(learn:Learner, alpha:float=0.4, stack_x:bool=False, stack_y:bool=True) → Learner

No tests found for mixup. To contribute a test please refer to this guide and this discussion.

Add mixup https://arxiv.org/abs/1710.09412 to learn.

class ClassificationInterpretation[source][test]

ClassificationInterpretation(learn:Learner, probs:Tensor, y_true:Tensor, losses:Tensor, ds_type:DatasetType=<DatasetType.Valid: 2>)

Tests found for ClassificationInterpretation:

  • pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation [source]

Some other tests where ClassificationInterpretation is used:

  • pytest -sv tests/test_tabular_train.py::test_confusion_tabular [source]
  • pytest -sv tests/test_vision_train.py::test_interp [source]

To run tests please refer to this guide.

Interpretation methods for classification models.

See MixUpCallback for more details.

Additional callbacks

We'll show examples below using our MNIST sample. As usual the on_something methods are directly called by the fastai library, no need to call them yourself.

path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)

class ShowGraph[source][test]

ShowGraph(learn) :: LearnerCallback

No tests found for ShowGraph. To contribute a test please refer to this guide and this discussion.

Update a graph of learner stats and metrics after each epoch.

learn = cnn_learner(data, models.resnet18, metrics=accuracy, callback_fns=ShowGraph)
learn.fit(3)

Training graph

on_epoch_end[source][test]

on_epoch_end(n_epochs:int, last_metrics:MetricsList, **kwargs) → bool

No tests found for on_epoch_end. To contribute a test please refer to this guide and this discussion.

If we have last_metrics plot them in our pbar graph

class GradientClipping[source][test]

GradientClipping(learn:Learner, clip:float=0.0) :: LearnerCallback

No tests found for GradientClipping. To contribute a test please refer to this guide and this discussion.

Gradient clipping during training.

learn = cnn_learner(data, models.resnet18, metrics=accuracy,
    callback_fns=partial(GradientClipping, clip=0.1))
learn.fit(1)
Total time: 00:11

epoch train_loss valid_loss accuracy
1 0.131133 0.078190 0.973013

on_backward_end[source][test]

on_backward_end(**kwargs)

No tests found for on_backward_end. To contribute a test please refer to this guide and this discussion.

Clip the gradient before the optimizer step.

class BnFreeze[source][test]

BnFreeze(learn) :: LearnerCallback

No tests found for BnFreeze. To contribute a test please refer to this guide and this discussion.

Freeze moving average statistics in all non-trainable batchnorm layers.

For batchnorm layers where requires_grad==False, you generally don't want to update their moving average statistics, in order to avoid the model's statistics getting out of sync with its pre-trained weights. You can add this callback to automate this freezing of statistics (internally, it calls eval on these layers).

learn = cnn_learner(data, models.resnet18, metrics=accuracy, callback_fns=BnFreeze)
learn.fit(1)
Total time: 00:07

epoch train_loss valid_loss accuracy
1 0.132564 0.078910 0.972031

on_epoch_begin[source][test]

on_epoch_begin(**kwargs:Any)

No tests found for on_epoch_begin. To contribute a test please refer to this guide and this discussion.

Put bn layers in eval mode just after model.train().