The fastai library is structured training around a
Learner object that binds together a pytorch model, some data with an optimizer and a loss function, which then will allow us to launch training.
basic_train contains the definition of this
Learner class along with the wrapper around pytorch optimizer that the library uses. It defines the basic training loop that is used each time you call the
fit function in fastai (or one of its variants). This training loop is kept to the minimum number of instructions, and most of its customization happens in
callback contains the definition of those, as well as the
CallbackHandler that is responsible for the communication between the training loop and the
Callback functions. It maintains a state dictionary to be able to provide to each
Callback all the informations of the training loop, easily allowing any tweaks you could think of.
Callback is then implemented in separate modules. Some deal with scheduling the hyperparameters, like
callback.general_sched. Others allow special kind of trainings like
callbacks.fp16 (mixed precision) or
callbacks.hooks are useful to save some internal data.
We'll do a quick overview of the key pieces of fastai's training modules. See the separate module docs for details on each. We'll use the classic MNIST dataset for the training documentation, cut down to just 3's and 7's. To minimize the boilerplate in our docs we've defined a funcion to grab the data from
URLs.MNIST_SAMPLE which will automatically download and unzip if not already done function, then we put it in an
path = untar_data(URLs.MNIST_SAMPLE) data = ImageDataBunch.from_folder(path)
model = simple_cnn((3,16,16,2))
The most important object for training models is
Learner, which needs to know, at minimum, what data to train with and what model to train.
learn = Learner(data, model)
That's enough to train a model, which is done using
fit. If you have a CUDA-capable GPU it will be used automatically. You have to say how many epochs to train for.
To see how our training is going, we can request that it reports various
metrics after each epoch. You can pass it to the constructor, or set it later. Note that metrics are always calculated on the validation set.
You can use
callbacks to modify training in almost any way you can imagine. For instance, we've provided a callback to implement Leslie Smith's 1cycle training method.
cb = OneCycleScheduler(learn, lr_max=0.01) learn.fit(1, callbacks=cb)
Recorder callback is automatically added for you, and you can use it to see what happened in your training, e.g.:
Note that if you're training a model for one of our supported applications, there's a lot of help available to you in the application modules:
For instance, let's use
vision) to quickly fine-tune a pre-trained Imagenet model for MNIST (not a very practical approach, of course, since MNIST is handwriting and our model is pre-trained on photos!).
learn = create_cnn(data, models.resnet18, metrics=accuracy) learn.fit_one_cycle(1)