Overview of fastai training modules, including Learner, metrics, and callbacks

Training modules overview

The fastai library is structured training around a Learner object that binds together a pytorch model, some data with an optimizer and a loss function, which then will allow us to launch training.

basic_train contains the definition of this Learner class along with the wrapper around pytorch optimizer that the library uses. It defines the basic training loop that is used each time you call the fit function in fastai (or one of its variants). This training loop is kept to the minimum number of instructions, and most of its customization happens in Callback objects.

callback contains the definition of those, as well as the CallbackHandler that is responsible for the communication between the training loop and the Callback functions. It maintains a state dictionary to be able to provide to each Callback all the informations of the training loop, easily allowing any tweaks you could think of.

In callbacks, each Callback is then implemented in separate modules. Some deal with scheduling the hyperparameters, like callbacks.one_cycle, callbacks.lr_finder or callback.general_sched. Others allow special kind of trainings like callbacks.fp16 (mixed precision) or callbacks.rnn. The Recorder or callbacks.hooks are useful to save some internal data.

train then implements those callbacks with useful helper functions. Lastly metrics contains all the functions you might want to call to evaluate your results.

Walk-through of key functionality

We'll do a quick overview of the key pieces of fastai's training modules. See the separate module docs for details on each. We'll use the classic MNIST dataset for the training documentation, cut down to just 3's and 7's. To minimize the boilerplate in our docs we've defined a funcion to grab the data from URLs.MNIST_SAMPLE which will automatically download and unzip if not already done function, then we put it in an ImageDataBunch.

path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)

Basic training with Learner

We can create minimal simple CNNs using simple_cnn (see models for details on creating models):

model = simple_cnn((3,16,16,2))

The most important object for training models is Learner, which needs to know, at minimum, what data to train with and what model to train.

learn = Learner(data, model)

That's enough to train a model, which is done using fit. If you have a CUDA-capable GPU it will be used automatically. You have to say how many epochs to train for.

learn.fit(1)
Total time: 00:02
epoch  train_loss  valid_loss
1      0.141339    0.121598    (00:02)

Viewing metrics

To see how our training is going, we can request that it reports various metrics after each epoch. You can pass it to the constructor, or set it later. Note that metrics are always calculated on the validation set.

learn.metrics=[accuracy]
learn.fit(1)
Total time: 00:02
epoch  train_loss  valid_loss  accuracy
1      0.109016    0.091778    0.969578  (00:02)

Extending training with callbacks

You can use callbacks to modify training in almost any way you can imagine. For instance, we've provided a callback to implement Leslie Smith's 1cycle training method.

cb = OneCycleScheduler(learn, lr_max=0.01)
learn.fit(1, callbacks=cb)
Total time: 00:02
epoch  train_loss  valid_loss  accuracy
1      0.091946    0.068201    0.974975  (00:02)

The Recorder callback is automatically added for you, and you can use it to see what happened in your training, e.g.:

learn.recorder.plot_lr(show_moms=True)

Extending Learner with train

Many of the callbacks can be used more easily by taking advantage of the Learner extensions in train. For instance, instead of creating OneCycleScheduler manually as above, you can simply call Learner.fit_one_cycle:

learn.fit_one_cycle(1)
Total time: 00:02
epoch  train loss  valid loss  accuracy
0      0.044362    0.045060    0.984298  (00:02)

Applications

Note that if you're training a model for one of our supported applications, there's a lot of help available to you in the application modules:

For instance, let's use create_cnn (from vision) to quickly fine-tune a pre-trained Imagenet model for MNIST (not a very practical approach, of course, since MNIST is handwriting and our model is pre-trained on photos!).

learn = create_cnn(data, models.resnet18, metrics=accuracy)
learn.fit_one_cycle(1)
Total time: 00:09
epoch  train loss  valid loss  accuracy
0      0.093473    0.068315    0.976938  (00:09)