Callback and helper function to add hooks in models
/usr/local/lib/python3.8/dist-packages/torch/cuda/ UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from (Triggered internally at  /pytorch/c10/cuda/CUDAFunctions.cpp:100.)
  return torch._C._cuda_getDeviceCount() > 0
from fastai.test_utils import *

What are hooks?

Hooks are functions you can attach to a particular layer in your model and that will be executed in the forward pass (for forward hooks) or backward pass (for backward hooks). Here we begin with an introduction around hooks, but you should jump to HookCallback if you quickly want to implement one (and read the following example ActivationStats).

Forward hooks are functions that take three arguments: the layer it's applied to, the input of that layer and the output of that layer.

tst_model = nn.Linear(5,3)
def example_forward_hook(m,i,o): print(m,i,o)
x = torch.randn(4,5)
hook = tst_model.register_forward_hook(example_forward_hook)
y = tst_model(x)
Linear(in_features=5, out_features=3, bias=True) (tensor([[-1.2270,  0.3393,  0.5718, -0.5177, -1.2104],
        [ 0.0354, -0.1312,  0.4349, -0.9000,  1.1018],
        [-1.5961,  1.1750, -1.1201, -1.5117,  1.9888],
        [ 1.5914,  0.4621,  0.9078, -0.6504, -0.4852]]),) tensor([[ 0.9606,  0.3859,  0.8708],
        [-0.1754, -0.4515,  0.6353],
        [ 0.7625,  0.8009,  1.7342],
        [ 0.2908, -0.3984,  0.3174]], grad_fn=<AddmmBackward>)

Backward hooks are functions that take three arguments: the layer it's applied to, the gradients of the loss with respect to the input, and the gradients with respect to the output.

def example_backward_hook(m,gi,go): print(m,gi,go)
hook = tst_model.register_backward_hook(example_backward_hook)

x = torch.randn(4,5)
y = tst_model(x)
loss = y.pow(2).mean()
Linear(in_features=5, out_features=3, bias=True) (tensor([0.4103, 0.0804, 0.4298]), None, tensor([[-0.7980, -0.7526, -0.7505],
        [ 0.1685,  0.2370,  0.1919],
        [-0.4391, -0.4002, -0.4031],
        [-0.1246, -0.0397, -0.2264],
        [-0.0704,  0.0269,  0.0509]])) (tensor([[ 0.0558, -0.0538, -0.0166],
        [ 0.2879,  0.2239,  0.2653],
        [ 0.0227,  0.0017,  0.0994],
        [ 0.0439, -0.0915,  0.0817]]),)

Hooks can change the input/output of a layer, or the gradients, print values or shapes. If you want to store something related to theses inputs/outputs, it's best to have your hook associated to a class so that it can put it in the state of an instance of that class.

class Hook[source]

Hook(m, hook_func, is_forward=True, detach=True, cpu=False, gather=False)

Create a hook on m with hook_func.

This will be called during the forward pass if is_forward=True, the backward pass otherwise, and will optionally detach, gather and put on the cpu the (gradient of the) input/output of the model before passing them to hook_func. The result of hook_func will be stored in the stored attribute of the Hook.

tst_model = nn.Linear(5,3)
hook = Hook(tst_model, lambda m,i,o: o)
y = tst_model(x)
test_eq(hook.stored, y)


Hook.hook_fn(module, input, output)

Applies hook_func to module, input, output.



Remove the hook from the model.

tst_model = nn.Linear(5,10)
x = torch.randn(4,5)
y = tst_model(x)
hook = Hook(tst_model, example_forward_hook)
test_stdout(lambda: tst_model(x), f"{tst_model} ({x},) {y.detach()}")
test_stdout(lambda: tst_model(x), "")

Context Manager

Since it's very important to remove your Hook even if your code is interrupted by some bug, Hook can be used as context managers.



Register the hook



Remove the hook

tst_model = nn.Linear(5,10)
x = torch.randn(4,5)
y = tst_model(x)
with Hook(tst_model, example_forward_hook) as h:
    test_stdout(lambda: tst_model(x), f"{tst_model} ({x},) {y.detach()}")
test_stdout(lambda: tst_model(x), "")


hook_output(module, detach=True, cpu=False, grad=False)

Return a Hook that stores activations of module in self.stored

The activations stored are the gradients if grad=True, otherwise the output of module. If detach=True they are detached from their history, and if cpu=True, they're put on the CPU.

tst_model = nn.Linear(5,10)
x = torch.randn(4,5)
with hook_output(tst_model) as h:
    y = tst_model(x)
    test_eq(y, h.stored)
    assert not h.stored.requires_grad
with hook_output(tst_model, grad=True) as h:
    y = tst_model(x)
    loss = y.pow(2).mean()
    test_close(2*y / y.numel(), h.stored[0])
with hook_output(tst_model, cpu=True) as h:
    y = tst_model.cuda()(x.cuda())
    test_eq(h.stored.device, torch.device('cpu'))

class Hooks[source]

Hooks(ms, hook_func, is_forward=True, detach=True, cpu=False)

Create several hooks on the modules in ms with hook_func.

layers = [nn.Linear(5,10), nn.ReLU(), nn.Linear(10,3)]
tst_model = nn.Sequential(*layers)
hooks = Hooks(tst_model, lambda m,i,o: o)
y = tst_model(x)
test_eq(hooks.stored[0], layers[0](x))
test_eq(hooks.stored[1], F.relu(layers[0](x)))
test_eq(hooks.stored[2], y)


The states saved in each hook.



Remove the hooks from the model.

Context Manager

Like Hook , you can use Hooks as context managers.



Register the hooks



Remove the hooks

layers = [nn.Linear(5,10), nn.ReLU(), nn.Linear(10,3)]
tst_model = nn.Sequential(*layers)
with Hooks(layers, lambda m,i,o: o) as h:
    y = tst_model(x)
    test_eq(h.stored[0], layers[0](x))
    test_eq(h.stored[1], F.relu(layers[0](x)))
    test_eq(h.stored[2], y)


hook_outputs(modules, detach=True, cpu=False, grad=False)

Return Hooks that store activations of all modules in self.stored

The activations stored are the gradients if grad=True, otherwise the output of modules. If detach=True they are detached from their history, and if cpu=True, they're put on the CPU.

layers = [nn.Linear(5,10), nn.ReLU(), nn.Linear(10,3)]
tst_model = nn.Sequential(*layers)
x = torch.randn(4,5)
with hook_outputs(layers) as h:
    y = tst_model(x)
    test_eq(h.stored[0], layers[0](x))
    test_eq(h.stored[1], F.relu(layers[0](x)))
    test_eq(h.stored[2], y)
    for s in h.stored: assert not s.requires_grad
with hook_outputs(layers, grad=True) as h:
    y = tst_model(x)
    loss = y.pow(2).mean()
    g = 2*y / y.numel()
    test_close(g, h.stored[2][0])
    g = g @ layers[2]
    test_close(g, h.stored[1][0])
    g = g * (layers[0](x) > 0).float()
    test_close(g, h.stored[0][0])
with hook_outputs(tst_model, cpu=True) as h:
    y = tst_model.cuda()(x.cuda())
    for s in h.stored: test_eq(s.device, torch.device('cpu'))


dummy_eval(m, size=(64, 64))

Evaluate m on a dummy input of a certain size


model_sizes(m, size=(64, 64))

Pass a dummy input through the model m to get the various sizes of activations.

m = nn.Sequential(ConvLayer(3, 16), ConvLayer(16, 32, stride=2), ConvLayer(32, 32))
test_eq(model_sizes(m), [[1, 16, 64, 64], [1, 32, 32, 32], [1, 32, 32, 32]])



Return the number of output features for m.

m = nn.Sequential(nn.Conv2d(5,4,3), nn.Conv2d(4,3,3))
test_eq(num_features_model(m), 3)
m = nn.Sequential(ConvLayer(3, 16), ConvLayer(16, 32, stride=2), ConvLayer(32, 32))
test_eq(num_features_model(m), 32)

To make hooks easy to use, we wrapped a version in a Callback where you just have to implement a hook function (plus any element you might need).



Check if m has at least one parameter

assert has_params(nn.Linear(3,4))
assert has_params(nn.LSTM(4,5,2))
assert not has_params(nn.ReLU())

class HookCallback[source]

HookCallback(modules=None, every=None, remove_end=True, is_forward=True, detach=True, cpu=True, hook=None) :: Callback

Callback that can be used to register hooks on modules

You can either subclass and implement a hook function (along with any event you want) or pass that a hook function when initializing. Such a function needs to take three argument: a layer, input and output (for a backward hook, input means gradient with respect to the inputs, output, gradient with respect to the output) and can either modify them or update the state according to them.

If not provided, modules will default to the layers of self.model that have a weight attribute. Depending on do_remove, the hooks will be properly removed at the end of training (or in case of error). is_forward , detach and cpu are passed to Hooks.

The function called at each forward (or backward) pass is self.hook and must be implemented when subclassing this callback.

class TstCallback(HookCallback):
    def hook(self, m, i, o): return o
    def after_batch(self): test_eq(self.hooks.stored[0], self.pred)
learn = synth_learner(n_trn=5, cbs = TstCallback())
[0, 14.660037994384766, 11.780715942382812, '00:00']
class TstCallback(HookCallback):
    def __init__(self, modules=None, remove_end=True, detach=True, cpu=False):
        super().__init__(modules, None, remove_end, False, detach, cpu)
    def hook(self, m, i, o): return o
    def after_batch(self):
            test_eq(self.hooks.stored[0][0], 2*(self.pred-self.y)/self.pred.shape[0])
learn = synth_learner(n_trn=5, cbs = TstCallback())
[0, 15.180420875549316, 15.292489051818848, '00:00']



Register the Hooks on self.modules.



Remove the Hooks.

Model summary



Give the number of parameters of a module and if it's trainable or not

test_eq(total_params(nn.Linear(10,32)), (32*10+32,True))
test_eq(total_params(nn.Linear(10,32, bias=False)), (32*10,True))
test_eq(total_params(nn.BatchNorm2d(20)), (20*2, True))
test_eq(total_params(nn.BatchNorm2d(20, affine=False)), (0,False))
test_eq(total_params(nn.Conv2d(16, 32, 3)), (16*32*3*3 + 32, True))
test_eq(total_params(nn.Conv2d(16, 32, 3, bias=False)), (16*32*3*3, True))
#First ih layer 20--10, all else 10--10. *4 for the four gates
test_eq(total_params(nn.LSTM(20, 10, 2)), (4 * (20*10 + 10) + 3 * 4 * (10*10 + 10), True))


layer_info(learn, *xb)

Return layer infos of model on xb (only support batch first inputs)

The output of _track is expected to be a type, the number of parameters, the shape of the layer, whether it is trainable, what layer group it belongs to, and whether or not the size changed. There are three potential groups that can show:

  • A non-activation layer (Linear, Conv, etc)
  • An activation layer
  • A pooling layer

Depending on which only part of the output is really returned, otherwise it is ''. For non-activation layers everything is returned. Activation layers only return a name and False for same. Pooling layers will return the name, the new shape, and False for same

def _m(): return nn.Sequential(nn.Linear(1,50), nn.ReLU(), nn.BatchNorm1d(50), nn.Linear(50, 1))
sample_input = torch.randn((16, 1))
test_eq(layer_info(synth_learner(model=_m()), sample_input), [
    ('Linear', 100, True, [1, 50], False),
    ('ReLU', '', '', '', True),
    ('BatchNorm1d', 100, True, [1, 50], True),
    ('Linear', 51, True, [1, 1], False)


module_summary(learn, *xb)

Print a summary of model using xb



Print a summary of the model, optimizer and loss function.

learn = synth_learner(model=_m())
Sequential (Input shape: 16)
Layer (type)         Output Shape         Param #    Trainable 
                     16 x 50             
Linear                                    100        True      
BatchNorm1d                               100        True      
                     16 x 1              
Linear                                    51         True      

Total params: 251
Total trainable params: 251
Total non-trainable params: 0

Optimizer used: functools.partial(<function SGD at 0x7f74e6301320>, mom=0.9)
Loss function: FlattenedLoss of MSELoss()

  - TrainEvalCallback
  - Recorder

Activation graphs

This is an example of a HookCallback, that stores the mean, stds and histograms of activations that go through the network.

class ActivationStats[source]

ActivationStats(with_hist=False, modules=None, every=None, remove_end=True, is_forward=True, detach=True, cpu=True, hook=None) :: HookCallback

Callback that record the mean and std of activations.

class ActivationStats(HookCallback):
    "Callback that record the mean and std of activations."
    def __init__(self, with_hist=False, **kwargs):
        self.with_hist = with_hist

    def before_fit(self):
        "Initialize stats."
        self.stats = L()

    def hook(self, m, i, o):
        if isinstance(o, tuple): return self.hook_multi_ouput(o)
        o = o.float()
        res = {'mean': o.mean().item(), 'std': o.std().item(),
               'near_zero': (o<=0.05).long().sum().item()/o.numel()}
        if self.with_hist: res['hist'] = o.histc(40,0,10)
        return res
    def hook_multi_ouput(self,o_tuple):
        "For outputs of RNN which are [nested] tuples of tensors"
        res = []
        for o in self._flatten_tuple(o_tuple):
            if not(isinstance(o, Tensor)): continue
            res.append(self.hook(None, None, o))
        return res

    def _flatten_tuple(self, o_tuple):
        "Recursively flatten a [nested] tuple"
        res = []
        for it in o_tuple:
            if isinstance(it, tuple): res += self._flatten_tuple(it)
            else: res += [it]
        return tuple(res)

    def after_batch(self):
        "Take the stored results and puts it in `self.stats`"
        if and (self.every is None or self.train_iter%self.every == 0):

    def layer_stats(self, idx):
        lstats = self.stats.itemgot(idx)
        return L(lstats.itemgot(o) for o in ('mean','std','near_zero'))

    def hist(self, idx):
        res = self.stats.itemgot(idx).itemgot('hist')
        return torch.stack(tuple(res)).t().float().log1p()

    def color_dim(self, idx, figsize=(10,5), ax=None):
        "The 'colorful dimension' plot"
        res = self.hist(idx)
        if ax is None: ax = subplots(figsize=figsize)[1][0]
        ax.imshow(res, origin='lower')

    def plot_layer_stats(self, idx):
        _,axs = subplots(1, 3, figsize=(12,3))
        for o,ax,title in zip(self.layer_stats(idx),axs,('mean','std','% near zero')):
learn = synth_learner(n_trn=5, cbs = ActivationStats(every=4))
[0, 23.258834838867188, 25.094209671020508, '00:00']
(#2) [[{'mean': -1.286624789237976, 'std': 0.5204624533653259, 'near_zero': 1.0}],[{'mean': -0.8293845057487488, 'std': 0.9002514481544495, 'near_zero': 0.8125}]]

The first line contains the means of the outputs of the model for each batch in the training set, the second line their standard deviations.

def test_every(n_tr, every):
    "create a learner, fit, then check number of stats collected"
    learn = synth_learner(n_trn=n_tr, cbs=ActivationStats(every=every))
    expected_stats_len = math.ceil(n_tr / every)
    test_eq(expected_stats_len, len(learn.activation_stats.stats))
for n_tr in [11, 12, 13]:
    test_every(n_tr, 4)
    test_every(n_tr, 1)
[0, 13.364368438720703, 8.206254005432129, '00:00']
[0, 28.091310501098633, 20.91598892211914, '00:00']
[0, 12.561519622802734, 11.980316162109375, '00:00']
[0, 14.071392059326172, 15.907712936401367, '00:00']
[0, 19.315099716186523, 16.30086898803711, '00:00']
[0, 19.391584396362305, 12.741905212402344, '00:00']