Overview of fastai training modules, including Learner, metrics, and callbacks

Training modules overview

The fastai library structures its training process around the Learner class, whose object binds together a PyTorch model, a dataset, an optimizer, and a loss function; the entire learner object then will allow us to launch training.

basic_train defines this Learner class, along with the wrapper around the PyTorch optimizer that the library uses. It defines the basic training loop that is used each time you call the fit method (or one of its variants) in fastai. This training loop is very bare-bones and has very few lines of codes; you can customize it by supplying an optional Callback argument to the fit method.

callback defines the Callback class and the CallbackHandler class that is responsible for the communication between the training loop and the Callback's methods. The CallbackHandler maintains a state dictionary able to provide each Callback object all the information of the training loop it belongs to, putting any imaginable tweaks of the training loop within your reach.

callbacks implements each predefined Callback class of the fastai library in a separate module. Some modules deal with scheduling the hyperparameters, like callbacks.one_cycle, callbacks.lr_finder and callback.general_sched. Others allow special kinds of training like callbacks.fp16 (mixed precision) and callbacks.rnn. The Recorder and callbacks.hooks are useful to save some internal data generated in the training loop.

train then uses these callbacks to implement useful helper functions. Lastly, metrics contains all the functions and classes you might want to use to evaluate your training results; simpler metrics are implemented as functions while more complicated ones as subclasses of Callback. For more details on implementing metrics as Callback, please refer to creating your own metrics.

Walk-through of key functionalities

We'll do a quick overview of the key pieces of fastai's training modules. See the separate module docs for details on each.


Import required modules and prepare data:

from fastai.vision import *

path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)

URLs.MNIST_SAMPLE is a small subset of the classic MNIST dataset containing the images of just 3's and 7's for the purpose of demo and documentation here. Common datasets can be downloaded with untar_data - which we will use to create an ImageDataBunch object

Basic training with Learner

We can create a minimal CNN using simple_cnn (see models for details on creating models):

model = simple_cnn((3,16,16,2))

The Learner class plays a central role in training models; when you create a Learner you need to specify at the very minimum the data and model to use.

learn = Learner(data, model)

These are enough to create a Learner object and then use it to train a model using its fit method. If you have a CUDA-enabled GPU, it will be used automatically. To call the fit method, you have to at least specify how many epochs to train for.

Total time: 00:03

epoch train_loss valid_loss
1 0.124981 0.097195

Viewing metrics

To see how our training is going, we can request that it reports various kinds of metrics after each epoch. You can pass it to the constructor, or set it later. Note that metrics are always calculated on the validation set.

Total time: 00:02

epoch train_loss valid_loss accuracy
1 0.081563 0.062798 0.976938

Extending training with callbacks

You can use callbacks to modify training in almost any way you can imagine. For instance, we've provided a callback to implement Leslie Smith's 1cycle training method.

cb = OneCycleScheduler(learn, lr_max=0.01)
learn.fit(1, callbacks=cb)
Total time: 00:02

epoch train_loss valid_loss accuracy
1 0.055955 0.045469 0.984298

The Recorder callback is automatically added for you, and you can use it to see what happened in your training, e.g.:


Extending Learner with train

Many of the callbacks can be used more easily by taking advantage of the Learner extensions in train. For instance, instead of creating OneCycleScheduler manually as above, you can simply call Learner.fit_one_cycle:

Total time: 00:03

epoch train_loss valid_loss accuracy
1 0.040535 0.035062 0.986752


Note that if you're training a model for one of our supported applications, there's a lot of help available to you in the application modules:

For instance, let's use cnn_learner (from vision) to quickly fine-tune a pre-trained Imagenet model for MNIST (not a very practical approach, of course, since MNIST is handwriting and our model is pre-trained on photos!).

learn = cnn_learner(data, models.resnet18, metrics=accuracy)
Total time: 02:06

epoch train_loss valid_loss accuracy
1 0.163659 0.112767 0.958783