test_eq(Identity()(1), 1)
def _add2(x): return x+2
tst = Lambda(_add2)
x = torch.randn(10,20)
test_eq(tst(x), x+2)
tst2 = pickle.loads(pickle.dumps(tst))
test_eq(tst2(x), x+2)
tst
def test_func(a,b=2): return a+b
tst = PartialLambda(test_func, b=5)
test_eq(tst(x), x+5)
tst = Flatten()
x = torch.randn(10,5,4)
test_eq(tst(x).shape, [10,20])
tst = Flatten(full=True)
test_eq(tst(x).shape, [200])
tst = View(10,5,4)
test_eq(tst(x).shape, [10,5,4])
tst = ResizeBatch(5,4)
test_eq(tst(x).shape, [10,5,4])
test = tensor([-10.,0.,10.])
assert torch.allclose(sigmoid_range(test, -1, 2), tensor([-1.,0.5, 2.]), atol=1e-4, rtol=1e-4)
assert torch.allclose(sigmoid_range(test, -5, -1), tensor([-5.,-3.,-1.]), atol=1e-4, rtol=1e-4)
assert torch.allclose(sigmoid_range(test, 2, 4), tensor([2., 3., 4.]), atol=1e-4, rtol=1e-4)
tst = SigmoidRange(-1, 2)
assert torch.allclose(tst(test), tensor([-1.,0.5, 2.]), atol=1e-4, rtol=1e-4)
If the input is bs x nf x h x h
, the output will be bs x 2*nf x 1 x 1
if no size is passed or bs x 2*nf x size x size
tst = AdaptiveConcatPool2d()
x = torch.randn(10,5,4,4)
test_eq(tst(x).shape, [10,10,1,1])
max1 = torch.max(x, dim=2, keepdim=True)[0]
maxp = torch.max(max1, dim=3, keepdim=True)[0]
test_eq(tst(x)[:,:5], maxp)
test_eq(tst(x)[:,5:], x.mean(dim=[2,3], keepdim=True))
tst = AdaptiveConcatPool2d(2)
test_eq(tst(x).shape, [10,10,2,2])
tst = PoolFlatten()
test_eq(tst(x).shape, [10,5])
test_eq(tst(x), x.mean(dim=[2,3]))
kwargs
are passed to nn.BatchNorm
and can be eps
, momentum
, affine
and track_running_stats
.
tst = BatchNorm(15)
assert isinstance(tst, nn.BatchNorm2d)
test_eq(tst.weight, torch.ones(15))
tst = BatchNorm(15, norm_type=NormType.BatchZero)
test_eq(tst.weight, torch.zeros(15))
tst = BatchNorm(15, ndim=1)
assert isinstance(tst, nn.BatchNorm1d)
tst = BatchNorm(15, ndim=3)
assert isinstance(tst, nn.BatchNorm3d)
tst = InstanceNorm(15)
assert isinstance(tst, nn.InstanceNorm2d)
test_eq(tst.weight, torch.ones(15))
tst = InstanceNorm(15, norm_type=NormType.InstanceZero)
test_eq(tst.weight, torch.zeros(15))
tst = InstanceNorm(15, ndim=1)
assert isinstance(tst, nn.InstanceNorm1d)
tst = InstanceNorm(15, ndim=3)
assert isinstance(tst, nn.InstanceNorm3d)
If affine
is false the weight should be None
test_eq(BatchNorm(15, affine=False).weight, None)
test_eq(InstanceNorm(15, affine=False).weight, None)
tst = BatchNorm1dFlat(15)
x = torch.randn(32, 64, 15)
y = tst(x)
mean = x.mean(dim=[0,1])
test_close(tst.running_mean, 0*0.9 + mean*0.1)
var = (x-mean).pow(2).mean(dim=[0,1])
test_close(tst.running_var, 1*0.9 + var*0.1, eps=1e-4)
test_close(y, (x-mean)/torch.sqrt(var+1e-5) * tst.weight + tst.bias, eps=1e-4)
The BatchNorm
layer is skipped if bn=False
, as is the dropout if p=0.
. Optionally, you can add an activation for after the linear layer with act
.
tst = LinBnDrop(10, 20)
mods = list(tst.children())
test_eq(len(mods), 2)
assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Linear)
tst = LinBnDrop(10, 20, p=0.1)
mods = list(tst.children())
test_eq(len(mods), 3)
assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Dropout)
assert isinstance(mods[2], nn.Linear)
tst = LinBnDrop(10, 20, act=nn.ReLU(), lin_first=True)
mods = list(tst.children())
test_eq(len(mods), 3)
assert isinstance(mods[0], nn.Linear)
assert isinstance(mods[1], nn.ReLU)
assert isinstance(mods[2], nn.BatchNorm1d)
tst = LinBnDrop(10, 20, bn=False)
mods = list(tst.children())
test_eq(len(mods), 1)
assert isinstance(mods[0], nn.Linear)
The convolution uses ks
(kernel size) stride
, padding
and bias
. padding
will default to the appropriate value ((ks-1)//2
if it's not a transposed conv) and bias
will default to True
the norm_type
is Spectral
or Weight
, False
if it's Batch
or BatchZero
. Note that if you don't want any normalization, you should pass norm_type=None
.
This defines a conv layer with ndim
(1,2 or 3) that will be a ConvTranspose if transpose=True
. act_cls
is the class of the activation function to use (instantiated inside). Pass act=None
if you don't want an activation function. If you quickly want to change your default activation, you can change the value of defaults.activation
.
init
is used to initialize the weights (the bias are initialized to 0) and xtra
is an optional layer to add at the end.
tst = ConvLayer(16, 32)
mods = list(tst.children())
test_eq(len(mods), 3)
test_eq(mods[1].weight, torch.ones(32))
test_eq(mods[0].padding, (1,1))
x = torch.randn(64, 16, 8, 8)#.cuda()
test_eq(tst(x).shape, [64,32,8,8])
tst = ConvLayer(16, 32, stride=2)
test_eq(tst(x).shape, [64,32,4,4])
tst = ConvLayer(16, 32, padding=0)
test_eq(tst(x).shape, [64,32,6,6])
assert mods[0].bias is None
#But can be overridden with `bias=True`
tst = ConvLayer(16, 32, bias=True)
assert first(tst.children()).bias is not None
#For no norm, or spectral/weight, bias is True by default
for t in [None, NormType.Spectral, NormType.Weight]:
tst = ConvLayer(16, 32, norm_type=t)
assert first(tst.children()).bias is not None
tst = ConvLayer(16, 32, ndim=3)
assert isinstance(list(tst.children())[0], nn.Conv3d)
tst = ConvLayer(16, 32, ndim=1, transpose=True)
assert isinstance(list(tst.children())[0], nn.ConvTranspose1d)
tst = ConvLayer(16, 32, ndim=3, act_cls=None)
mods = list(tst.children())
test_eq(len(mods), 2)
tst = ConvLayer(16, 32, ndim=3, act_cls=partial(nn.LeakyReLU, negative_slope=0.1))
mods = list(tst.children())
test_eq(len(mods), 3)
assert isinstance(mods[2], nn.LeakyReLU)
# def linear(in_features, out_features, bias=True, act_cls=None, init='auto'):
# "Linear layer followed by optional activation, with optional auto-init"
# res = nn.Linear(in_features, out_features, bias=bias)
# if act_cls: act_cls = act_cls()
# init_linear(res, act_cls, init=init)
# if act_cls: res = nn.Sequential(res, act_cls)
# return res
# @delegates(ConvLayer)
# def conv1d(ni, nf, ks, stride=1, ndim=1, norm_type=None, **kwargs):
# "Convolutional layer followed by optional activation, with optional auto-init"
# return ConvLayer(ni, nf, ks, stride=stride, ndim=ndim, norm_type=norm_type, **kwargs)
# @delegates(ConvLayer)
# def conv2d(ni, nf, ks, stride=1, ndim=2, norm_type=None, **kwargs):
# "Convolutional layer followed by optional activation, with optional auto-init"
# return ConvLayer(ni, nf, ks, stride=stride, ndim=ndim, norm_type=norm_type, **kwargs)
# @delegates(ConvLayer)
# def conv3d(ni, nf, ks, stride=1, ndim=3, norm_type=None, **kwargs):
# "Convolutional layer followed by optional activation, with optional auto-init"
# return ConvLayer(ni, nf, ks, stride=stride, ndim=ndim, norm_type=norm_type, **kwargs)
Truncated normal initialization bounds the distribution to avoid large value. For a given standard deviation std
, the bounds are roughly -2*std
, 2*std
.
std = 0.02
tst = Embedding(10, 30, std)
assert tst.weight.min() > -2*std
assert tst.weight.max() < 2*std
test_close(tst.weight.mean(), 0, 1e-2)
test_close(tst.weight.std(), std, 0.1)
Self-attention layer as introduced in Self-Attention Generative Adversarial Networks.
Initially, no change is done to the input. This is controlled by a trainable parameter named gamma
as we return x + gamma * out
.
tst = SelfAttention(16)
x = torch.randn(32, 16, 8, 8)
test_eq(tst(x),x)
Then during training gamma
will probably change since it's a trainable parameter. Let's see what's happening when it gets a nonzero value.
tst.gamma.data.fill_(1.)
y = tst(x)
test_eq(y.shape, [32,16,8,8])
The attention mechanism requires three matrix multiplications (here represented by 1x1 convs). The multiplications are done on the channel level (the second dimension in our tensor) and we flatten the feature map (which is 8x8 here). As in the paper, we note f
, g
and h
the results of those multiplications.
q,k,v = tst.query[0].weight.data,tst.key[0].weight.data,tst.value[0].weight.data
test_eq([q.shape, k.shape, v.shape], [[2, 16, 1], [2, 16, 1], [16, 16, 1]])
f,g,h = map(lambda m: x.view(32, 16, 64).transpose(1,2) @ m.squeeze().t(), [q,k,v])
test_eq([f.shape, g.shape, h.shape], [[32,64,2], [32,64,2], [32,64,16]])
The key part of the attention layer is to compute attention weights for each of our location in the feature map (here 8x8 = 64). Those are positive numbers that sum to 1 and tell the model to pay attention to this or that part of the picture. We make the product of f
and the transpose of g
(to get something of size bs by 64 by 64) then apply a softmax on the first dimension (to get the positive numbers that sum up to 1). The result can then be multiplied with h
transposed to get an output of size bs by channels by 64, which we can then be viewed as an output the same size as the original input.
The final result is then x + gamma * out
as we saw before.
beta = F.softmax(torch.bmm(f, g.transpose(1,2)), dim=1)
test_eq(beta.shape, [32, 64, 64])
out = torch.bmm(h.transpose(1,2), beta)
test_eq(out.shape, [32, 16, 64])
test_close(y, x + out.view(32, 16, 8, 8), eps=1e-4)
Self-attention layer used in the Big GAN paper.
It uses the same attention as in SelfAttention
but adds a max pooling of stride 2 before computing the matrices g
and h
: the attention is ported on one of the 2x2 max-pooled window, not the whole feature map. There is also a final matrix product added at the end to the output, before retuning gamma * out + x
.
PixelShuffle introduced in this article to avoid checkerboard artifacts when upsampling images. If we want an output with ch_out
filters, we use a convolution with ch_out * (r**2)
filters, where r
is the upsampling factor. Then we reorganize those filters like in the picture below:
ICNR init was introduced in this article. It suggests to initialize the convolution that will be used in PixelShuffle so that each of the r**2
channels get the same weight (so that in the picture above, the 9 colors in a 3 by 3 window are initially the same).
ch_out x ch_in x ks x ks
. tst = torch.randn(16*4, 32, 1, 1)
tst = icnr_init(tst)
for i in range(0,16*4,4):
test_eq(tst[i],tst[i+1])
test_eq(tst[i],tst[i+2])
test_eq(tst[i],tst[i+3])
The convolutional layer is initialized with icnr_init
and passed act_cls
and norm_type
(the default of weight normalization seemed to be what's best for super-resolution problems, in our experiments).
The blur
option comes from Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts where the authors add a little bit of blur to completely get rid of checkerboard artifacts.
psfl = PixelShuffle_ICNR(16, norm_type=None) #Deactivate weight norm as it changes the weight
x = torch.randn(64, 16, 8, 8)
y = psfl(x)
test_eq(y.shape, [64, 16, 16, 16])
#ICNR init makes every 2x2 window (stride 2) have the same elements
for i in range(0,16,2):
for j in range(0,16,2):
test_eq(y[:,:,i,j],y[:,:,i+1,j])
test_eq(y[:,:,i,j],y[:,:,i ,j+1])
test_eq(y[:,:,i,j],y[:,:,i+1,j+1])
This is useful to write layers that require to remember the input (like a resnet block) in a sequential way.
res_block = SequentialEx(ConvLayer(16, 16), ConvLayer(16,16))
res_block.append(MergeLayer()) # just to test append - normally it would be in init params
x = torch.randn(32, 16, 8, 8)
y = res_block(x)
test_eq(y.shape, [32, 16, 8, 8])
test_eq(y, x + res_block[1](res_block[0](x)))
x = TensorBase(torch.randn(32, 16, 8, 8))
y = res_block(x)
test_is(y.orig, None)
Equivalent to keras.layers.Concatenate, it will concat the outputs of a ModuleList over a given dimension (default the filter dimension)
layers = [ConvLayer(2,4), ConvLayer(2,4), ConvLayer(2,4)]
x = torch.rand(1,2,8,8)
cat = Cat(layers)
test_eq(cat(x).shape, [1,12,8,8])
test_eq(cat(x), torch.cat([l(x) for l in layers], dim=1))
The model is a succession of convolutional layers from (filters[0],filters[1])
to (filters[n-2],filters[n-1])
(if n
is the length of the filters
list) followed by a PoolFlatten
. kernel_szs
and strides
defaults to a list of 3s and a list of 2s. If bn=True
the convolutional layers are successions of conv-relu-batchnorm, otherwise conv-relu.
tst = SimpleCNN([8,16,32])
mods = list(tst.children())
test_eq(len(mods), 3)
test_eq([[m[0].in_channels, m[0].out_channels] for m in mods[:2]], [[8,16], [16,32]])
Test kernel sizes
tst = SimpleCNN([8,16,32], kernel_szs=[1,3])
mods = list(tst.children())
test_eq([m[0].kernel_size for m in mods[:2]], [(1,1), (3,3)])
Test strides
tst = SimpleCNN([8,16,32], strides=[1,2])
mods = list(tst.children())
test_eq([m[0].stride for m in mods[:2]], [(1,1),(2,2)])
This is a resnet block (normal or bottleneck depending on expansion
, 1 for the normal block and 4 for the traditional bottleneck) that implements the tweaks from Bag of Tricks for Image Classification with Convolutional Neural Networks. In particular, the last batchnorm layer (if that is the selected norm_type
) is initialized with a weight (or gamma) of zero to facilitate the flow from the beginning to the end of the network. It also implements optional Squeeze and Excitation and grouped convs for ResNeXT and similar models (use dw=True
for depthwise convs).
The kwargs
are passed to ConvLayer
along with norm_type
.
It's easy to get the list of all parameters of a given model. For when you want all submodules (like linear/conv layers) without forgetting lone parameters, the following class wraps those in fake modules.
class TstModule(Module):
def __init__(self): self.a,self.lin = nn.Parameter(torch.randn(1)),nn.Linear(5,10)
tst = TstModule()
children = children_and_parameters(tst)
test_eq(len(children), 2)
test_eq(children[0], tst.lin)
assert isinstance(children[1], ParameterModule)
test_eq(children[1].val, tst.a)
class A(Module): pass
assert not has_children(A())
assert has_children(TstModule())
tst = nn.Sequential(TstModule(), TstModule())
children = flatten_model(tst)
test_eq(len(children), 4)
assert isinstance(children[1], ParameterModule)
assert isinstance(children[3], ParameterModule)
x,y = torch.randn(5),torch.randn(5)
loss_fn = nn.MSELoss()
with NoneReduce(loss_fn) as loss_func:
loss = loss_func(x,y)
test_eq(loss.shape, [5])
test_eq(loss_fn.reduction, 'mean')
loss_fn = F.mse_loss
with NoneReduce(loss_fn) as loss_func:
loss = loss_func(x,y)
test_eq(loss.shape, [5])
test_eq(loss_fn, F.mse_loss)
test_eq(in_channels(nn.Sequential(nn.Conv2d(5,4,3), nn.Conv2d(4,3,3))), 5)
test_eq(in_channels(nn.Sequential(nn.AvgPool2d(4), nn.Conv2d(4,3,3))), 4)
test_eq(in_channels(nn.Sequential(BatchNorm(4), nn.Conv2d(4,3,3))), 4)
test_eq(in_channels(nn.Sequential(InstanceNorm(4), nn.Conv2d(4,3,3))), 4)
test_eq(in_channels(nn.Sequential(InstanceNorm(4, affine=False), nn.Conv2d(4,3,3))), 4)
test_fail(lambda : in_channels(nn.Sequential(nn.AvgPool2d(4))))