Learner, Metrics, Callbacks

Basic class for handling the training loop

You probably want to jump directly to the definition of Learner.

Utils function



 replacing_yield (o, attr, val)

Context manager to temporarily replace an attribute

class _A:
    def __init__(self, a): self.a = a
    def a_changed(self, v): return replacing_yield(self, 'a', v)

a = _A(42)
with a.a_changed(32):
    test_eq(a.a, 32)
test_eq(a.a, 42)



 mk_metric (m)

Convert m to an AvgMetric, unless it’s already a Metric

See the class Metric below for more information.



 save_model (file, model, opt, with_opt=True, pickle_protocol=2,

Save model to file along with opt (if available, and if with_opt)

file can be a Path object, a string or an opened file object. pickle_protocol and torch_save_kwargs is passed along to torch.save



 load_model (file, model, opt, with_opt=True, device=None, strict=True,

Load model from file along with opt (if available, and if with_opt)

file can be a Path object, a string or an opened file object. If a device is passed, the model is loaded on it, otherwise it’s loaded on the CPU.

If strict is True, the file must exactly contain weights for every parameter key in model, if strict is False, only the keys that are in the saved model are loaded in model.

You can pass in other kwargs to torch.load through torch_load_kwargs.



 SkipToEpoch (epoch:int)

Skip training up to epoch



Group together a model, some dls and a loss_func to handle training

opt_func will be used to create an optimizer when Learner.fit is called, with lr as a default learning rate. splitter is a function that takes self.model and returns a list of parameter groups (or just one parameter group if there are no different parameter groups). The default is trainable_params, which returns all trainable parameters of the model.

cbs is one or a list of Callbacks to pass to the Learner. Callbacks are used for every tweak of the training loop. Each Callback is registered as an attribute of Learner (with camel case). At creation, all the callbacks in defaults.callbacks (TrainEvalCallback, Recorder and ProgressCallback) are associated to the Learner.

metrics is an optional list of metrics, that can be either functions or Metrics (see below).

path and model_dir are used to save and/or load models. Often path will be inferred from dls, but you can override it or pass a Path object to model_dir. Make sure you can write in path/model_dir!

wd is the default weight decay used when training the model; moms, the default momentums used in Learner.fit_one_cycle. wd_bn_bias controls if weight decay is applied to BatchNorm layers and bias.

Lastly, train_bn controls if BatchNorm layers are trained even when they are supposed to be frozen according to the splitter. Our empirical experiments have shown that it’s the best behavior for those layers in transfer learning.

PyTorch interop

You can use regular PyTorch functionality for most of the arguments of the Learner, although the experience will be smoother with pure fastai objects and you will be able to use the full functionality of the library. The expectation is that the training loop will work smoothly even if you did not use fastai end to end. What you might lose are interpretation objects or showing functionality. The list below explains how to use plain PyTorch objects for all the arguments and what you might lose.

The most important is opt_func. If you are not using a fastai optimizer, you will need to write a function that wraps your PyTorch optimizer in an OptimWrapper. See the optimizer module for more details. This is to ensure the library’s schedulers/freeze API work with your code.

  • dls is a DataLoaders object, that you can create from standard PyTorch dataloaders. By doing so, you will lose all showing functionality like show_batch/show_results. You can check the data block API or the mid-level data API tutorial to learn how to use fastai to gather your data!
  • model is a standard PyTorch model. You can use anyone you like, just make sure it accepts the number of inputs you have in your DataLoaders and returns as many outputs as you have targets.
  • loss_func can be any loss function you like. It needs to be one of fastai’s if you want to use Learn.predict or Learn.get_preds, or you will have to implement special methods (see more details after the BaseLoss documentation).

Training loop

Now let’s look at the main thing the Learner class implements: the training loop.



 Learner.fit (n_epoch, lr=None, wd=None, cbs=None, reset_opt=False,

Fit self.model for n_epoch using cbs. Optionally reset_opt.

Uses lr and wd if they are provided, otherwise use the defaults values given by the lr and wd attributes of Learner.

All the examples use synth_learner which is a simple Learner training a linear regression model.

#Training a few epochs should make the model better
learn = synth_learner(lr=0.1)
learn.model = learn.model.cpu()
xb,yb = learn.dls.one_batch()
init_loss = learn.loss_func(learn.model(xb), yb)
xb,yb = learn.dls.one_batch()
final_loss = learn.loss_func(learn.model(xb), yb)
assert final_loss < init_loss, (final_loss,init_loss)



 Learner.one_batch (i, b)

Train or evaluate self.model on batch (xb,yb)

This is an internal method called by Learner.fit. If passed, i is the index of this iteration in the epoch. In training mode, this does a full training step on the batch (compute predictions, loss, gradients, update the model parameters and zero the gradients). In validation mode, it stops at the loss computation. Training or validation is controlled internally by the TrainEvalCallback through the training attribute.

Nothing is returned, but the attributes x, y, pred, loss of the Learner are set with the proper values:

b = learn.dls.one_batch()
learn.one_batch(0, b)
test_eq(learn.x, b[0])
test_eq(learn.y, b[1])
out = learn.model(learn.x)
test_eq(learn.pred, out)
test_eq(learn.loss, learn.loss_func(out, b[1]))



 Learner.all_batches ()

Train or evaluate self.model on all the batches of self.dl



 Learner.create_opt ()

Create an optimizer with default hyper-parameters

This method is called internally to create the optimizer, the hyper-parameters are then adjusted by what you pass to Learner.fit or your particular schedulers (see callback.schedule).

learn = synth_learner(n_train=5, cbs=VerboseCallback())
assert learn.opt is None
assert learn.opt is not None
test_eq(learn.opt.hypers[0]['lr'], learn.lr)
learn = synth_learner(n_train=5, cbs=VerboseCallback(), opt_func=partial(OptimWrapper, opt=torch.optim.Adam))
assert learn.opt is None
assert learn.opt is not None
test_eq(learn.opt.hypers[0]['lr'], learn.lr)
wrapper_lr = 1
learn = synth_learner(n_train=5, cbs=VerboseCallback(), opt_func=partial(OptimWrapper, opt=torch.optim.Adam, lr=wrapper_lr))
assert learn.opt is None
assert learn.opt is not None
test_eq(learn.opt.hypers[0]['lr'], wrapper_lr)

Callback handling

We only describe the basic functionality linked to Callbacks here. To learn more about Callbacks and how to write them, check the callback.core module documentation.

Let’s first see how the Callbacks become attributes of Learner:

#Test init with callbacks
class TstCallback(Callback):
    def batch_begin(self): self.learn.a = self.a + 1

tst_learn = synth_learner()
test_eq(len(tst_learn.cbs), 1)
assert hasattr(tst_learn, ('train_eval'))

tst_learn = synth_learner(cbs=TstCallback())
test_eq(len(tst_learn.cbs), 2)
assert hasattr(tst_learn, ('tst'))



 Learner.__call__ (event_name)

Call event_name for all Callbacks in self.cbs

This how the Callbacks are called internally. For instance a VerboseCallback just prints the event names (can be useful for debugging):

learn = synth_learner(cbs=VerboseCallback())



 Learner.add_cb (cb)

Add cb to the list of Callback and register self as their learner

learn = synth_learner()
test_eq(len(learn.cbs), 2)
assert isinstance(learn.cbs[1], TestTrainEvalCallback)
test_eq(learn.train_eval.learn, learn)



 Learner.add_cbs (cbs)

Add cbs to the list of Callback and register self as their learner

learn.add_cbs([TestTrainEvalCallback(), TestTrainEvalCallback()])
test_eq(len(learn.cbs), 4)



 Learner.added_cbs (cbs)
learn = synth_learner()
test_eq(len(learn.cbs), 1)
with learn.added_cbs(TestTrainEvalCallback()):
    test_eq(len(learn.cbs), 2)



 Learner.ordered_cbs (event)

List of Callbacks, in order, for an event in the training loop

By order, we mean using the internal ordering of the Callbacks (see callback.core for more information on how it works).

learn = synth_learner()
[TrainEvalCallback, TestTrainEvalCallback]



 Learner.remove_cb (cb)

Add cb from the list of Callback and deregister self as their learner

learn = synth_learner()
cb = learn.cbs[1]
test_eq(len(learn.cbs), 1)
assert cb.learn is None
assert not getattr(learn,'test_train_eval',None)

cb can simply be the class of the Callback we want to remove (in which case all instances of that callback are removed).

learn = synth_learner()
learn.add_cbs([TestTrainEvalCallback(), TestTrainEvalCallback()])
test_eq(len(learn.cbs), 1)
assert not getattr(learn,'test_train_eval',None)



 Learner.remove_cbs (cbs)

Remove cbs from the list of Callback and deregister self as their learner

Elements of cbs can either be types of callbacks or actual callbacks of the Learner.

learn = synth_learner()
learn.add_cbs([TestTrainEvalCallback() for _ in range(3)])
cb = learn.cbs[1]
test_eq(len(learn.cbs), 1)



 Learner.removed_cbs (cbs)

Elements of cbs can either be types of callbacks or actual callbacks of the Learner.

learn = synth_learner()
with learn.removed_cbs(learn.cbs[1]):
    test_eq(len(learn.cbs), 1)
test_eq(len(learn.cbs), 2)



 Learner.show_training_loop ()

Show each step in the training loop

At each step, callbacks are shown in order, which can help debugging.

learn = synth_learner()
Start Fit
   - before_fit     : [TrainEvalCallback]
  Start Epoch Loop
     - before_epoch   : []
    Start Train
       - before_train   : [TrainEvalCallback]
      Start Batch Loop
         - before_batch   : []
         - after_pred     : []
         - after_loss     : []
         - before_backward: []
         - before_step    : []
         - after_step     : []
         - after_cancel_batch: []
         - after_batch    : [TrainEvalCallback]
      End Batch Loop
    End Train
     - after_cancel_train: []
     - after_train    : []
    Start Valid
       - before_validate: [TrainEvalCallback]
      Start Batch Loop
         - **CBs same as train batch**: []
      End Batch Loop
    End Valid
     - after_cancel_validate: []
     - after_validate : []
  End Epoch Loop
   - after_cancel_epoch: []
   - after_epoch    : []
End Fit
 - after_cancel_fit: []
 - after_fit      : []



 before_batch_cb (f)

Shortcut for creating a Callback on the before_batch event, which takes and returns xb,yb

In order to change the data passed to your model, you will generally want to hook into the before_batch event, like so:

class TstCallback(Callback):
    def before_batch(self):
        self.learn.xb = self.xb + 1000
        self.learn.yb = self.yb - 1000

Since that is so common, we provide the before_batch_cb decorator to make it easier.

def cb(self, xb, yb): return xb+1000,yb-1000




 Learner.save (file, with_opt=True, pickle_protocol=2)

Save model and optimizer state (if with_opt) to self.path/self.model_dir/file

file can be a Path, a string or a buffer. pickle_protocol is passed along to torch.save.



 Learner.load (file, device=None, with_opt=True, strict=True)

Load model and optimizer state (if with_opt) from self.path/self.model_dir/file using device

file can be a Path, a string or a buffer. Use device to load the model/optimizer state on a device different from the one it was saved.

with tempfile.TemporaryDirectory() as d:
    learn = synth_learner(path=d)
    #Test save created a file
    assert (Path(d)/'models/tmp.pth').exists()
    #Test load did load the model
    learn1 = synth_learner(path=d)
    learn1 = learn1.load('tmp')
    test_eq(learn.a, learn1.a)
    test_eq(learn.b, learn1.b)
    test_eq(learn.opt.state_dict(), learn1.opt.state_dict())



 Learner.export (fname='export.pkl', pickle_module=<module 'pickle' from
                 '/usr/lib/python3.10/pickle.py'>, pickle_protocol=2)

Export the content of self without the items and the optimizer state for inference

The Learner is saved in self.path/fname, using pickle_protocol. Note that serialization in Python saves the names of functions, not the code itself. Therefore, any custom code you have for models, data transformation, loss function etc… should be put in a module that you will import in your training environment before exporting, and in your deployment environment before loading it.



 load_learner (fname, cpu=True, pickle_module=<module 'pickle' from

Load a Learner object in fname, by default putting it on the cpu


load_learner requires all your custom code be in the exact same place as when exporting your Learner (the main script, or the module you imported it from).



 Learner.to_detach (b, cpu=True, gather=True)

Calls to_detach if self.dl provides a .to_detach function otherwise calls global to_detach

fastai provides to_detach which by default detachs tensor gradients, and gathers (calling maybe_gather) tensors from all ranks if running in distributed data parallel (DDP) mode.

When running in DDP mode all ranks need to have the same batch size, and DistributedDL takes care of padding batches as needed; however when gathering all tensors (e.g. for calculating metrics, inference, etc.) we need to discard the padded items. DistributedDL provides a method to_detach that removes padding appropriately.

Calling the learner’s to_detach method will attempt to find a to_detach method in the learner’s last used DataLoader dl and use that one if found, otherwise it will resort to the vanilla to_detach.



 Metric ()

Blueprint for defining a metric

Metrics can be simple averages (like accuracy) but sometimes their computation is a little bit more complex and can’t be averaged over batches (like precision or recall), which is why we need a special class for them. For simple functions that can be computed as averages over batches, we can use the class AvgMetric, otherwise you’ll need to implement the following methods.


If your Metric has state depending on tensors, don’t forget to store it on the CPU to avoid any potential memory leaks.



 Metric.reset ()

Reset inner state to prepare for new computation



 Metric.accumulate (learn)

Use learn to update the state with new results



 Metric.value ()



 Metric.name ()



 AvgMetric (func)

Average the values of func taking into account potential different batch sizes

learn = synth_learner()
tst = AvgMetric(lambda x,y: (x-y).abs().mean())
t,u = torch.randn(100),torch.randn(100)
for i in range(0,100,25): 
    learn.pred,learn.yb = t[i:i+25],(u[i:i+25],)
test_close(tst.value, (t-u).abs().mean())



 AvgLoss ()

Average the losses taking into account potential different batch sizes

tst = AvgLoss()
t = torch.randn(100)
for i in range(0,100,25): 
    learn.yb,learn.loss = t[i:i+25],t[i:i+25].mean()
test_close(tst.value, t.mean())



 AvgSmoothLoss (beta=0.98)

Smooth average of the losses (exponentially weighted with beta)

tst = AvgSmoothLoss()
t = torch.randn(100)
val = tensor(0.)
for i in range(4): 
    learn.loss = t[i*25:(i+1)*25].mean()
    val = val*0.98 + t[i*25:(i+1)*25].mean()*(1-0.98)
    test_close(val/(1-0.98**(i+1)), tst.value)



 ValueMetric (func, metric_name=None)

Use to include a pre-calculated metric value (for instance calculated in a Callback) and returned by func

def metric_value_fn(): return 5e-3

vm = ValueMetric(metric_value_fn, 'custom_value_metric')
test_eq(vm.value, 5e-3)
test_eq(vm.name, 'custom_value_metric')

vm = ValueMetric(metric_value_fn)
test_eq(vm.name, 'metric_value_fn')

Recorder –



 Recorder (add_time=True, train_metrics=False, valid_metrics=True,

Callback that registers statistics (lr, loss and metrics) during training

By default, metrics are computed on the validation set only, although that can be changed by adjusting train_metrics and valid_metrics. beta is the weight used to compute the exponentially weighted average of the losses (which gives the smooth_loss attribute to Learner).

The logger attribute of a Learner determines what happens to those metrics. By default, it just print them:

#Test printed output
def tst_metric(out, targ): return F.mse_loss(out, targ)
learn = synth_learner(n_train=5, metrics=tst_metric, default_cbs=False, cbs=[TrainEvalCallback, Recorder])
# pat = r"[tensor\(\d.\d*\), tensor\(\d.\d*\), tensor\(\d.\d*\), 'dd:dd']"
pat = r"\[\d, \d+.\d+, \d+.\d+, \d+.\d+, '\d\d:\d\d'\]"
test_stdout(lambda: learn.fit(1), pat, regex=True)




 Recorder.before_fit ()

Prepare state for training



 Recorder.before_epoch ()

Set timer if self.add_time=True



 Recorder.before_validate ()

Reset loss and metrics state



 Recorder.after_batch ()

Update all metrics and records lr and smooth loss in training



 Recorder.after_epoch ()

Store and log the loss/metric values

Plotting tools



 Recorder.plot_loss (skip_start=5, with_valid=True, log=False,
                     show_epochs=False, ax=None)

Plot the losses from skip_start and onward. Optionally log=True for logarithmic axis, show_epochs=True for indicate epochs and a matplotlib axis ax to plot on.



 CastToTensor (after_create=None, before_fit=None, before_epoch=None,
               before_train=None, before_batch=None, after_pred=None,
               after_loss=None, before_backward=None,
               after_cancel_backward=None, after_backward=None,
               before_step=None, after_cancel_step=None, after_step=None,
               after_cancel_batch=None, after_batch=None,
               after_cancel_train=None, after_train=None,
               before_validate=None, after_cancel_validate=None,
               after_validate=None, after_cancel_epoch=None,
               after_epoch=None, after_cancel_fit=None, after_fit=None)

Cast Subclassed Tensors to Tensor

Workaround for bug in PyTorch where subclassed tensors, such as TensorBase, train up to ~20% slower than Tensor when passed to a model. Added to Learner by default.

CastToTensor’s order is right before MixedPrecision so callbacks which make use of fastai’s tensor subclasses still can use them.

If inputs are not a subclassed tensor or tuple of tensors, you may need to cast inputs in Learner.xb and Learner.yb to Tensor via your own callback or in the dataloader before Learner performs the forward pass.

If the CastToTensor workaround interferes with custom code, it can be removed:

learn = Learner(...)

You should verify your inputs are of type Tensor or implement a cast to Tensor via a custom callback or dataloader if CastToTensor is removed.

Inference functions



 Learner.validate (ds_idx=1, dl=None, cbs=None)

Validate on dl with potential new cbs.

#Test result
learn = synth_learner(n_train=5, metrics=tst_metric)
res = learn.validate()
test_eq(res[0], res[1])
x,y = learn.dls.valid_ds.tensors
test_close(res[0], F.mse_loss(learn.model(x), y), 1e-3)



 Learner.get_preds (ds_idx:int=1, dl=None, with_input:bool=False,
                    with_decoded:bool=False, with_loss:bool=False,
                    act=None, inner:bool=False, reorder:bool=True,
                    save_preds:Path=None, save_targs:Path=None,
                    with_preds:bool=True, with_targs:bool=True,
                    concat_dim:int=0, pickle_protocol:int=2)

Get the predictions and targets on the ds_idx-th dbunchset or dl, optionally with_input and with_loss

Type Default Details
ds_idx int 1 DataLoader to use for predictions if dl is None. 0: train. 1: valid
dl NoneType None DataLoader to use for predictions, defaults to ds_idx=1 if None
with_input bool False Return inputs with predictions
with_decoded bool False Return decoded predictions
with_loss bool False Return per item loss with predictions
act NoneType None Apply activation to predictions, defaults to self.loss_func’s activation
inner bool False If False, create progress bar, show logger, use temporary cbs
reorder bool True Reorder predictions on dataset indicies, if applicable
cbs Callback | MutableSequence | None None Temporary Callbacks to apply during prediction
save_preds Path None Path to save predictions
save_targs Path None Path to save targets
with_preds bool True Whether to return predictions
with_targs bool True Whether to return targets
concat_dim int 0 Dimension to concatenate returned tensors
pickle_protocol int 2 Pickle protocol used to save predictions and targets
Returns tuple

with_decoded will also return the decoded predictions using the decodes function of the loss function (if it exists). For instance, fastai’s CrossEntropyFlat takes the argmax or predictions in its decodes.

Depending on the loss_func attribute of Learner, an activation function will be picked automatically so that the predictions make sense. For instance if the loss is a case of cross-entropy, a softmax will be applied, or if the loss is binary cross entropy with logits, a sigmoid will be applied. If you want to make sure a certain activation function is applied, you can pass it with act.

save_preds and save_targs should be used when your predictions are too big to fit all in memory. Give a Path object that points to a folder where the predictions and targets will be saved.

concat_dim is the batch dimension, where all the tensors will be concatenated.

inner is an internal attribute that tells get_preds it’s called internally, inside another training loop, to avoid recursion errors.


If you want to use the option with_loss=True on a custom loss function, make sure you have implemented a reduction attribute that supports ‘none’

#Test result
learn = synth_learner(n_train=5, metrics=tst_metric)
preds,targs = learn.get_preds()
x,y = learn.dls.valid_ds.tensors
test_eq(targs, y)
test_close(preds, learn.model(x))

preds,targs = learn.get_preds(act = torch.sigmoid)
test_eq(targs, y)
test_close(preds, torch.sigmoid(learn.model(x)))



 Learner.predict (item, rm_type_tfms=None, with_input=False)

Prediction on item, fully decoded, loss function decoded and probabilities

It returns a tuple of three elements with, in reverse order, - the prediction from the model, potentially passed through the activation of the loss function (if it has one) - the decoded prediction, using the potential decodes method from it - the fully decoded prediction, using the transforms used to build the Datasets/DataLoaders

rm_type_tfms is a deprecated argument that should not be used and will be removed in a future version. with_input will add the decoded inputs to the result.

class _FakeLossFunc(Module):
    reduction = 'none'
    def forward(self, x, y): return F.mse_loss(x,y)
    def activation(self, x): return x+1
    def decodes(self, x):    return 2*x

class _Add1(Transform):
    def encodes(self, x): return x+1
    def decodes(self, x): return x-1
learn = synth_learner(n_train=5)
dl = TfmdDL(Datasets(torch.arange(50), tfms = [L(), [_Add1()]]))
learn.dls = DataLoaders(dl, dl)
learn.loss_func = _FakeLossFunc()

inp = tensor([2.])
out = learn.model(inp).detach()+1  #applying model + activation
dec = 2*out                        #decodes from loss function
full_dec = dec-1                   #decodes from _Add1
test_eq(learn.predict(inp), [full_dec,dec,out])
test_eq(learn.predict(inp, with_input=True), [inp,full_dec,dec,out])



 Learner.show_results (ds_idx=1, dl=None, max_n=9, shuffle=True, **kwargs)

Show some predictions on ds_idx-th dataset or dl

Will show max_n samples (unless the batch size of ds_idx or dl is less than max_n, in which case it will show as many samples) and shuffle the data unless you pass false to that flag. kwargs are application-dependent.

We can’t show an example on our synthetic Learner, but check all the beginners tutorials which will show you how that method works across applications.

The last functions in this section are used internally for inference, but should be less useful to you.



 Learner.no_logging ()
learn = synth_learner(n_train=5, metrics=tst_metric)
with learn.no_logging():
    test_stdout(lambda: learn.fit(1), '')
test_eq(learn.logger, print)



 Learner.loss_not_reduced ()

This requires your loss function to either have a reduction attribute or a reduction argument (like all fastai and PyTorch loss functions).

Transfer learning



 Learner.unfreeze ()

Unfreeze the entire model



 Learner.freeze ()

Freeze up to last parameter group



 Learner.freeze_to (n)

Freeze parameter groups up to n




 Learner.tta (ds_idx=1, dl=None, n=4, item_tfms=None, batch_tfms=None,
              beta=0.25, use_max=False)

Return predictions on the ds_idx dataset or dl using Test Time Augmentation

In practice, we get the predictions n times with the transforms of the training set and average those. The final predictions are (1-beta) multiplied by this average + beta multiplied by the predictions obtained with the transforms of the dataset. Set beta to None to get a tuple of the predictions and tta results. You can also use the maximum of all predictions instead of an average by setting use_max=True.

If you want to use new transforms, you can pass them with item_tfms and batch_tfms.