Implement callbacks using hooks

Hook callbacks

This provides both a standalone class and a callback for registering and automatically deregistering PyTorch hooks, along with some pre-defined hooks. Hooks can be attached to any nn.Module, for either the forward or the backward pass.

We'll start by looking at the pre-defined hook ActivationStats, then we'll see how to create our own.

class ActivationStats[source]

ActivationStats(learn:Learner, modules:Sequence[Module]=None, do_remove:bool=True) :: HookCallback

Callback that record the activations.

ActivationStats saves the layer activations in self.stats for all modules passed to it. By default it will save activations for all modules. For instance:

path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
learn = create_cnn(data, models.resnet18, callback_fns=ActivationStats)
Total time: 00:13
epoch  train loss  valid loss
0      0.077055    0.049985    (00:13)

The saved stats is a FloatTensor of shape (2,num_modules,num_batches). The first axis is (mean,stdev).

(194, 44)
torch.Size([2, 44, 194])

So this shows the standard deviation (axis0==1) of 5th last layer (axis1==-5) for each batch (axis2):


class Hook[source]

Hook(m:Module, hook_func:HookFunc, is_forward:bool=True, detach:bool=True)

Create a hook.

Registers and manually deregisters a PyTorch hook. Your hook_func will be called automatically when forward/backward (depending on is_forward) for your module m is run, and the result of that function is placed in self.stored.



Deregister the hook, if not called already.

class Hooks[source]

Hooks(ms:ModuleList, hook_func:HookFunc, is_forward:bool=True, detach:bool=True)

Create several hooks.

Acts as a Collection (i.e. len(hooks) and hooks[i]) and an Iterator (i.e. for hook in hooks) of a group of hooks, one for each module in ms, with the ability to remove all as a group. Use stored to get all hook results. hook_func and is_forward behavior is the same as Hook. See the source code for HookCallback for a simple example.



Deregister all hooks created by this class, if not previously called.

Convenience functions for hooks


hook_output(module:Module) → Hook

Function that creates a Hook for module that simply stores the output of the layer.


hook_outputs(modules:ModuleList) → Hooks

Function that creates a Hook for all passed modules that simply stores the output of the layers. For example, the (slightly simplified) source code of model_sizes is:

def model_sizes(m, size):
    x = m(torch.zeros(1, in_channels(m), *size))
    return [o.stored.shape for o in hook_outputs(m)]


model_sizes(m:Module, size:tuple=(64, 64), full:bool=True) → Tuple[Sizes, Tensor, Hooks]

Pass a dummy input through the model to get the various sizes. Returns (res,x,hooks) if full


num_features_model(m:Module) → int

Return the number of output features for model.

It can be useful to get the size of each layer of a model (e.g. for printing a summary, or for generating cross-connections for a DynamicUnet), however they depend on the size of the input. This function calculates the layer sizes by passing in a minimal tensor of size.

class HookCallback[source]

HookCallback(learn:Learner, modules:Sequence[Module]=None, do_remove:bool=True) :: LearnerCallback

Callback that registers given hooks.

For all modules, uses a callback to automatically register a method self.hook (that you must define in an inherited class) as a hook. This method must have the signature:

def hook(self, m:Model, input:Tensors, output:Tensors)

If do_remove then the hook is automatically deregistered at the end of training. See ActivationStats for a simple example of inheriting from this class.