AWD-LSTM

AWD LSTM from Smerity et al.

Basic NLP modules

On top of the pytorch or the fastai layers, the language models use some custom layers specific to NLP.


source

dropout_mask


def dropout_mask(
    x:Tensor, # Source tensor, output will be of the same type as `x`
    sz:list, # Size of the dropout mask as `int`s
    p:float, # Dropout probability
)->Tensor: # Multiplicative dropout mask

Return a dropout mask of the same type as x, size sz, with probability p to cancel an element.

t = dropout_mask(torch.randn(3,4), [4,3], 0.25)
test_eq(t.shape, [4,3])
assert ((t == 4/3) + (t==0)).all()

source

RNNDropout


def RNNDropout(
    p:float=0.5
):

Dropout with probability p that is consistent on the seq_len dimension.

dp = RNNDropout(0.3)
tst_inp = torch.randn(4,3,7)
tst_out = dp(tst_inp)
for i in range(4):
    for j in range(7):
        if tst_out[i,0,j] == 0: assert (tst_out[i,:,j] == 0).all()
        else: test_close(tst_out[i,:,j], tst_inp[i,:,j]/(1-0.3))

It also supports doing dropout over a sequence of images where time dimesion is the 1st axis, 10 images of 3 channels and 32 by 32.

_ = dp(torch.rand(4,10,3,32,32))

source

WeightDropout


def WeightDropout(
    module:nn.Module, # Wrapped module
    weight_p:float, # Weight dropout probability
    layer_names:str | MutableSequence='weight_hh_l0', # Name(s) of the parameters to apply dropout to
):

A module that wraps another layer in which some weights will be replaced by 0 during training.

module = nn.LSTM(5,7)
dp_module = WeightDropout(module, 0.4)
wgts = dp_module.module.weight_hh_l0
tst_inp = torch.randn(10,20,5)
h = torch.zeros(1,20,7), torch.zeros(1,20,7)
dp_module.reset()
x,h = dp_module(tst_inp,h)
loss = x.sum()
loss.backward()
new_wgts = getattr(dp_module.module, 'weight_hh_l0')
test_eq(wgts, getattr(dp_module, 'weight_hh_l0_raw'))
assert 0.2 <= (new_wgts==0).sum().float()/new_wgts.numel() <= 0.6
assert dp_module.weight_hh_l0_raw.requires_grad
assert dp_module.weight_hh_l0_raw.grad is not None
assert ((dp_module.weight_hh_l0_raw.grad == 0.) & (new_wgts == 0.)).any()

source

EmbeddingDropout


def EmbeddingDropout(
    emb:nn.Embedding, # Wrapped embedding layer
    embed_p:float, # Embdedding layer dropout probability
):

Apply dropout with probability embed_p to an embedding layer emb.

enc = nn.Embedding(10, 7, padding_idx=1)
enc_dp = EmbeddingDropout(enc, 0.5)
tst_inp = torch.randint(0,10,(8,))
tst_out = enc_dp(tst_inp)
for i in range(8):
    assert (tst_out[i]==0).all() or torch.allclose(tst_out[i], 2*enc.weight[tst_inp[i]])

source

AWD_LSTM


def AWD_LSTM(
    vocab_sz:int, # Size of the vocabulary
    emb_sz:int, # Size of embedding vector
    n_hid:int, # Number of features in hidden state
    n_layers:int, # Number of LSTM layers
    pad_token:int=1, # Padding token id
    hidden_p:float=0.2, # Dropout probability for hidden state between layers
    input_p:float=0.6, # Dropout probability for LSTM stack input
    embed_p:float=0.1, # Embedding layer dropout probabillity
    weight_p:float=0.5, # Hidden-to-hidden wight dropout probability for LSTM layers
    bidir:bool=False, # If set to `True` uses bidirectional LSTM layers
):

AWD-LSTM inspired by https://arxiv.org/abs/1708.02182

This is the core of an AWD-LSTM model, with embeddings from vocab_sz and emb_sz, n_layers LSTMs potentially bidir stacked, the first one going from emb_sz to n_hid, the last one from n_hid to emb_sz and all the inner ones from n_hid to n_hid. pad_token is passed to the PyTorch embedding layer. The dropouts are applied as such:

  • the embeddings are wrapped in EmbeddingDropout of probability embed_p;
  • the result of this embedding layer goes through an RNNDropout of probability input_p;
  • each LSTM has WeightDropout applied with probability weight_p;
  • between two of the inner LSTM, an RNNDropout is applied with probability hidden_p.

THe module returns two lists: the raw outputs (without being applied the dropout of hidden_p) of each inner LSTM and the list of outputs with dropout. Since there is no dropout applied on the last output, those two lists have the same last element, which is the output that should be fed to a decoder (in the case of a language model).

tst = AWD_LSTM(100, 20, 10, 2, hidden_p=0.2, embed_p=0.02, input_p=0.1, weight_p=0.2)
x = torch.randint(0, 100, (10,5))
r = tst(x)
test_eq(tst.bs, 10)
test_eq(len(tst.hidden), 2)
test_eq([h_.shape for h_ in tst.hidden[0]], [[1,10,10], [1,10,10]])
test_eq([h_.shape for h_ in tst.hidden[1]], [[1,10,20], [1,10,20]])

test_eq(r.shape, [10,5,20])
test_eq(r[:,-1], tst.hidden[-1][0][0]) #hidden state is the last timestep in raw outputs

tst.eval()
tst.reset()
tst(x);
tst(x);

source

awd_lstm_lm_split


def awd_lstm_lm_split(
    model
):

Split a RNN model in groups for differential learning rates.


source

awd_lstm_clas_split


def awd_lstm_clas_split(
    model
):

Split a RNN model in groups for differential learning rates.