test_eq(event.before_step, 'before_step')Callbacks
Events
Callbacks can occur at any of these times:: after_create before_fit before_epoch before_train before_batch after_pred after_loss before_backward after_cancel_backward after_backward before_step after_cancel_step after_step after_cancel_batch after_batch after_cancel_train after_train before_validate after_cancel_validate after_validate after_cancel_epoch after_epoch after_cancel_fit after_fit.
event
def event(
args:VAR_POSITIONAL, kwargs:VAR_KEYWORD
):
All possible events as attributes to get tab-completion and typo-proofing
To ensure that you are referring to an event (that is, the name of one of the times when callbacks are called) that exists, and to get tab completion of event names, use event:
Callback
def Callback(
after_create:NoneType=None, before_fit:NoneType=None, before_epoch:NoneType=None, before_train:NoneType=None,
before_batch:NoneType=None, after_pred:NoneType=None, after_loss:NoneType=None, before_backward:NoneType=None,
after_cancel_backward:NoneType=None, after_backward:NoneType=None, before_step:NoneType=None,
after_cancel_step:NoneType=None, after_step:NoneType=None, after_cancel_batch:NoneType=None,
after_batch:NoneType=None, after_cancel_train:NoneType=None, after_train:NoneType=None,
before_validate:NoneType=None, after_cancel_validate:NoneType=None, after_validate:NoneType=None,
after_cancel_epoch:NoneType=None, after_epoch:NoneType=None, after_cancel_fit:NoneType=None,
after_fit:NoneType=None
):
Basic class handling tweaks of the training loop by changing a Learner in various events
The training loop is defined in Learner a bit below and consists in a minimal set of instructions: looping through the data we:
- compute the output of the model from the input
- calculate a loss between this output and the desired target
- compute the gradients of this loss with respect to all the model parameters
- update the parameters accordingly
- zero all the gradients
Any tweak of this training loop is defined in a Callback to avoid over-complicating the code of the training loop, and to make it easy to mix and match different techniques (since they’ll be defined in different callbacks). A callback can implement actions on the following events:
after_create: called after theLearneris createdbefore_fit: called before starting training or inference, ideal for initial setup.before_epoch: called at the beginning of each epoch, useful for any behavior you need to reset at each epoch.before_train: called at the beginning of the training part of an epoch.before_batch: called at the beginning of each batch, just after drawing said batch. It can be used to do any setup necessary for the batch (like hyper-parameter scheduling) or to change the input/target before it goes in the model (change of the input with techniques like mixup for instance).after_pred: called after computing the output of the model on the batch. It can be used to change that output before it’s fed to the loss.after_loss: called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss (AR or TAR in RNN training for instance).before_backward: called after the loss has been computed, but only in training mode (i.e. when the backward pass will be used)after_backward: called after the backward pass, but before the update of the parameters. Generallybefore_stepshould be used instead.before_step: called after the backward pass, but before the update of the parameters. It can be used to do any change to the gradients before said update (gradient clipping for instance).after_step: called after the step and before the gradients are zeroed.after_batch: called at the end of a batch, for any clean-up before the next one.after_train: called at the end of the training phase of an epoch.before_validate: called at the beginning of the validation phase of an epoch, useful for any setup needed specifically for validation.after_validate: called at the end of the validation part of an epoch.after_epoch: called at the end of an epoch, for any clean-up before the next one.after_fit: called at the end of training, for final clean-up.
Callback.__call__
def __call__(
event_name
):
Call self.{event_name} if it’s defined
One way to define callbacks is through subclassing:
class _T(Callback):
def call_me(self): return "maybe"
test_eq(_T()("call_me"), "maybe")Another way is by passing the callback function to the constructor:
def cb(self): return "maybe"
_t = Callback(before_fit=cb)
test_eq(_t(event.before_fit), "maybe")Callbacks provide a shortcut to avoid having to write self.learn.bla for any bla attribute we seek; instead, just write self.bla. This only works for getting attributes, not for setting them.
mk_class('TstLearner', 'a')
class TstCallback(Callback):
def batch_begin(self): print(self.a)
learn,cb = TstLearner(1),TstCallback()
cb.learn = learn
test_stdout(lambda: cb('batch_begin'), "1")If you want to change the value of an attribute, you have to use self.learn.bla, no self.bla. In the example below, self.a += 1 creates an a attribute of 2 in the callback instead of setting the a of the learner to 2. It also issues a warning that something is probably wrong:
learn.a1
class TstCallback(Callback):
def batch_begin(self): self.a += 1
learn,cb = TstLearner(1),TstCallback()
cb.learn = learn
cb('batch_begin')
test_eq(cb.a, 2)
test_eq(cb.learn.a, 1)/tmp/ipykernel_5201/1369389649.py:29: UserWarning: You are shadowing an attribute (a) that exists in the learner. Use `self.learn.a` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
A proper version needs to write self.learn.a = self.a + 1:
class TstCallback(Callback):
def batch_begin(self): self.learn.a = self.a + 1
learn,cb = TstLearner(1),TstCallback()
cb.learn = learn
cb('batch_begin')
test_eq(cb.learn.a, 2)Callback.name
def name(
):
Name of the Callback, camel-cased and with ‘Callback’ removed
test_eq(TstCallback().name, 'tst')
class ComplicatedNameCallback(Callback): pass
test_eq(ComplicatedNameCallback().name, 'complicated_name')TrainEvalCallback
def TrainEvalCallback(
after_create:NoneType=None, before_fit:NoneType=None, before_epoch:NoneType=None, before_train:NoneType=None,
before_batch:NoneType=None, after_pred:NoneType=None, after_loss:NoneType=None, before_backward:NoneType=None,
after_cancel_backward:NoneType=None, after_backward:NoneType=None, before_step:NoneType=None,
after_cancel_step:NoneType=None, after_step:NoneType=None, after_cancel_batch:NoneType=None,
after_batch:NoneType=None, after_cancel_train:NoneType=None, after_train:NoneType=None,
before_validate:NoneType=None, after_cancel_validate:NoneType=None, after_validate:NoneType=None,
after_cancel_epoch:NoneType=None, after_epoch:NoneType=None, after_cancel_fit:NoneType=None,
after_fit:NoneType=None
):
Callback that tracks the number of iterations done and properly sets training/eval mode
This Callback is automatically added in every Learner at initialization.
Attributes available to callbacks
When writing a callback, the following attributes of Learner are available:
model: the model used for training/validationdls: the underlyingDataLoadersloss_func: the loss function usedopt: the optimizer used to update the model parametersopt_func: the function used to create the optimizercbs: the list containing allCallbacksdl: currentDataLoaderused for iterationx/xb: last input drawn fromself.dl(potentially modified by callbacks).xbis always a tuple (potentially with one element) andxis detuplified. You can only assign toxb.y/yb: last target drawn fromself.dl(potentially modified by callbacks).ybis always a tuple (potentially with one element) andyis detuplified. You can only assign toyb.pred: last predictions fromself.model(potentially modified by callbacks)loss_grad: last computed loss (potentially modified by callbacks)loss: clone ofloss_gradused for loggingn_epoch: the number of epochs in this trainingn_iter: the number of iterations in the currentself.dlepoch: the current epoch index (from 0 ton_epoch-1)iter: the current iteration index inself.dl(from 0 ton_iter-1)
The following attributes are added by TrainEvalCallback and should be available unless you went out of your way to remove that callback:
train_iter: the number of training iterations done since the beginning of this trainingpct_train: from 0. to 1., the percentage of training iterations completedtraining: flag to indicate if we’re in training mode or not
The following attribute is added by Recorder and should be available unless you went out of your way to remove that callback:
smooth_loss: an exponentially-averaged version of the training loss
Callbacks control flow
It happens that we may want to skip some of the steps of the training loop: in gradient accumulation, we don’t always want to do the step/zeroing of the grads for instance. During an LR finder test, we don’t want to do the validation phase of an epoch. Or if we’re training with a strategy of early stopping, we want to be able to completely interrupt the training loop.
This is made possible by raising specific exceptions the training loop will look for (and properly catch).
CancelStepException
def CancelStepException(
args:VAR_POSITIONAL, kwargs:VAR_KEYWORD
):
Skip stepping the optimizer
CancelBatchException
def CancelBatchException(
args:VAR_POSITIONAL, kwargs:VAR_KEYWORD
):
Skip the rest of this batch and go to after_batch
CancelBackwardException
def CancelBackwardException(
args:VAR_POSITIONAL, kwargs:VAR_KEYWORD
):
Skip the backward pass and go to after_backward
CancelTrainException
def CancelTrainException(
args:VAR_POSITIONAL, kwargs:VAR_KEYWORD
):
Skip the rest of the training part of the epoch and go to after_train
CancelValidException
def CancelValidException(
args:VAR_POSITIONAL, kwargs:VAR_KEYWORD
):
Skip the rest of the validation part of the epoch and go to after_validate
CancelEpochException
def CancelEpochException(
args:VAR_POSITIONAL, kwargs:VAR_KEYWORD
):
Skip the rest of this epoch and go to after_epoch
CancelFitException
def CancelFitException(
args:VAR_POSITIONAL, kwargs:VAR_KEYWORD
):
Interrupts training and go to after_fit
You can detect one of those exceptions occurred and add code that executes right after with the following events:
after_cancel_batch: reached immediately after aCancelBatchExceptionbefore proceeding toafter_batchafter_cancel_train: reached immediately after aCancelTrainExceptionbefore proceeding toafter_epochafter_cancel_valid: reached immediately after aCancelValidExceptionbefore proceeding toafter_epochafter_cancel_epoch: reached immediately after aCancelEpochExceptionbefore proceeding toafter_epochafter_cancel_fit: reached immediately after aCancelFitExceptionbefore proceeding toafter_fit
GatherPredsCallback
def GatherPredsCallback(
with_input:bool=False, # Whether to return inputs
with_loss:bool=False, # Whether to return losses
save_preds:Path=None, # Path to save predictions
save_targs:Path=None, # Path to save targets
with_preds:bool=True, # Whether to return predictions
with_targs:bool=True, # Whether to return targets
concat_dim:int=0, # Dimension to concatenate returned tensors
pickle_protocol:int=2, # Pickle protocol used to save predictions and targets
):
Callback that returns all predictions and targets, optionally with_input or with_loss
FetchPredsCallback
def FetchPredsCallback(
ds_idx:int=1, # Index of dataset, 0 for train, 1 for valid, used if `dl` is not present
dl:DataLoader=None, # [`DataLoader`](https://docs.fast.ai/data.load.html#dataloader) used for fetching [`Learner`](https://docs.fast.ai/learner.html#learner) predictions
with_input:bool=False, # Whether to return inputs in [`GatherPredsCallback`](https://docs.fast.ai/callback.core.html#gatherpredscallback)
with_decoded:bool=False, # Whether to return decoded predictions
cbs:Callback | MutableSequence=None, # [`Callback`](https://docs.fast.ai/callback.core.html#callback) to temporarily remove from [`Learner`](https://docs.fast.ai/learner.html#learner)
reorder:bool=True, # Whether to sort prediction results
):
A callback to fetch predictions during the training loop