This provides both a standalone class and a callback for registering and automatically deregistering PyTorch hooks, along with some pre-defined hooks. Hooks can be attached to any
nn.Module, for either the forward or the backward pass.
We'll start by looking at the pre-defined hook
ActivationStats, then we'll see how to create our own.
ActivationStats saves the layer activations in
self.stats for all
modules passed to it. By default it will save activations for all modules. For instance:
path = untar_data(URLs.MNIST_SAMPLE) data = ImageDataBunch.from_folder(path) #learn = create_cnn(data, models.resnet18, callback_fns=ActivationStats) learn = Learner(data, simple_cnn((3,16,16,2)), callback_fns=ActivationStats) learn.fit(1)
stats is a
FloatTensor of shape
(2,num_modules,num_batches). The first axis is
torch.Size([2, 3, 193])
So this shows the standard deviation (
axis0==1) of 2th last layer (
axis1==-2) for each batch (
Registers and manually deregisters a PyTorch hook. Your
hook_func will be called automatically when forward/backward (depending on
is_forward) for your module
m is run, and the result of that function is placed in
Deregister the hook, if not called already.
Acts as a
hooks[i]) and an
for hook in hooks) of a group of hooks, one for each module in
ms, with the ability to remove all as a group. Use
stored to get all hook results.
is_forward behavior is the same as
Hook. See the source code for
HookCallback for a simple example.
Deregister all hooks created by this class, if not previously called.
It can be useful to get the size of each layer of a model (e.g. for printing a summary, or for generating cross-connections for a
DynamicUnet), however they depend on the size of the input. This function calculates the layer sizes by passing in a minimal tensor of
modules, uses a callback to automatically register a method
self.hook (that you must define in an inherited class) as a hook. This method must have the signature:
def hook(self, m:Model, input:Tensors, output:Tensors)
do_remove then the hook is automatically deregistered at the end of training. See
ActivationStats for a simple example of inheriting from this class.