1), 1) test_eq(Identity()(
Layers
Basic manipulations and resize
module
module (*flds, **defaults)
Decorator to create an nn.Module
using f
as forward
method
Identity
Identity ()
Do nothing at all
Lambda
Lambda (func)
An easy way to create a pytorch layer for a simple func
def _add2(x): return x+2
= Lambda(_add2)
tst = torch.randn(10,20)
x +2)
test_eq(tst(x), x= pickle.loads(pickle.dumps(tst))
tst2 +2)
test_eq(tst2(x), x tst
Lambda(func=<function _add2>)
PartialLambda
PartialLambda (func)
Layer that applies partial(func, **kwargs)
def test_func(a,b=2): return a+b
= PartialLambda(test_func, b=5)
tst +5) test_eq(tst(x), x
Flatten
Flatten (full=False)
Flatten x
to a single dimension, e.g. at end of a model. full
for rank-1 tensor
= Flatten()
tst = torch.randn(10,5,4)
x 10,20])
test_eq(tst(x).shape, [= Flatten(full=True)
tst 200]) test_eq(tst(x).shape, [
ToTensorBase
ToTensorBase (tensor_cls=<class 'fastai.torch_core.TensorBase'>)
Convert x to TensorBase class
= ToTensorBase()
ttb = TensorImage(torch.rand(1,3,32,32))
timg type(ttb(timg)), TensorBase) test_eq(
View
View (*size)
Reshape x
to size
= View(10,5,4)
tst 10,5,4]) test_eq(tst(x).shape, [
ResizeBatch
ResizeBatch (*size)
Reshape x
to size
, keeping batch dim the same size
= ResizeBatch(5,4)
tst 10,5,4]) test_eq(tst(x).shape, [
Debugger
Debugger ()
A module to debug inside a model.
sigmoid_range
sigmoid_range (x, low, high)
Sigmoid function with range (low, high)
= tensor([-10.,0.,10.])
test assert torch.allclose(sigmoid_range(test, -1, 2), tensor([-1.,0.5, 2.]), atol=1e-4, rtol=1e-4)
assert torch.allclose(sigmoid_range(test, -5, -1), tensor([-5.,-3.,-1.]), atol=1e-4, rtol=1e-4)
assert torch.allclose(sigmoid_range(test, 2, 4), tensor([2., 3., 4.]), atol=1e-4, rtol=1e-4)
SigmoidRange
SigmoidRange (low, high)
Sigmoid module with range (low, high)
= SigmoidRange(-1, 2)
tst assert torch.allclose(tst(test), tensor([-1.,0.5, 2.]), atol=1e-4, rtol=1e-4)
Pooling layers
AdaptiveConcatPool1d
AdaptiveConcatPool1d (size=None)
Layer that concats AdaptiveAvgPool1d
and AdaptiveMaxPool1d
AdaptiveConcatPool2d
AdaptiveConcatPool2d (size=None)
Layer that concats AdaptiveAvgPool2d
and AdaptiveMaxPool2d
If the input is bs x nf x h x h
, the output will be bs x 2*nf x 1 x 1
if no size is passed or bs x 2*nf x size x size
= AdaptiveConcatPool2d()
tst = torch.randn(10,5,4,4)
x 10,10,1,1])
test_eq(tst(x).shape, [= torch.max(x, dim=2, keepdim=True)[0]
max1 = torch.max(max1, dim=3, keepdim=True)[0]
maxp 5], maxp)
test_eq(tst(x)[:,:5:], x.mean(dim=[2,3], keepdim=True))
test_eq(tst(x)[:,= AdaptiveConcatPool2d(2)
tst 10,10,2,2]) test_eq(tst(x).shape, [
PoolType
PoolType ()
Initialize self. See help(type(self)) for accurate signature.
adaptive_pool
adaptive_pool (pool_type)
PoolFlatten
PoolFlatten (pool_type='Avg')
Combine nn.AdaptiveAvgPool2d
and Flatten
.
= PoolFlatten()
tst 10,5])
test_eq(tst(x).shape, [=[2,3])) test_eq(tst(x), x.mean(dim
BatchNorm layers
BatchNorm
BatchNorm (nf, ndim=2, norm_type=<NormType.Batch: 1>, eps:float=1e-05, momentum:Optional[float]=0.1, affine:bool=True, track_running_stats:bool=True, device=None, dtype=None)
BatchNorm layer with nf
features and ndim
initialized depending on norm_type
.
InstanceNorm
InstanceNorm (nf, ndim=2, norm_type=<NormType.Instance: 5>, affine=True, eps:float=1e-05, momentum:float=0.1, track_running_stats:bool=False, device=None, dtype=None)
InstanceNorm layer with nf
features and ndim
initialized depending on norm_type
.
kwargs
are passed to nn.BatchNorm
and can be eps
, momentum
, affine
and track_running_stats
.
= BatchNorm(15)
tst assert isinstance(tst, nn.BatchNorm2d)
15))
test_eq(tst.weight, torch.ones(= BatchNorm(15, norm_type=NormType.BatchZero)
tst 15))
test_eq(tst.weight, torch.zeros(= BatchNorm(15, ndim=1)
tst assert isinstance(tst, nn.BatchNorm1d)
= BatchNorm(15, ndim=3)
tst assert isinstance(tst, nn.BatchNorm3d)
= InstanceNorm(15)
tst assert isinstance(tst, nn.InstanceNorm2d)
15))
test_eq(tst.weight, torch.ones(= InstanceNorm(15, norm_type=NormType.InstanceZero)
tst 15))
test_eq(tst.weight, torch.zeros(= InstanceNorm(15, ndim=1)
tst assert isinstance(tst, nn.InstanceNorm1d)
= InstanceNorm(15, ndim=3)
tst assert isinstance(tst, nn.InstanceNorm3d)
If affine
is false the weight should be None
15, affine=False).weight, None)
test_eq(BatchNorm(15, affine=False).weight, None) test_eq(InstanceNorm(
BatchNorm1dFlat
BatchNorm1dFlat (num_features:int, eps:float=1e-05, momentum:Optional[float]=0.1, affine:bool=True, track_running_stats:bool=True, device=None, dtype=None)
nn.BatchNorm1d
, but first flattens leading dimensions
= BatchNorm1dFlat(15)
tst = torch.randn(32, 64, 15)
x = tst(x)
y = x.mean(dim=[0,1])
mean 0*0.9 + mean*0.1)
test_close(tst.running_mean, = (x-mean).pow(2).mean(dim=[0,1])
var 1*0.9 + var*0.1, eps=1e-4)
test_close(tst.running_var, -mean)/torch.sqrt(var+1e-5) * tst.weight + tst.bias, eps=1e-4) test_close(y, (x
LinBnDrop
LinBnDrop (n_in, n_out, bn=True, p=0.0, act=None, lin_first=False)
Module grouping BatchNorm1d
, Dropout
and Linear
layers
The BatchNorm
layer is skipped if bn=False
, as is the dropout if p=0.
. Optionally, you can add an activation for after the linear layer with act
.
= LinBnDrop(10, 20)
tst = list(tst.children())
mods len(mods), 2)
test_eq(assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Linear)
= LinBnDrop(10, 20, p=0.1)
tst = list(tst.children())
mods len(mods), 3)
test_eq(assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Dropout)
assert isinstance(mods[2], nn.Linear)
= LinBnDrop(10, 20, act=nn.ReLU(), lin_first=True)
tst = list(tst.children())
mods len(mods), 3)
test_eq(assert isinstance(mods[0], nn.Linear)
assert isinstance(mods[1], nn.ReLU)
assert isinstance(mods[2], nn.BatchNorm1d)
= LinBnDrop(10, 20, bn=False)
tst = list(tst.children())
mods len(mods), 1)
test_eq(assert isinstance(mods[0], nn.Linear)
Inits
sigmoid
sigmoid (input, eps=1e-07)
Same as torch.sigmoid
, plus clamping to `(eps,1-eps)
sigmoid_
sigmoid_ (input, eps=1e-07)
Same as torch.sigmoid_
, plus clamping to `(eps,1-eps)
vleaky_relu
vleaky_relu (input, inplace=True)
F.leaky_relu
with 0.3 slope
init_default
init_default (m, func=<function kaiming_normal_>)
Initialize m
weights with func
and set bias
to 0.
init_linear
init_linear (m, act_func=None, init='auto', bias_std=0.01)
Convolutions
ConvLayer
ConvLayer (ni, nf, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=<NormType.Batch: 1>, bn_1st=True, act_cls=<class 'torch.nn.modules.activation.ReLU'>, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int,Tuple[int,int]]=1, groups:int=1, padding_mode:str='zeros', device=None, dtype=None)
Create a sequence of convolutional (ni
to nf
), ReLU (if use_activ
) and norm_type
layers.
The convolution uses ks
(kernel size) stride
, padding
and bias
. padding
will default to the appropriate value ((ks-1)//2
if it’s not a transposed conv) and bias
will default to True
the norm_type
is Spectral
or Weight
, False
if it’s Batch
or BatchZero
. Note that if you don’t want any normalization, you should pass norm_type=None
.
This defines a conv layer with ndim
(1,2 or 3) that will be a ConvTranspose if transpose=True
. act_cls
is the class of the activation function to use (instantiated inside). Pass act=None
if you don’t want an activation function. If you quickly want to change your default activation, you can change the value of defaults.activation
.
init
is used to initialize the weights (the bias are initialized to 0) and xtra
is an optional layer to add at the end.
= ConvLayer(16, 32)
tst = list(tst.children())
mods len(mods), 3)
test_eq(1].weight, torch.ones(32))
test_eq(mods[0].padding, (1,1)) test_eq(mods[
= torch.randn(64, 16, 8, 8)#.cuda() x
#Padding is selected to make the shape the same if stride=1
64,32,8,8]) test_eq(tst(x).shape, [
#Padding is selected to make the shape half if stride=2
= ConvLayer(16, 32, stride=2)
tst 64,32,4,4]) test_eq(tst(x).shape, [
#But you can always pass your own padding if you want
= ConvLayer(16, 32, padding=0)
tst 64,32,6,6]) test_eq(tst(x).shape, [
#No bias by default for Batch NormType
assert mods[0].bias is None
#But can be overridden with `bias=True`
= ConvLayer(16, 32, bias=True)
tst assert first(tst.children()).bias is not None
#For no norm, or spectral/weight, bias is True by default
for t in [None, NormType.Spectral, NormType.Weight]:
= ConvLayer(16, 32, norm_type=t)
tst assert first(tst.children()).bias is not None
#Various n_dim/tranpose
= ConvLayer(16, 32, ndim=3)
tst assert isinstance(list(tst.children())[0], nn.Conv3d)
= ConvLayer(16, 32, ndim=1, transpose=True)
tst assert isinstance(list(tst.children())[0], nn.ConvTranspose1d)
#No activation/leaky
= ConvLayer(16, 32, ndim=3, act_cls=None)
tst = list(tst.children())
mods len(mods), 2)
test_eq(= ConvLayer(16, 32, ndim=3, act_cls=partial(nn.LeakyReLU, negative_slope=0.1))
tst = list(tst.children())
mods len(mods), 3)
test_eq(assert isinstance(mods[2], nn.LeakyReLU)
# #export
# def linear(in_features, out_features, bias=True, act_cls=None, init='auto'):
# "Linear layer followed by optional activation, with optional auto-init"
# res = nn.Linear(in_features, out_features, bias=bias)
# if act_cls: act_cls = act_cls()
# init_linear(res, act_cls, init=init)
# if act_cls: res = nn.Sequential(res, act_cls)
# return res
# #export
# @delegates(ConvLayer)
# def conv1d(ni, nf, ks, stride=1, ndim=1, norm_type=None, **kwargs):
# "Convolutional layer followed by optional activation, with optional auto-init"
# return ConvLayer(ni, nf, ks, stride=stride, ndim=ndim, norm_type=norm_type, **kwargs)
# #export
# @delegates(ConvLayer)
# def conv2d(ni, nf, ks, stride=1, ndim=2, norm_type=None, **kwargs):
# "Convolutional layer followed by optional activation, with optional auto-init"
# return ConvLayer(ni, nf, ks, stride=stride, ndim=ndim, norm_type=norm_type, **kwargs)
# #export
# @delegates(ConvLayer)
# def conv3d(ni, nf, ks, stride=1, ndim=3, norm_type=None, **kwargs):
# "Convolutional layer followed by optional activation, with optional auto-init"
# return ConvLayer(ni, nf, ks, stride=stride, ndim=ndim, norm_type=norm_type, **kwargs)
AdaptiveAvgPool
AdaptiveAvgPool (sz=1, ndim=2)
nn.AdaptiveAvgPool layer for ndim
MaxPool
MaxPool (ks=2, stride=None, padding=0, ndim=2, ceil_mode=False)
nn.MaxPool layer for ndim
AvgPool
AvgPool (ks=2, stride=None, padding=0, ndim=2, ceil_mode=False)
nn.AvgPool layer for ndim
Embeddings
trunc_normal_
trunc_normal_ (x, mean=0.0, std=1.0)
Truncated normal initialization (approximation)
Embedding
Embedding (ni, nf, std=0.01)
Embedding layer with truncated normal initialization
Truncated normal initialization bounds the distribution to avoid large value. For a given standard deviation std
, the bounds are roughly -2*std
, 2*std
.
= 0.02
std = Embedding(10, 30, std)
tst assert tst.weight.min() > -2*std
assert tst.weight.max() < 2*std
0, 1e-2)
test_close(tst.weight.mean(), 0.1) test_close(tst.weight.std(), std,
Self attention
SelfAttention
SelfAttention (n_channels)
Self attention layer for n_channels
.
Self-attention layer as introduced in Self-Attention Generative Adversarial Networks.
Initially, no change is done to the input. This is controlled by a trainable parameter named gamma
as we return x + gamma * out
.
= SelfAttention(16)
tst = torch.randn(32, 16, 8, 8)
x test_eq(tst(x),x)
Then during training gamma
will probably change since it’s a trainable parameter. Let’s see what’s happening when it gets a nonzero value.
1.)
tst.gamma.data.fill_(= tst(x)
y 32,16,8,8]) test_eq(y.shape, [
The attention mechanism requires three matrix multiplications (here represented by 1x1 convs). The multiplications are done on the channel level (the second dimension in our tensor) and we flatten the feature map (which is 8x8 here). As in the paper, we note f
, g
and h
the results of those multiplications.
= tst.query[0].weight.data,tst.key[0].weight.data,tst.value[0].weight.data
q,k,v 2, 16, 1], [2, 16, 1], [16, 16, 1]])
test_eq([q.shape, k.shape, v.shape], [[= map(lambda m: x.view(32, 16, 64).transpose(1,2) @ m.squeeze().t(), [q,k,v])
f,g,h 32,64,2], [32,64,2], [32,64,16]]) test_eq([f.shape, g.shape, h.shape], [[
The key part of the attention layer is to compute attention weights for each of our location in the feature map (here 8x8 = 64). Those are positive numbers that sum to 1 and tell the model to pay attention to this or that part of the picture. We make the product of f
and the transpose of g
(to get something of size bs by 64 by 64) then apply a softmax on the first dimension (to get the positive numbers that sum up to 1). The result can then be multiplied with h
transposed to get an output of size bs by channels by 64, which we can then be viewed as an output the same size as the original input.
The final result is then x + gamma * out
as we saw before.
= F.softmax(torch.bmm(f, g.transpose(1,2)), dim=1)
beta 32, 64, 64])
test_eq(beta.shape, [= torch.bmm(h.transpose(1,2), beta)
out 32, 16, 64])
test_eq(out.shape, [+ out.view(32, 16, 8, 8), eps=1e-4) test_close(y, x
PooledSelfAttention2d
PooledSelfAttention2d (n_channels)
Pooled self attention layer for 2d.
Self-attention layer used in the Big GAN paper.
It uses the same attention as in SelfAttention
but adds a max pooling of stride 2 before computing the matrices g
and h
: the attention is ported on one of the 2x2 max-pooled window, not the whole feature map. There is also a final matrix product added at the end to the output, before retuning gamma * out + x
.
SimpleSelfAttention
SimpleSelfAttention (n_in:int, ks=1, sym=False)
Same as nn.Module
, but no need for subclasses to call super().__init__
PixelShuffle
PixelShuffle introduced in this article to avoid checkerboard artifacts when upsampling images. If we want an output with ch_out
filters, we use a convolution with ch_out * (r**2)
filters, where r
is the upsampling factor. Then we reorganize those filters like in the picture below:
icnr_init
icnr_init (x, scale=2, init=<function kaiming_normal_>)
ICNR init of x
, with scale
and init
function
ICNR init was introduced in this article. It suggests to initialize the convolution that will be used in PixelShuffle so that each of the r**2
channels get the same weight (so that in the picture above, the 9 colors in a 3 by 3 window are initially the same).
This is done on the first dimension because PyTorch stores the weights of a convolutional layer in this format: ch_out x ch_in x ks x ks
.
= torch.randn(16*4, 32, 1, 1)
tst = icnr_init(tst)
tst for i in range(0,16*4,4):
+1])
test_eq(tst[i],tst[i+2])
test_eq(tst[i],tst[i+3]) test_eq(tst[i],tst[i
PixelShuffle_ICNR
PixelShuffle_ICNR (ni, nf=None, scale=2, blur=False, norm_type=<NormType.Weight: 3>, act_cls=<class 'torch.nn.modules.activation.ReLU'>)
Upsample by scale
from ni
filters to nf
(default ni
), using nn.PixelShuffle
.
The convolutional layer is initialized with icnr_init
and passed act_cls
and norm_type
(the default of weight normalization seemed to be what’s best for super-resolution problems, in our experiments).
The blur
option comes from Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts where the authors add a little bit of blur to completely get rid of checkerboard artifacts.
= PixelShuffle_ICNR(16)
psfl = torch.randn(64, 16, 8, 8)
x = psfl(x)
y 64, 16, 16, 16])
test_eq(y.shape, [#ICNR init makes every 2x2 window (stride 2) have the same elements
for i in range(0,16,2):
for j in range(0,16,2):
+1,j])
test_eq(y[:,:,i,j],y[:,:,i+1])
test_eq(y[:,:,i,j],y[:,:,i ,j+1,j+1]) test_eq(y[:,:,i,j],y[:,:,i
= PixelShuffle_ICNR(16, norm_type=None)
psfl = torch.randn(64, 16, 8, 8)
x = psfl(x)
y 64, 16, 16, 16])
test_eq(y.shape, [#ICNR init makes every 2x2 window (stride 2) have the same elements
for i in range(0,16,2):
for j in range(0,16,2):
+1,j])
test_eq(y[:,:,i,j],y[:,:,i+1])
test_eq(y[:,:,i,j],y[:,:,i ,j+1,j+1]) test_eq(y[:,:,i,j],y[:,:,i
= PixelShuffle_ICNR(16, norm_type=NormType.Spectral)
psfl = torch.randn(64, 16, 8, 8)
x = psfl(x)
y 64, 16, 16, 16])
test_eq(y.shape, [#ICNR init makes every 2x2 window (stride 2) have the same elements
for i in range(0,16,2):
for j in range(0,16,2):
+1,j])
test_eq(y[:,:,i,j],y[:,:,i+1])
test_eq(y[:,:,i,j],y[:,:,i ,j+1,j+1]) test_eq(y[:,:,i,j],y[:,:,i
Sequential extensions
sequential
sequential (*args)
Create an nn.Sequential
, wrapping items with Lambda
if needed
SequentialEx
SequentialEx (*layers)
Like nn.Sequential
, but with ModuleList semantics, and can access module input
This is useful to write layers that require to remember the input (like a resnet block) in a sequential way.
MergeLayer
MergeLayer (dense:bool=False)
Merge a shortcut with the result of the module by adding them or concatenating them if dense=True
.
= SequentialEx(ConvLayer(16, 16), ConvLayer(16,16))
res_block # just to test append - normally it would be in init params
res_block.append(MergeLayer()) = torch.randn(32, 16, 8, 8)
x = res_block(x)
y 32, 16, 8, 8])
test_eq(y.shape, [+ res_block[1](res_block[0](x))) test_eq(y, x
= TensorBase(torch.randn(32, 16, 8, 8))
x = res_block(x)
y None) test_is(y.orig,
Concat
Equivalent to keras.layers.Concatenate, it will concat the outputs of a ModuleList over a given dimension (default the filter dimension)
Cat
Cat (layers, dim=1)
Concatenate layers outputs over a given dim
= [ConvLayer(2,4), ConvLayer(2,4), ConvLayer(2,4)]
layers = torch.rand(1,2,8,8)
x = Cat(layers)
cat 1,12,8,8])
test_eq(cat(x).shape, [for l in layers], dim=1)) test_eq(cat(x), torch.cat([l(x)
Ready-to-go models
SimpleCNN
SimpleCNN (filters, kernel_szs=None, strides=None, bn=True)
Create a simple CNN with filters
.
The model is a succession of convolutional layers from (filters[0],filters[1])
to (filters[n-2],filters[n-1])
(if n
is the length of the filters
list) followed by a PoolFlatten
. kernel_szs
and strides
defaults to a list of 3s and a list of 2s. If bn=True
the convolutional layers are successions of conv-relu-batchnorm, otherwise conv-relu.
= SimpleCNN([8,16,32])
tst = list(tst.children())
mods len(mods), 3)
test_eq(0].in_channels, m[0].out_channels] for m in mods[:2]], [[8,16], [16,32]]) test_eq([[m[
Test kernel sizes
= SimpleCNN([8,16,32], kernel_szs=[1,3])
tst = list(tst.children())
mods 0].kernel_size for m in mods[:2]], [(1,1), (3,3)]) test_eq([m[
Test strides
= SimpleCNN([8,16,32], strides=[1,2])
tst = list(tst.children())
mods 0].stride for m in mods[:2]], [(1,1),(2,2)]) test_eq([m[
ProdLayer
ProdLayer ()
Merge a shortcut with the result of the module by multiplying them.
SEModule
SEModule (ch, reduction, act_cls=<class 'torch.nn.modules.activation.ReLU'>)
ResBlock
ResBlock (expansion, ni, nf, stride=1, groups=1, reduction=None, nh1=None, nh2=None, dw=False, g2=1, sa=False, sym=False, norm_type=<NormType.Batch: 1>, act_cls=<class 'torch.nn.modules.activation.ReLU'>, ndim=2, ks=3, pool=<function AvgPool>, pool_first=True, padding=None, bias=None, bn_1st=True, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int,Tuple[int,int]]=1, padding_mode:str='zeros', device=None, dtype=None)
Resnet block from ni
to nh
with stride
This is a resnet block (normal or bottleneck depending on expansion
, 1 for the normal block and 4 for the traditional bottleneck) that implements the tweaks from Bag of Tricks for Image Classification with Convolutional Neural Networks. In particular, the last batchnorm layer (if that is the selected norm_type
) is initialized with a weight (or gamma) of zero to facilitate the flow from the beginning to the end of the network. It also implements optional Squeeze and Excitation and grouped convs for ResNeXT and similar models (use dw=True
for depthwise convs).
The kwargs
are passed to ConvLayer
along with norm_type
.
SEBlock
SEBlock (expansion, ni, nf, groups=1, reduction=16, stride=1, **kwargs)
SEResNeXtBlock
SEResNeXtBlock (expansion, ni, nf, groups=32, reduction=16, stride=1, base_width=4, **kwargs)
SeparableBlock
SeparableBlock (expansion, ni, nf, reduction=16, stride=1, base_width=4, **kwargs)
Time Distributed Layer
Equivalent to Keras TimeDistributed
Layer, enables computing pytorch Module
over an axis.
= 2, 5
bs, seq_len = torch.rand(bs,seq_len,3,2,2), torch.rand(bs,seq_len,3,2,2) x, y
= TimeDistributed(nn.Conv2d(3,4,1))
tconv 2,5,4,2,2))
test_eq(tconv(x).shape, (=True
tconv.low_mem2,5,4,2,2)) test_eq(tconv(x).shape, (
class Mod(Module):
def __init__(self):
self.conv = nn.Conv2d(3,4,1)
def forward(self, x, y):
return self.conv(x) + self.conv(y)
= TimeDistributed(Mod()) tmod
= tmod(x,y)
out 2,5,4,2,2))
test_eq(out.shape, (=True
tmod.low_mem= tmod(x,y)
out_low_mem 2,5,4,2,2))
test_eq(out_low_mem.shape, ( test_eq(out, out_low_mem)
class Mod2(Module):
def __init__(self):
self.conv = nn.Conv2d(3,4,1)
def forward(self, x, y):
return self.conv(x), self.conv(y)
= TimeDistributed(Mod2()) tmod2
= tmod2(x,y)
out len(out), 2)
test_eq(0].shape, (2,5,4,2,2))
test_eq(out[=True
tmod2.low_mem= tmod2(x,y)
out_low_mem 0].shape, (2,5,4,2,2))
test_eq(out_low_mem[ test_eq(out, out_low_mem)
TimeDistributed
TimeDistributed (module, low_mem=False, tdim=1)
Applies module
over tdim
identically for each step, use low_mem
to compute one at a time.
This module is equivalent to Keras TimeDistributed Layer. This wrapper allows to apply a layer to every temporal slice of an input. By default it is assumed the time axis (tdim
) is the 1st one (the one after the batch size). A typical usage would be to encode a sequence of images using an image encoder.
The forward
function of TimeDistributed
supports *args
and **kkwargs
but only args
will be split and passed to the underlying module independently for each timestep, kwargs
will be passed as they are. This is useful when you have module that take multiple arguments as inputs, this way, you can put all tensors you need spliting as args
and other arguments that don’t need split as kwargs
.
This module is heavy on memory, as it will try to pass mutiple timesteps at the same time on the batch dimension, if you get out of memorey errors, try first reducing your batch size by the number of timesteps.
from fastai.vision.all import *
= create_body(resnet18()) encoder
A resnet18 will encode a feature map of 512 channels. Height and Width will be divided by 32.
= TimeDistributed(encoder) time_resnet
a synthetic batch of 2 image-sequences of lenght 5. (bs, seq_len, ch, w, h)
= torch.rand(2, 5, 3, 64, 64) image_sequence
time_resnet(image_sequence).shape
torch.Size([2, 5, 512, 2, 2])
This way, one can encode a sequence of images on feature space. There is also a low_mem_forward
that will pass images one at a time to reduce GPU memory consumption.
time_resnet.low_mem_forward(image_sequence).shape
torch.Size([2, 5, 512, 2, 2])
Swish and Mish
swish
swish (x, inplace=False)
SwishJit
SwishJit ()
Same as nn.Module
, but no need for subclasses to call super().__init__
MishJitAutoFn
MishJitAutoFn (*args, **kwargs)
*Base class to create custom autograd.Function
.
To create a custom autograd.Function
, subclass this class and implement the :meth:forward
and :meth:backward
static methods. Then, to use your custom op in the forward pass, call the class method [
apply](https://docs.fast.ai/torch_core.html#apply)
. Do not call :meth:forward
directly.
To ensure correctness and best performance, make sure you are calling the correct methods on ctx
and validating your backward function using :func:torch.autograd.gradcheck
.
See :ref:extending-autograd
for more details on how to use this class.
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Exp(Function):
>>> @staticmethod
>>> def forward(ctx, i):
>>> result = i.exp()
>>> ctx.save_for_backward(result)
>>> return result
>>>
>>> @staticmethod
>>> def backward(ctx, grad_output):
>>> result, = ctx.saved_tensors
>>> return grad_output * result
>>>
>>> # Use it by calling the apply method:
>>> # xdoctest: +SKIP
>>> output = Exp.apply(input)*
mish
mish (x, inplace=False)
MishJit
MishJit ()
Same as nn.Module
, but no need for subclasses to call super().__init__
Helper functions for submodules
It’s easy to get the list of all parameters of a given model. For when you want all submodules (like linear/conv layers) without forgetting lone parameters, the following class wraps those in fake modules.
ParameterModule
ParameterModule (p)
Register a lone parameter p
in a module.
children_and_parameters
children_and_parameters (m)
Return the children of m
and its direct parameters not registered in modules.
class TstModule(Module):
def __init__(self): self.a,self.lin = nn.Parameter(torch.randn(1)),nn.Linear(5,10)
= TstModule()
tst = children_and_parameters(tst)
children len(children), 2)
test_eq(0], tst.lin)
test_eq(children[assert isinstance(children[1], ParameterModule)
1].val, tst.a) test_eq(children[
has_children
has_children (m)
class A(Module): pass
assert not has_children(A())
assert has_children(TstModule())
flatten_model
flatten_model (m)
Return the list of all submodules and parameters of m
= nn.Sequential(TstModule(), TstModule())
tst = flatten_model(tst)
children len(children), 4)
test_eq(assert isinstance(children[1], ParameterModule)
assert isinstance(children[3], ParameterModule)
NoneReduce
NoneReduce (loss_func)
A context manager to evaluate loss_func
with none reduce.
= torch.randn(5),torch.randn(5)
x,y = nn.MSELoss()
loss_fn with NoneReduce(loss_fn) as loss_func:
= loss_func(x,y)
loss 5])
test_eq(loss.shape, ['mean')
test_eq(loss_fn.reduction,
= F.mse_loss
loss_fn with NoneReduce(loss_fn) as loss_func:
= loss_func(x,y)
loss 5])
test_eq(loss.shape, [ test_eq(loss_fn, F.mse_loss)
in_channels
in_channels (m)
Return the shape of the first weight layer in m
.
5,4,3), nn.Conv2d(4,3,3))), 5)
test_eq(in_channels(nn.Sequential(nn.Conv2d(4), nn.Conv2d(4,3,3))), 4)
test_eq(in_channels(nn.Sequential(nn.AvgPool2d(4), nn.Conv2d(4,3,3))), 4)
test_eq(in_channels(nn.Sequential(BatchNorm(4), nn.Conv2d(4,3,3))), 4)
test_eq(in_channels(nn.Sequential(InstanceNorm(4, affine=False), nn.Conv2d(4,3,3))), 4)
test_eq(in_channels(nn.Sequential(InstanceNorm(lambda : in_channels(nn.Sequential(nn.AvgPool2d(4)))) test_fail(