from fastai.test_utils import *Model hooks
What are hooks?
Hooks are functions you can attach to a particular layer in your model and that will be executed in the forward pass (for forward hooks) or backward pass (for backward hooks). Here we begin with an introduction around hooks, but you should jump to HookCallback if you quickly want to implement one (and read the following example ActivationStats).
Forward hooks are functions that take three arguments: the layer it’s applied to, the input of that layer and the output of that layer.
tst_model = nn.Linear(5,3)
def example_forward_hook(m,i,o): print(m,i,o)
x = torch.randn(4,5)
hook = tst_model.register_forward_hook(example_forward_hook)
y = tst_model(x)
hook.remove()Linear(in_features=5, out_features=3, bias=True) (tensor([[-0.9811, 0.1455, 0.3667, 0.7821, 1.0376],
[ 0.4916, -0.8581, 0.1134, 0.1752, -0.0595],
[ 0.4517, -0.9027, 1.3693, -0.8399, 1.4931],
[-0.7818, -1.1915, -0.1014, 1.1878, -0.8517]]),) tensor([[-0.1019, -0.4006, -0.3282],
[-0.0551, 0.5754, 0.0726],
[-0.5382, -0.1731, -0.1683],
[-0.3195, 0.7669, 0.3924]], grad_fn=<AddmmBackward0>)
Backward hooks are functions that take three arguments: the layer it’s applied to, the gradients of the loss with respect to the input, and the gradients with respect to the output.
def example_backward_hook(m,gi,go): print(m,gi,go)
hook = tst_model.register_backward_hook(example_backward_hook)
x = torch.randn(4,5)
y = tst_model(x)
loss = y.pow(2).mean()
loss.backward()
hook.remove()Linear(in_features=5, out_features=3, bias=True) (tensor([ 0.0913, 0.3834, -0.0015]), None, tensor([[ 0.1872, 0.1248, -0.2946],
[ 0.1090, -0.3164, -0.2486],
[-0.0468, -0.1728, -0.1686],
[-0.0787, 0.3200, 0.0099],
[-0.0308, -0.1119, 0.0056]])) (tensor([[ 0.0414, 0.1750, 0.0672],
[-0.0252, 0.0636, 0.0592],
[ 0.1243, 0.0364, -0.1118],
[-0.0491, 0.1084, -0.0160]]),)
/home/benja/.conda/envs/fastaidev/lib/python3.12/site-packages/torch/nn/modules/module.py:1830: FutureWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
Hooks can change the input/output of a layer, or the gradients, print values or shapes. If you want to store something related to theses inputs/outputs, it’s best to have your hook associated to a class so that it can put it in the state of an instance of that class.
Hook
Hook (m, hook_func, is_forward=True, detach=True, cpu=False, gather=False)
Create a hook on m with hook_func.
This will be called during the forward pass if is_forward=True, the backward pass otherwise, and will optionally detach, gather and put on the cpu the (gradient of the) input/output of the model before passing them to hook_func. The result of hook_func will be stored in the stored attribute of the Hook.
tst_model = nn.Linear(5,3)
hook = Hook(tst_model, lambda m,i,o: o)
y = tst_model(x)
test_eq(hook.stored, y)Hook.hook_fn
Hook.hook_fn (module, input, output)
Applies hook_func to module, input, output.
Hook.remove
Hook.remove ()
Remove the hook from the model.
It’s important to properly remove your hooks for your model when you’re done to avoid them being called again next time your model is applied to some inputs, and to free the memory that go with their state.
tst_model = nn.Linear(5,10)
x = torch.randn(4,5)
y = tst_model(x)
hook = Hook(tst_model, example_forward_hook)
test_stdout(lambda: tst_model(x), f"{tst_model} ({x},) {y.detach()}")
hook.remove()
test_stdout(lambda: tst_model(x), "")Context Manager
Since it’s very important to remove your Hook even if your code is interrupted by some bug, Hook can be used as context managers.
Hook.__enter__
Hook.__enter__ (*args)
Register the hook
Hook.__exit__
Hook.__exit__ (*args)
Remove the hook
tst_model = nn.Linear(5,10)
x = torch.randn(4,5)
y = tst_model(x)
with Hook(tst_model, example_forward_hook) as h:
test_stdout(lambda: tst_model(x), f"{tst_model} ({x},) {y.detach()}")
test_stdout(lambda: tst_model(x), "")hook_output
hook_output (module, detach=True, cpu=False, grad=False)
Return a Hook that stores activations of module in self.stored
The activations stored are the gradients if grad=True, otherwise the output of module. If detach=True they are detached from their history, and if cpu=True, they’re put on the CPU.
tst_model = nn.Linear(5,10)
x = torch.randn(4,5)
with hook_output(tst_model) as h:
y = tst_model(x)
test_eq(y, h.stored)
assert not h.stored.requires_grad
with hook_output(tst_model, grad=True) as h:
y = tst_model(x)
loss = y.pow(2).mean()
loss.backward()
test_close(2*y / y.numel(), h.stored[0])with hook_output(tst_model, cpu=True) as h:
y = tst_model.cuda()(x.cuda())
test_eq(h.stored.device, torch.device('cpu'))Hooks
Hooks (ms, hook_func, is_forward=True, detach=True, cpu=False)
Create several hooks on the modules in ms with hook_func.
layers = [nn.Linear(5,10), nn.ReLU(), nn.Linear(10,3)]
tst_model = nn.Sequential(*layers)
hooks = Hooks(tst_model, lambda m,i,o: o)
y = tst_model(x)
test_eq(hooks.stored[0], layers[0](x))
test_eq(hooks.stored[1], F.relu(layers[0](x)))
test_eq(hooks.stored[2], y)
hooks.remove()Hooks.stored
Hooks.stored ()
Hooks.remove
Hooks.remove ()
Remove the hooks from the model.
Context Manager
Like Hook , you can use Hooks as context managers.
Hooks.__enter__
Hooks.__enter__ (*args)
Register the hooks
Hooks.__exit__
Hooks.__exit__ (*args)
Remove the hooks
layers = [nn.Linear(5,10), nn.ReLU(), nn.Linear(10,3)]
tst_model = nn.Sequential(*layers)
with Hooks(layers, lambda m,i,o: o) as h:
y = tst_model(x)
test_eq(h.stored[0], layers[0](x))
test_eq(h.stored[1], F.relu(layers[0](x)))
test_eq(h.stored[2], y)hook_outputs
hook_outputs (modules, detach=True, cpu=False, grad=False)
Return Hooks that store activations of all modules in self.stored
The activations stored are the gradients if grad=True, otherwise the output of modules. If detach=True they are detached from their history, and if cpu=True, they’re put on the CPU.
layers = [nn.Linear(5,10), nn.ReLU(), nn.Linear(10,3)]
tst_model = nn.Sequential(*layers)
x = torch.randn(4,5)
with hook_outputs(layers) as h:
y = tst_model(x)
test_eq(h.stored[0], layers[0](x))
test_eq(h.stored[1], F.relu(layers[0](x)))
test_eq(h.stored[2], y)
for s in h.stored: assert not s.requires_grad
with hook_outputs(layers, grad=True) as h:
y = tst_model(x)
loss = y.pow(2).mean()
loss.backward()
g = 2*y / y.numel()
test_close(g, h.stored[2][0])
g = g @ layers[2].weight.data
test_close(g, h.stored[1][0])
g = g * (layers[0](x) > 0).float()
test_close(g, h.stored[0][0])with hook_outputs(tst_model, cpu=True) as h:
y = tst_model.cuda()(x.cuda())
for s in h.stored: test_eq(s.device, torch.device('cpu'))dummy_eval
dummy_eval (m, size=(64, 64))
Evaluate m on a dummy input of a certain size
model_sizes
model_sizes (m, size=(64, 64))
Pass a dummy input through the model m to get the various sizes of activations.
m = nn.Sequential(ConvLayer(3, 16), ConvLayer(16, 32, stride=2), ConvLayer(32, 32))
test_eq(model_sizes(m), [[1, 16, 64, 64], [1, 32, 32, 32], [1, 32, 32, 32]])num_features_model
num_features_model (m)
Return the number of output features for m.
m = nn.Sequential(nn.Conv2d(5,4,3), nn.Conv2d(4,3,3))
test_eq(num_features_model(m), 3)
m = nn.Sequential(ConvLayer(3, 16), ConvLayer(16, 32, stride=2), ConvLayer(32, 32))
test_eq(num_features_model(m), 32)To make hooks easy to use, we wrapped a version in a Callback where you just have to implement a hook function (plus any element you might need).
has_params
has_params (m)
Check if m has at least one parameter
assert has_params(nn.Linear(3,4))
assert has_params(nn.LSTM(4,5,2))
assert not has_params(nn.ReLU())HookCallback
HookCallback (modules=None, every=None, remove_end=True, is_forward=True, detach=True, cpu=True, include_paramless=False, hook=None)
Callback that can be used to register hooks on modules
You can either subclass and implement a hook function (along with any event you want) or pass that a hook function when initializing. Such a function needs to take three argument: a layer, input and output (for a backward hook, input means gradient with respect to the inputs, output, gradient with respect to the output) and can either modify them or update the state according to them.
If not provided, modules will default to the layers of self.model that have a weight attribute. (to include layers of self.model that do not have a weight attribute e.g ReLU, Flatten etc., set include_paramless=True) Depending on do_remove, the hooks will be properly removed at the end of training (or in case of error). is_forward , detach and cpu are passed to Hooks.
The function called at each forward (or backward) pass is self.hook and must be implemented when subclassing this callback.
class TstCallback(HookCallback):
def hook(self, m, i, o): return o
def after_batch(self): test_eq(self.hooks.stored[0], self.pred)
learn = synth_learner(n_trn=5, cbs = TstCallback())
learn.fit(1)[0, 6.587433815002441, 5.402360916137695, '00:00']
/home/benja/fastai/fastai/fastai/callback/core.py:71: UserWarning: You are shadowing an attribute (modules) that exists in the learner. Use `self.learn.modules` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
class TstCallback(HookCallback):
def __init__(self, modules=None, remove_end=True, detach=True, cpu=False):
super().__init__(modules, None, remove_end, False, detach, cpu)
def hook(self, m, i, o): return o
def after_batch(self):
if self.training:
test_eq(self.hooks.stored[0][0], 2*(self.pred-self.y)/self.pred.shape[0])
learn = synth_learner(n_trn=5, cbs = TstCallback())
learn.fit(1)[0, 8.743090629577637, 10.072294235229492, '00:00']
HookCallback.before_fit
HookCallback.before_fit ()
Register the Hooks on self.modules.
HookCallback.after_fit
HookCallback.after_fit ()
Remove the Hooks.
Model summary
total_params
total_params (m)
Give the number of parameters of a module and if it’s trainable or not
test_eq(total_params(nn.Linear(10,32)), (32*10+32,True))
test_eq(total_params(nn.Linear(10,32, bias=False)), (32*10,True))
test_eq(total_params(nn.BatchNorm2d(20)), (20*2, True))
test_eq(total_params(nn.BatchNorm2d(20, affine=False)), (0,False))
test_eq(total_params(nn.Conv2d(16, 32, 3)), (16*32*3*3 + 32, True))
test_eq(total_params(nn.Conv2d(16, 32, 3, bias=False)), (16*32*3*3, True))
#First ih layer 20--10, all else 10--10. *4 for the four gates
test_eq(total_params(nn.LSTM(20, 10, 2)), (4 * (20*10 + 10) + 3 * 4 * (10*10 + 10), True))layer_info
layer_info (learn, *xb)
Return layer infos of model on xb (only support batch first inputs)
The output of _track is expected to be a tuple of module name, the number of parameters, the shape of the layer, whether it is trainable, what layer group it belongs to, and whether or not the size changed. There are three potential groups that can show:
- A non-activation layer (Linear, Conv, etc)
- An activation layer
- A pooling layer
Depending on which only part of the output is really returned, otherwise it is ''. For non-activation layers everything is returned. Activation layers only return a name, the shape and False for same. Pooling layers will return the name, the new shape, and False for same
def _m(): return nn.Sequential(nn.Linear(1,50), nn.ReLU(), nn.BatchNorm1d(50), nn.Linear(50, 1))
sample_input = torch.randn((16, 1))
test_eq(layer_info(synth_learner(model=_m()), sample_input), [
('Linear', 100, True, [1, 50], False),
('ReLU', '', '', [1,50], True),
('BatchNorm1d', 100, True, [1, 50], True),
('Linear', 51, True, [1, 1], False)
])module_summary
module_summary (learn, *xb)
Print a summary of model using xb
Learner.summary
Learner.summary ()
Print a summary of the model, optimizer and loss function.
learn = synth_learner(model=_m())
learn.summary()Sequential (Input shape: 16 x 1)
============================================================================
Layer (type) Output Shape Param # Trainable
============================================================================
16 x 50
Linear 100 True
ReLU
BatchNorm1d 100 True
____________________________________________________________________________
16 x 1
Linear 51 True
____________________________________________________________________________
Total params: 251
Total trainable params: 251
Total non-trainable params: 0
Optimizer used: functools.partial(<function SGD at 0x78dacd98c7c0>, mom=0.9)
Loss function: FlattenedLoss of MSELoss()
Callbacks:
- TrainEvalCallback
- CastToTensor
- Recorder
Activation graphs
ActivationStats
ActivationStats (with_hist=False, modules=None, every=None, remove_end=True, is_forward=True, detach=True, cpu=True, include_paramless=False, hook=None)
Callback that record the mean and std of activations.
learn = synth_learner(n_trn=5, cbs = ActivationStats(every=4))
learn.fit(1)[0, 7.943600177764893, 8.535039901733398, '00:00']
learn.activation_stats.stats(#2) [[{'mean': 1.3028467893600464, 'std': 0.32002925872802734, 'near_zero': 0.0}],[{'mean': 1.3026641607284546, 'std': 0.29966112971305847, 'near_zero': 0.0}]]
The first line contains the means of the outputs of the model for each batch in the training set, the second line their standard deviations.
def test_every(n_tr, every):
"create a learner, fit, then check number of stats collected"
learn = synth_learner(n_trn=n_tr, cbs=ActivationStats(every=every))
learn.fit(1)
expected_stats_len = math.ceil(n_tr / every)
test_eq(expected_stats_len, len(learn.activation_stats.stats))
for n_tr in [11, 12, 13]:
test_every(n_tr, 4)
test_every(n_tr, 1)[0, 7.132676601409912, 6.505333423614502, '00:00']
[0, 30.60495376586914, 29.395254135131836, '00:00']
[0, 14.507355690002441, 10.65038013458252, '00:00']
[0, 12.470440864562988, 7.216660499572754, '00:00']
[0, 30.247482299804688, 25.165172576904297, '00:00']
[0, 6.672229290008545, 5.598482131958008, '00:00']