Learner class and training loop

Basic training functionality

basic_train wraps together the data (in a DataBunch object) with a PyTorch model to define a Learner object. Here the basic training loop is defined for the fit method. The Learner object is the entry point of most of the Callback objects that will customize this training loop in different ways. Some of the most commonly used customizations are available through the train module, notably:

  • Learner.lr_find will launch an LR range test that will help you select a good learning rate.
  • Learner.fit_one_cycle will launch a training using the 1cycle policy to help you train your model faster.
  • Learner.to_fp16 will convert your model to half precision and help you launch a training in mixed precision.

class Learner[source]

Learner(data:DataBunch, model:Module, opt_func:Callable='Adam', loss_func:Callable=None, metrics:Collection[Callable]=None, true_wd:bool=True, bn_wd:bool=True, wd:Floats=0.01, train_bn:bool=True, path:str=None, model_dir:str='models', callback_fns:Collection[Callable]=None, callbacks:Collection[Callback]=<factory>, layer_groups:ModuleList=None)

Trainer for model using data to minimize loss_func with optimizer opt_func.

The main purpose of Learner is to train model using Learner.fit. After every epoch, all metrics will be printed and also made available to callbacks.

The default weight decay will be wd, which will be handled using the method from Fixing Weight Decay Regularization in Adam if true_wd is set (otherwise it's L2 regularization). If bn_wd is False, then weight decay will be removed from batchnorm layers, as recommended in Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour. If train_bn, batchnorm layer learnable params are trained even for frozen layer groups.

To use discriminative layer training, pass a list of nn.Module as layer_groups; each nn.Module will be used to customize the optimization of the corresponding layer group.

If path is provided, all the model files created will be saved in path/model_dir; if not, then they will be saved in data.path/model_dir.

You can pass a list of callbacks that you have already created, or (more commonly) simply pass a list of callback functions to callback_fns and each function will be called (passing self) on object initialization, with the results stored as callback objects. For a walk-through, see the training overview page. You may also want to use an application specific model. For example, if you are dealing with a vision dataset, here the MNIST, you might want to use the create_cnn method:

path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
learn = create_cnn(data, models.resnet18, metrics=accuracy)

Model fitting methods


lr_find(learn:Learner, start_lr:Floats=1e-07, end_lr:Floats=10, num_it:int=100, stop_div:bool=True, wd:float=None)

Explore lr from start_lr to end_lr over num_it iterations in learn. If stop_div, stops when loss diverges.

Runs the learning rate finder defined in LRFinder, as discussed in Cyclical Learning Rates for Training Neural Networks.

LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
Min numerical gradient: 1.32E-02


fit(epochs:int, lr:Union[float, Collection[float], slice]=slice(None, 0.003, None), wd:Floats=None, callbacks:Collection[Callback]=None)

Fit the model on this learner with lr learning rate, wd weight decay for epochs with callbacks.

Uses discriminative layer training if multiple learning rates or weight decay values are passed. To control training behaviour, use the callback system or one or more of the pre-defined callbacks.

Total time: 00:04

epoch train_loss valid_loss accuracy
1 0.129607 0.082084 0.973013


fit_one_cycle(learn:Learner, cyc_len:int, max_lr:Union[float, Collection[float], slice]=slice(None, 0.003, None), moms:Point=(0.95, 0.85), div_factor:float=25.0, pct_start:float=0.3, wd:float=None, callbacks:Optional[Collection[Callback]]=None, tot_epochs:int=None, start_epoch:int=1)

Fit a model following the 1cycle policy.

Use cycle length cyc_len, a per cycle maximal learning rate max_lr, momentum moms, division factor div_factor, weight decay wd, and optional callbacks callbacks. Uses the OneCycleScheduler callback. Please refer to What is 1-cycle for a conceptual background of 1-cycle training policy and more technical details on what do the method's arguments do.

Total time: 00:04

epoch train_loss valid_loss accuracy
1 0.088884 0.066379 0.978410

See results


predict(item:ItemBase, **kwargs)

Return predicted class, label and probabilities for item.

predict can be used to get a single prediction from the trained learner on one specific piece of data you are interested in.

(Image (3, 28, 28), Category 3)

Each element of the dataset is a tuple, where the first element is the data itself, while the second element is the target label. So to get the data, we need to index one more time.

data = learn.data.train_ds[0][0]
pred = learn.predict(data)
(Category 3, tensor(0), tensor([9.9979e-01, 2.0649e-04]))

The first two elements of the tuple are, respectively, the predicted class and label. Label here is essentially an internal representation of each class, since class name is a string and cannot be used in computation. To check what each label corresponds to, run:

['3', '7']

So category 0 is 3 while category 1 is 7.

probs = pred[2]

The last element in the tuple is the predicted probabilities. For a categorization dataset, the number of probabilities returned is the same as the number of classes; probs[i] is the probability that the item belongs to learn.data.classes[i].


You could always check yourself if the probabilities given make sense.


get_preds(ds_type:DatasetType=<DatasetType.Valid: 2>, with_loss:bool=False, n_batch:Optional[int]=None, pbar:Union[MasterBar, ProgressBar, NoneType]=None) → List[Tensor]

Return predictions and targets on ds_type dataset.

It will run inference using the learner on all the data in the ds_type dataset and return the predictions; if n_batch is not specified, it will run the predictions on the default batch size. If with_loss, it will also return the loss on each prediction.

Here is how you check the default batch size.

preds = learn.get_preds()
[tensor([[9.9366e-01, 6.3430e-03],
         [9.9828e-01, 1.7193e-03],
         [9.9993e-01, 7.1130e-05],
         [1.5793e-04, 9.9984e-01],
         [9.0569e-03, 9.9094e-01],
         [9.8014e-01, 1.9864e-02]]), tensor([0, 0, 0,  ..., 1, 1, 1])]

The first element of the tuple is a tensor that contains all the predictions.

tensor([[9.9366e-01, 6.3430e-03],
        [9.9828e-01, 1.7193e-03],
        [9.9993e-01, 7.1130e-05],
        [1.5793e-04, 9.9984e-01],
        [9.0569e-03, 9.9094e-01],
        [9.8014e-01, 1.9864e-02]])

While the second element of the tuple is a tensor that contains all the target labels.

tensor([0, 0, 0,  ..., 1, 1, 1])

For more details about what each number mean, refer to the documentation of predict.

Since get_preds gets predictions on all the data in the ds_type dataset, here the number of predictions will be equal to the number of data in the validation dataset.

len(preds[0]), len(preds[1])
(2038, 2038)

To get predictions on the entire training dataset, simply set the ds_type argument accordingly.

[tensor([[9.9973e-01, 2.6554e-04],
         [9.9962e-01, 3.8422e-04],
         [9.9988e-01, 1.1570e-04],
         [9.9922e-01, 7.8436e-04],
         [4.4838e-04, 9.9955e-01],
         [1.3715e-04, 9.9986e-01]]), tensor([0, 0, 0,  ..., 0, 1, 1])]

To also get prediction loss along with the predictions and the targets, set with_loss=True in the arguments.

[tensor([[9.9366e-01, 6.3430e-03],
         [9.9828e-01, 1.7193e-03],
         [9.9993e-01, 7.1130e-05],
         [1.5793e-04, 9.9984e-01],
         [9.0569e-03, 9.9094e-01],
         [9.8014e-01, 1.9864e-02]]),
 tensor([0, 0, 0,  ..., 1, 1, 1]),
 tensor([6.3632e-03, 1.7209e-03, 7.1049e-05,  ..., 1.5783e-04, 9.0983e-03,

Note that the third tensor in the output tuple contains the losses.


validate(dl=None, callbacks=None, metrics=None)

Validate on dl with potential callbacks and metrics.

Return the calculated loss and the metrics of the current model on the given data loader dl. The default data loader dl is the validation dataloader.

You can check the default metrics of the learner using:

'[<function accuracy at 0x7f1effc86d08>]'
[0.06637867, tensor(0.9784)]
[0.06637867, tensor(0.9784)]
[0.039573476, tensor(0.9860)]


show_results(ds_type=<DatasetType.Valid: 2>, rows:int=5, **kwargs)

Show rows result of predictions on ds_type dataset.

Note that the text number on the top is the ground truth, or the target label, the one in the middle is the prediction, while the image number on the bottom is the image data itself.



pred_batch(ds_type:DatasetType=<DatasetType.Valid: 2>, batch:Tuple=None, reconstruct:bool=False) → List[Tensor]

Return output of the model on one batch from ds_type dataset.

Note that the number of predictions given equals to the batch size.

preds = learn.pred_batch()

Since the total number of predictions is too large, we will only look at a part of them.

tensor([[9.9366e-01, 6.3430e-03],
        [9.9828e-01, 1.7193e-03],
        [9.9993e-01, 7.1130e-05],
        [1.0000e+00, 5.2653e-07],
        [9.9839e-01, 1.6092e-03],
        [1.0000e+00, 9.6659e-07],
        [9.5156e-01, 4.8442e-02],
        [9.9854e-01, 1.4628e-03],
        [9.9937e-01, 6.2854e-04],
        [8.3490e-01, 1.6510e-01]])
item = learn.data.train_ds[0][0]
batch = learn.data.one_item(item)
(tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]]], device='cuda:0'),
 tensor([0], device='cuda:0'))
tensor([[9.9979e-01, 2.0649e-04]])


interpret(learn:Learner, ds_type:DatasetType=<DatasetType.Valid: 2>, tta=False)

Create a ClassificationInterpretation object from learner on ds_type with tta.

For more details, refer to ClassificationInterpretation

Model summary


model_summary(m:Learner, n:int=70)

Print a summary of m using a output text width of n chars

Test time augmentation


TTA(learn:Learner, beta:float=0.4, scale:float=1.35, ds_type:DatasetType=<DatasetType.Valid: 2>, with_loss:bool=False) → Tensors

Applies TTA to predict on ds_type dataset.

Applies Test Time Augmentation to learn on the dataset ds_type. We take the average of our regular predictions (with a weight beta) with the average of predictions obtained through augmented versions of the training set (with a weight 1-beta). The transforms decided for the training set are applied with a few changes scale controls the scale for zoom (which isn't random), the cropping isn't random but we make sure to get the four corners of the image. Flipping isn't random but applied once on each of those corner images (so that makes 8 augmented versions total).

Gradient clipping


clip_grad(learn:Learner, clip:float=0.1) → Learner

Add gradient clipping of clip during training.

Mixed precision training


to_fp16(learn:Learner, loss_scale:float=None, max_noskip:int=1000, dynamic:bool=False, clip:float=None, flat_master:bool=False) → Learner

Put learn in FP16 precision mode.

Uses the MixedPrecision callback to train in mixed precision (i.e. forward and backward passes using fp16, with weight updates using fp32), using all NVIDIA recommendations for ensuring speed and accuracy.



Put learn back to FP32 precision mode.

Distributed training


distributed(learn:Learner, cuda_id:int, cache_dir:PathOrStr='tmp')

Put learn on distributed training with cuda_id.

Discriminative layer training

When fitting a model you can pass a list of learning rates (and/or weight decay amounts), which will apply a different rate to each layer group (i.e. the parameters of each module in self.layer_groups). See the Universal Language Model Fine-tuning for Text Classification paper for details and experimental results in NLP (we also frequently use them successfully in computer vision, but have not published a paper on this topic yet). When working with a Learner on which you've called split, you can set hyperparameters in four ways:

  1. param = [val1, val2 ..., valn] (n = number of layer groups)
  2. param = val
  3. param = slice(start,end)
  4. param = slice(end)

If we chose to set it in way 1, we must specify a number of values exactly equal to the number of layer groups. If we chose to set it in way 2, the chosen value will be repeated for all layer groups. See Learner.lr_range for an explanation of the slice syntax).

Here's an example of how to use discriminative learning rates (note that you don't actually need to manually call Learner.split in this case, since fastai uses this exact function as the default split for resnet18; this is just to show how to customize it):

# creates 3 layer groups
learn.split(lambda m: (m[0][6], m[1]))
# only randomly initialized head now trainable
Total time: 00:04

epoch train_loss valid_loss accuracy
1 0.067769 0.060910 0.979392
# all layers now trainable
# optionally, separate LR and WD for each group
learn.fit_one_cycle(1, max_lr=(1e-4, 1e-3, 1e-2), wd=(1e-4,1e-4,1e-1))
Total time: 00:06

epoch train_loss valid_loss accuracy
1 0.022366 0.006872 0.998037


lr_range(lr:Union[float, slice]) → ndarray

Build differential learning rates from lr.

Rather than manually setting an LR for every group, it's often easier to use Learner.lr_range. This is a convenience method that returns one learning rate for each layer group. If you pass slice(start,end) then the first group's learning rate is start, the last is end, and the remaining are evenly geometrically spaced.

If you pass just slice(end) then the last group's learning rate is end, and all the other groups are end/10. For instance (for our learner that has 3 layer groups):

learn.lr_range(slice(1e-5,1e-3)), learn.lr_range(slice(1e-3))
(array([1.e-05, 1.e-04, 1.e-03]), array([0.0001, 0.0001, 0.001 ]))



Unfreeze entire model.

Sets every layer group to trainable (i.e. requires_grad=True).



Freeze up to last layer.

Sets every layer group except the last to untrainable (i.e. requires_grad=False).



Freeze layers up to layer n.



Split the model at split_on.

A convenience method that sets layer_groups based on the result of split_model. If split_on is a function, it calls that function and passes the result to split_model (see above for example).

Saving and loading models

Simply call Learner.save and Learner.load to save and load models. Only the parameters are saved, not the actual architecture (so you'll need to create your model in the same way before loading weights back in). Models are saved to the path/model_dir directory.



save(name:PathOrStr, return_path:bool=False, with_opt:bool=True)

Save model and optimizer state (if with_opt) with name to self.model_dir.

learn.save("trained_model", return_path=True)


load(name:PathOrStr, device:device=None, strict:bool=True, with_opt:bool=None, purge:bool=False)

Load model and optimizer state (if with_opt) name from self.model_dir using device.

learn = learn.load("trained_model")

Deploying your model

When you are ready to put your model in production, export the minimal state of your Learner with



Export the state of the Learner in self.path/fname.

path = learn.path


load_learner(path:PathOrStr, fname:PathOrStr='export.pkl', test:ItemList=None)

Load a Learner object saved with export_state in path/fn with empty data, optionally add test and load on cpu.

learn = load_learner(path)
learn = load_learner(path, fname='trained_model.pkl')

WARNING: If you used any customized classes when creating your learner, you must first define these classes first before executing load_learner.

You can find more information and multiple examples in this tutorial

Other methods



Initializes all weights (except batchnorm) using function init, which will often be from PyTorch's nn.init module.


mixup(learn:Learner, alpha:float=0.4, stack_x:bool=False, stack_y:bool=True) → Learner

Add mixup https://arxiv.org/abs/1710.09412 to learn.



Pass item through the model and computes the gradient. Useful if backward_hooks are attached.


create_opt(lr:Floats, wd:Floats=0.0)

Create optimizer with lr learning rate and wd weight decay.

You generally won't need to call this yourself - it's used to create the optim optimizer before fitting the model.


dl(ds_type:DatasetType=<DatasetType.Valid: 2>)

Return DataLoader for DatasetType ds_type.

DeviceDataLoader(dl=<torch.utils.data.dataloader.DataLoader object at 0x7f1efe504780>, device=device(type='cuda'), tfms=[], collate_fn=<function data_collate at 0x7f1f16f140d0>)
DeviceDataLoader(dl=<torch.utils.data.dataloader.DataLoader object at 0x7f1f696aa4a8>, device=device(type='cuda'), tfms=[], collate_fn=<function data_collate at 0x7f1f16f140d0>)

class Recorder[source]

Recorder(learn:Learner) :: LearnerCallback

A LearnerCallback that records epoch, loss, opt and metric data during training.

A Learner creates a Recorder object automatically - you do not need to explicitly pass it to callback_fns - because other callbacks rely on it being available. It stores the smoothed loss, hyperparameter values, and metrics for each batch, and provides plotting methods for each. Note that Learner automatically sets an attribute with the snake-cased name of each callback, so you can access this through Learner.recorder, as shown below.

Plotting methods


plot(skip_start:int=10, skip_end:int=5)

Plot learning rate and losses, trimmed between skip_start and skip_end. Optionally plot and return min gradient

This is mainly used with the learning rate finder, since it shows a scatterplot of loss vs learning rate.

path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
learn = create_cnn(data, models.resnet18, metrics=accuracy)
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
Min numerical gradient: 7.59E-03



Plot training and validation losses.

Note that validation losses are only calculated once per epoch, whereas training losses are calculated after every batch.

Total time: 00:22

epoch train_loss valid_loss accuracy
1 0.247247 0.141247 0.954367
2 0.109672 0.078876 0.972522
3 0.065391 0.054635 0.983808
4 0.044042 0.049592 0.981845
5 0.041287 0.049224 0.984298



Plot learning rate, show_moms to include momentum.




Plot metrics collected during training.

Note that metrics are only collected at the end of each epoch, so you'll need to train at least two epochs to have anything to show here.


Callback methods

You don't call these yourself - they're called by fastai's Callback system automatically to enable the class's functionality. Refer to Callback for more details.


on_backward_begin(smooth_loss:Tensor, **kwargs:Any)

Record the loss before any other callback has a chance to modify it.


on_batch_begin(train, **kwargs:Any)

Record learning rate and momentum at beginning of batch.


on_epoch_end(epoch:int, num_batch:int, smooth_loss:Tensor, last_metrics=typing.Collection[typing.Union[torch.Tensor, numbers.Number]], **kwargs:Any) → bool

Save epoch info: num_batch, smooth_loss, metrics.


on_train_begin(pbar:PBar, metrics_names:StrList, **kwargs:Any)

Initialize recording status at beginning of training.

Inner functions

The following functions are used along the way by the Recorder or can be called by other callbacks.



Add metrics to the inner stats.



Add names to the inner metric names.



Format stats before printing.

Module functions

Generally you'll want to use a Learner to train your model, since they provide a lot of functionality and make things easier. However, for ultimate flexibility, you can call the same underlying functions that Learner calls behind the scenes:


fit(epochs:int, model:Module, loss_func:LossFunction, opt:Optimizer, data:DataBunch, callbacks:Optional[Collection[Callback]]=None, metrics:OptMetrics=None)

Fit the model on data and learn using loss_func and opt.

Note that you have to create the Optimizer yourself if you call this function, whereas Learn.fit creates it for you automatically.


train_epoch(model:Module, dl:DataLoader, opt:Optimizer, loss_func:LossFunction)

Simple training of model for 1 epoch of dl using optim opt and loss function loss_func.

You won't generally need to call this yourself - it's what fit calls for each epoch.


validate(model:Module, dl:DataLoader, loss_func:OptLossFunc=None, cb_handler:Optional[CallbackHandler]=None, pbar:Union[MasterBar, ProgressBar, NoneType]=None, average=True, n_batch:Optional[int]=None) → Iterator[Tuple[IntOrTensor, Ellipsis]]

Calculate loss_func of model on dl in evaluation mode.

This is what fit calls after each epoch. You can call it if you want to run inference on a DataLoader manually.


get_preds(model:Module, dl:DataLoader, pbar:Union[MasterBar, ProgressBar, NoneType]=None, cb_handler:Optional[CallbackHandler]=None, activ:Module=None, loss_func:OptLossFunc=None, n_batch:Optional[int]=None) → List[Tensor]

Tuple of predictions and targets, and optional losses (if loss_func) using dl, max batches n_batch.


loss_batch(model:Module, xb:Tensor, yb:Tensor, loss_func:OptLossFunc=None, opt:OptOptimizer=None, cb_handler:Optional[CallbackHandler]=None) → Tuple[Union[Tensor, int, float, str]]

Calculate loss and metrics for a batch, call out to callbacks as necessary.

You won't generally need to call this yourself - it's what fit and validate call for each batch. It only does a backward pass if you set opt.

Other classes

class LearnerCallback[source]

LearnerCallback(learn) :: Callback

Base class for creating callbacks for a Learner.

class RecordOnCPU[source]

RecordOnCPU() :: Callback

Store the input and target going through the model on the CPU.