Implement callbacks using hooks

Hook callbacks

This provides both a standalone class and a callback for registering and automatically deregistering PyTorch hooks, along with some pre-defined hooks. Hooks can be attached to any nn.Module, for either the forward or the backward pass.

We'll start by looking at the pre-defined hook ActivationStats, then we'll see how to create our own.

class ActivationStats[source]

ActivationStats(`learn`:Learner, `modules`:Sequence[Module]=`None`, `do_remove`:bool=`True`) :: HookCallback

Callback that record the mean and std of activations.

ActivationStats saves the layer activations in self.stats for all modules passed to it. By default it will save activations for all modules. For instance:

path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
#learn = create_cnn(data, models.resnet18, callback_fns=ActivationStats)
learn = Learner(data, simple_cnn((3,16,16,2)), callback_fns=ActivationStats)
Total time: 00:02

epoch train_loss valid_loss
1 0.112384 0.083544

The saved stats is a FloatTensor of shape (2,num_modules,num_batches). The first axis is (mean,stdev).

(193, 3)
torch.Size([2, 3, 193])

So this shows the standard deviation (axis0==1) of 2th last layer (axis1==-2) for each batch (axis2):


Internal implementation


hook(`m`:Module, `i`:Tensors, `o`:Tensors) → Tuple[Rank0Tensor, Rank0Tensor]

Take the mean and std of o.

Callback methods

You don't call these yourself - they're called by fastai's Callback system automatically to enable the class's functionality.



Initialize stats.


on_batch_end(`train`, `kwargs`)

Take the stored results and puts it in self.stats



Polish the final result.

class Hook[source]

Hook(`m`:Module, `hook_func`:HookFunc, `is_forward`:bool=`True`, `detach`:bool=`True`)

Create a hook on m with hook_func.

Registers and manually deregisters a PyTorch hook. Your hook_func will be called automatically when forward/backward (depending on is_forward) for your module m is run, and the result of that function is placed in self.stored.



Remove the hook from the model.

Deregister the hook, if not called already.

class Hooks[source]

Hooks(`ms`:ModuleList, `hook_func`:HookFunc, `is_forward`:bool=`True`, `detach`:bool=`True`)

Create several hooks on the modules in ms with hook_func.

Acts as a Collection (i.e. len(hooks) and hooks[i]) and an Iterator (i.e. for hook in hooks) of a group of hooks, one for each module in ms, with the ability to remove all as a group. Use stored to get all hook results. hook_func and is_forward behavior is the same as Hook. See the source code for HookCallback for a simple example.



Remove the hooks from the model.

Deregister all hooks created by this class, if not previously called.

Convenience functions for hooks


hook_output(`module`:Module, `detach`:bool=`True`, `grad`:bool=`False`) → Hook

Return a Hook that stores activations of module in self.stored

Function that creates a Hook for module that simply stores the output of the layer.


hook_outputs(`modules`:ModuleList, `detach`:bool=`True`, `grad`:bool=`False`) → Hooks

Return Hooks that store activations of all modules in self.stored

Function that creates a Hook for all passed modules that simply stores the output of the layers. For example, the (slightly simplified) source code of model_sizes is:

def model_sizes(m, size):
    x = m(torch.zeros(1, in_channels(m), *size))
    return [o.stored.shape for o in hook_outputs(m)]


model_sizes(`m`:Module, `size`:tuple=`(64, 64)`) → Tuple[Sizes, Tensor, Hooks]

Pass a dummy input through the model m to get the various sizes of activations.


model_summary(`m`:ModuleList, `n`:int=`70`)

Print a summary of m using a output text width of n chars


num_features_model(`m`:Module) → int

Return the number of output features for model.

It can be useful to get the size of each layer of a model (e.g. for printing a summary, or for generating cross-connections for a DynamicUnet), however they depend on the size of the input. This function calculates the layer sizes by passing in a minimal tensor of size.


dummy_batch(`m`:Module, `size`:tuple=`(64, 64)`) → Tensor

Create a dummy batch to go through m with size.


dummy_eval(`m`:Module, `size`:tuple=`(64, 64)`)

Pass a dummy_batch in evaluation mode in m with size.

class HookCallback[source]

HookCallback(`learn`:Learner, `modules`:Sequence[Module]=`None`, `do_remove`:bool=`True`) :: LearnerCallback

Callback that can be used to register hooks on modules. Implement the corresponding function in self.hook.

For all modules, uses a callback to automatically register a method self.hook (that you must define in an inherited class) as a hook. This method must have the signature:

def hook(self, m:Model, input:Tensors, output:Tensors)

If do_remove then the hook is automatically deregistered at the end of training. See ActivationStats for a simple example of inheriting from this class.

Callback methods

You don't call these yourself - they're called by fastai's Callback system automatically to enable the class's functionality.



Register the Hooks on self.modules.



Remove the Hooks.