Demystifying Neural Language Models

Exploring the temporal nature of language

Language is an inherently temporal phenomenon. Spoken language is a sequence of acoustic events over time, and we comprehend and produce both spoken and written language as a continuous input stream. Human thought persists from moment to moment, building upon previous understanding. Traditional neural networks lack this ability, making it difficult for them to reason about complex sequences of events. Recurrent neural networks or RNNs are a family of neural networks for processing sequential data. RNNs have a mechanism that deals directly with the sequential nature of language, allowing them to handle the temporal nature of language without the use of arbitrary fixed-sized windows. The recurrent network offers a new way to represent the prior context, in its recurrent connections, allowing the model’s decision to depend on information from hundreds of words in the past.

Recurrent Neural Networks

A recurrent neural network (RNN) is any network that contains a cycle within its network connections, meaning that the value of some unit is directly, or indirectly, dependent on its own earlier outputs as an input.

Figure 1: Recurrent Neural Network Architecture.

Figure 1 shows the original RNN architecture. It can be seen that an RNN has a hidden state, and the hidden state is recurrent because every single hidden state takes the value of previous hidden states. The current hidden state $h_t$ is a function of previous hidden state $h_{t−1}$ and current input state $x_t$

Figure 2: Closer view of the repeating module in a standard RNN.

When we look at the repeating module more closely (Figure 2) we can see that function $f_w$ multiplies previous hidden state $h_{t−1}$ by weight matrix $W_{hh}$, multiplies the input $x_t$ by the weight matrix $W_{xh}$ and pass the multiplied values to $tanh$ function. Vanilla RNN uses $tanh$ to add non linearity to the neural network. The weights $W_{hh}$ and $W_{xh}$ are shared across all the cells.

\[\begin{equation} h_t = \text{tanh}(W_{hh}h_{t−1} + W_{xh}x_t + b_h) \end{equation}\]

To make a prediction, we multiply $h_t$ with another weight matrix $W_{hy}$.

\[\begin{equation} y_t = \text{softmax}(W_{hy}h_t + b_y) \end{equation}\]

Training RNN

As with feedforward networks, In RNN, we do forward pass and calculate the loss. Once the loss is calculated, we backpropagate and update the weight by taking the gradient step. This process is called backpropagation through time.

Figure 3: Backpropagation in a standard RNN.

In sequences of considerable length, determining the gradient of the final hidden state, denoted as $h_n$, with respect to the initial hidden state, $h_0$, necessitates the recurrent multiplication of weight matrices $W_{hh}$ and $W_{xh}$. If the largest singular value of these weight matrices is smaller than one, the gradient will vanish; conversely, if the largest singular value surpasses one, the gradient will experience exponential growth in extensive sequences. Gradient clipping provides a solution to prevent the issue of exploding gradients. As the weights are typically initialized close to zero, the largest singular value is often less than one, which leads to the vanishing gradient problem in vanilla RNNs. Long Short-Term Memory (LSTM) networks offer an effective approach to address this challenge, mitigating the vanishing gradient issue.

Gradient clipping

In general terms, when addressing an optimization problem, we update the model parameters, represented in vector form as $x$, in the direction of the negative gradient $g$, based on a minibatch. With a learning rate η > 0, we update $x$ in one iteration as x − ηg. Assuming the objective function $f$ exhibits well-behaved properties, such as Lipschitz continuity with constant $L$, we can assert that for any a and b, the following holds:

\[|f(\mathbf{a})-f(\mathbf{b})| \leq L\|\mathbf{a}-\mathbf{b}\|\]

In this case we can safely assume that if we update the parameter vector by $\eta\mathbf{g}$, then:

\[|f(\mathbf{x})-f(\mathbf{x}-\eta \mathbf{g})| \leq L \eta\|\mathbf{g}\|\]

which means that we will not observe a change by more than $\eta\mathbf{\vert \vert g\vert \vert }$. This is both a curse and a blessing. On the curse side, it limits the speed of making progress; whereas on the blessing side, it limits the extent to which things can go wrong if we move in the wrong direction.

Sometimes the gradients can be quite large and the optimization algorithm may fail to converge. We could address this by reducing the learning rate $\eta$. But in this case when large gradients are sparse, such an approach may appear entirely unwarranted. One popular alternative is to clip the gradient $\mathbf{g}$ by projecting them back to a ball of a given radius, say $\theta$ via:

\[g \leftarrow \min \left(1, \frac{\theta}{\|\mathbf{g}\|}\right) \mathbf{g}\]

Long Short Term Memory (LSTM)

Long short term memory (LSTM) is developed to deal with vanishing gradient problem that can be encountered when training traditional RNNs. The advantage if an LSTM cell compared to common recurrent units is its cell memory unit. The cell vector can encapsulate the notion of forgetting part of its previously-stored memory, as well as to add part of the new information.

Figure 4: Inside a single LSTM Unit.

Input Gate, Forget Gate, and Output Gate

The data feeding into the LSTM gates are the input at the current time step and the hidden state of the previous time step. Three fully connected layers with sigmoid activation functions compute the values of the input, forget, and output gates. The input gate intuitively decides the amount of the input node’s value that needs to be added to the internal state of the current memory cell. The forget gate determines whether to retain or discard the current memory value. Finally, the output gate determines if the memory cell should have any impact on the output at the current time step.

Mathematically, suppose that there are $h$ hidden units, the batch size is $n$, and the number of inputs is $d$. Thus, the input is $\mathbf{X_t} \in \mathbb{R}^{n \times d}$ and the hidden state of the previous time step is $\mathbf{h_{t-1}} \in \mathbb{R}^{n \times h}$. Correspondingly, the gates at time step $t$ are defined as follows: the input gate is $\mathbf{i_t} \in \mathbb{R}^{n \times h}$, the forget gate is $\mathbf{f_t} \in \mathbb{R}^{n \times h}$, and the output gate is $\mathbf{o_t} \in \mathbb{R}^{n \times h}$. They are calculated as follows:

\[\begin{align*} i_t = \ &\sigma(\mathbf{X_t}\textcolor{red}{\mathbf{W_{ii}}} + \textcolor{red}{\mathbf{b_{ii}}} + \mathbf{h_{t-1}}\textcolor{green}{\mathbf{W_{hi}}} + \textcolor{green}{\mathbf{b_{hi}}})& \\ f_t = \ &\sigma(\mathbf{X_t}\textcolor{red}{\mathbf{W_{if}}} + \textcolor{red}{\mathbf{b_{if}}} + \mathbf{h_{t-1}}\textcolor{green}{\mathbf{W_{hf}}} + \textcolor{green}{\mathbf{b_{hf}}})& \\ o_t = \ &\sigma(\mathbf{X_t}\textcolor{red}{\mathbf{W_{io}}} + \textcolor{red}{\mathbf{b_{io}}} + \mathbf{h_{t-1}}\textcolor{green}{\mathbf{W_{ho}}} + \textcolor{green}{\mathbf{b_{ho}}})& \end{align*}\]

where the weight parameters are $\textcolor{red}{\mathbf{W_{ii}}}$, $\textcolor{red}{\mathbf{W_{if}}}$, $\textcolor{red}{\mathbf{W_{io}}} \in \mathbb{R}^{d \times h}$ and $\textcolor{red}{\mathbf{W_{hi}}}$, $\textcolor{red}{\mathbf{W_{hf}}}$, $\textcolor{red}{\mathbf{W_{ho}}} \in \mathbb{R}^{h \times h}$ and $\textcolor{red}{\mathbf{b_{ii}}}$, $\textcolor{red}{\mathbf{b_{if}}}$, $\textcolor{red}{\mathbf{b_{io}}}$, $\textcolor{green}{\mathbf{b_{hi}}}$, $\textcolor{green}{\mathbf{b_{hf}}}$, $\textcolor{green}{\mathbf{b_{ho}}} \in \mathbb{R}^{1 \times h}$ are the bias parameters.

Candidate memory state

In an LSTM, the candidate memory state is calculated by first applying a non-linear function (such as a hyperbolic tangent) to the input at the current time step and the previous hidden state. This transformed input is then multiplied by the input gate output, which determines the amount of the input to let into the memory cell. This leads to the following equation at time step $t$:

\[\begin{equation} g_t = \ \sigma(\mathbf{X_t}\textcolor{red}{\mathbf{W_{ig}}} + \textcolor{red}{\mathbf{b_{ig}}} + \mathbf{h_{t-1}}\textcolor{green}{\mathbf{W_{hg}}} + \textcolor{green}{\mathbf{b_{hg}}}) \\ \end{equation}\]

Memory cell

In LSTMs, the input gate $i_t$ governs how much we take new data into account via $g_t$ and the forget gate $f_t$ addresses how much of the old cell internal state $\mathbf{C_{t-1}} \in \mathbb{R}^{n \times h}$ we retain. Using the Hadamard (elementwise) product operator $\odot$ we arrive at the following update equation:

\[\begin{equation} C_t = f_t \odot C_{t-1} + i_t \odot g_t \end{equation}\]

If the forget gate is always 1 and the input gate is always 0, the memory cell internal state $\mathbf{C_{t-1}}$ will remain constant forever, passing unchanged to each subsequent time step. The input gates and forget gates in the model provide it with the ability to determine whether to maintain a value as is or modify it based on future inputs, thus allowing the model to learn in a more flexible manner. This design effectively addresses the vanishing gradient problem, making the models easier to train, especially when dealing with datasets that have long sequence lengths.

Hidden state

In LSTMs, we first apply $tanh$ to the memory cell internal state and then apply another point-wise multiplication, this time with the output gate.

\[\begin{equation} h_t = o_t \odot \text{tanh}(C_t) \end{equation}\]

The output gate in LSTM can allow or prevent the memory cell internal state from impacting subsequent layers. When the output gate value is close to 1, the impact is uninhibited, and when it’s close to 0, the impact is prevented. This allows the memory cell to accumulate information over many time steps without impacting the network until the output gate flips to values close to 1 at a subsequent time step.

Gated Recurrent Unit (GRU)

In the 2010s, researchers experimented with simplified RNN architectures that retained the internal state and gating mechanisms of LSTMs, but were faster. The GRU is a streamlined version of LSTM that performs comparably, but with quicker computation.

Figure 5: Inside a single Gated Recurrent Unit.

GRU overcomes the vanishing gradient problem in traditional RNNs by employing update and reset gates (one less than the LSTMs). These gates are two vectors that determine the information to be transmitted to the output. What distinguishes them is their ability to learn how to retain information from the distant past without erasing it or discarding irrelevant information that doesn’t contribute to the prediction.

Reset and update gates

The reset gate decides how much of the previous hidden state should be ignored while computing the current hidden state. This allows the model to selectively forget information that is no longer relevant. Reset gates help capture short-term dependencies in sequences.

The update gate enables the model to decide the amount of past information (from prior time steps) that should be transmitted to the future. This is a potent capability because the model can choose to retain all past information and eliminate the risk of the vanishing gradient problem.

Mathematically, for a given time step $t$, suppose that the input is a minibatch $\mathbf{X_t} \in \mathbb{R}^{n \times d}$ (number of examples: $n$, number of inputs: $d$) and the hidden state of the previous time step is $\mathbf{h_{t-1}} \in \mathbb{R}^{n \times h}$ (number of hidden units: $h$ ).

Then, the reset gate $\mathbf{r_t} \in \mathbb{R}^{n \times d}$ and update gate $\mathbf{z_t} \in \mathbb{R}^{n \times d}$ are calculated as follows:

\[\begin{align*} r_t = \ &\sigma(\mathbf{X_t}\textcolor{red}{\mathbf{W_{ir}}} + \textcolor{red}{\mathbf{b_{ir}}} + \mathbf{h_{t-1}}\textcolor{green}{\mathbf{W_{hr}}} + \textcolor{green}{\mathbf{b_{hr}}})& \\ z_t = \ &\sigma(\mathbf{X_t}\textcolor{red}{\mathbf{W_{iz}}} + \textcolor{red}{\mathbf{b_{iz}}} + \mathbf{h_{t-1}}\textcolor{green}{\mathbf{W_{hz}}} + \textcolor{green}{\mathbf{b_{hz}}})& \end{align*}\]

Again $\textcolor{red}{\mathbf{W_{ir}}}$,$\textcolor{red}{\mathbf{W_{iz}}}$ $\in \mathbb{R}^{d \times h}$ and $\textcolor{green}{\mathbf{W_{hr}}}$,$\textcolor{green}{\mathbf{W_{hz}}} \in \mathbb{R}^{h \times h}$ are the weight parameters whereas $\textcolor{red}{\mathbf{b_{ir}}}$,$\textcolor{red}{\mathbf{b_{iz}}}$,$\textcolor{green}{\mathbf{b_{hr}}}$, and $\textcolor{green}{\mathbf{b_{hz}}} \in \mathbb{R}^{1 \times h}$ are bias parameters respectively.

Candidate hidden unit

We can calculate the candidate hidden unit $\mathbf{n_t} \in \mathbb{R}^{n \times h}$ at time step t, by incorporating the reset gate $\mathbf{r_t}$, with the regular updating mechanism in \eqref{eq:2}

\[\begin{equation} n_t = \text{tanh}(\mathbf{X_t}\textcolor{red}{\mathbf{W_{in}}} + \textcolor{red}{\mathbf{b_{in}}} + (\mathbf{r_t}\odot \mathbf{h_{t-1}})\textcolor{green}{\mathbf{W_{hn}}} + \textcolor{green}{\mathbf{b_{hn}}}) \end{equation}\]

where $\textcolor{red}{\mathbf{W_{in}}}$ $\in \mathbb{R}^{d \times h}$ and $\textcolor{green}{\mathbf{W_{hn}}} \in \mathbb{R}^{h \times h}$ are weight parameters whereas $\textcolor{red}{\mathbf{b_{in}}}$, $\textcolor{green}{\mathbf{b_{hn}}} \in \mathbb{R}^{1 \times h}$ are bias parameters respectively and the symbol $\odot$ is the Hadamard (elementwise) product operator. Comparing with \eqref{eq:2}, now the influence of the previous states can be reduced with the elementwise multiplication of $\mathbf{r_t}$ and $\mathbf{h_{t-1}}$ in \eqref{eq:5}. Whenever the entries in the reset gate $\mathbf{r_t}$ are close to 1, we recover a vanilla RNN. For all entries of the reset gate $\mathbf{r_t}$ that are close to 0, the candidate hidden state is the result of a MLP with $\mathbf{X_t}$ as input. Any pre-existing hidden state is thus reset to defaults.

Hidden state

Finally, we need to incorporate the effect of the update gate $\mathbf{z_t}$. This determines the extent to which the new hidden state $\mathbf{h_t} \in \mathbb{R}^{n \times h}$ matches the old state $\mathbf{h_{t-1}}$ versus how much it resembles the new candidate state $\mathbf{n_t}$, simply by taking elementwise convex combinations of $\mathbf{h_{t-1}}$ and $\mathbf{n_t}$. This leads to the final update equation for the GRU:

\[\begin{equation} h_t = z_t \odot h_{t-1} + (1 - z_t) \odot n_t \end{equation}\]

Whenever the update gate $\mathbf{z_t}$ is close to 1, we simply retain the old state. In this case the information from $\mathbf{X_t}$ is ignored, effectively skipping time step $t$ in the dependency chain. In contrast, whenever $\mathbf{z_t}$ is close to 0, the new latent state $\mathbf{h_t}$ approaches the candidate latent state $\mathbf{n_t}$.

Deep RNNs

So far, we have focused on networks with a sequence input, a single hidden RNN layer, and an output layer. Despite having only one hidden layer, these networks are deep because inputs from the first time step can influence the outputs at the final time step. However, we also want to express complex relationships between inputs and outputs at the same time step, so we construct RNNs that are deep in both time and input-to-output directions, which is similar to MLPs.

Stacked RNNs

In their 2013 study, Graves, Mohamed, and Hinton implemented a deep recurrent neural network (RNN) for speech recognition, where ‘depth’ refers to the stacking of more than one hidden layer. This setup computes hidden vectors iteratively across all layers and time instances.

The researchers selected bidirectional LSTM (BiLSTM) as their hidden layer function. Unlike regular LSTM, BiLSTM processes input in both original and reversed order by stacking two distinct hidden layers responsible for forward and backward information flow. This is especially beneficial in speech recognition, where the pronunciation depends on preceding and succeeding phonemes, enabling BiLSTMs to handle long-term dependencies in both directions.

Deep Transition RNNs

Pascanu et al. (2013) proposed a methodology to deepen RNNs by including one or more intermediate nonlinear layers between various state transitions, arguing this allows for better capturing of temporal structures and more efficient input summarization. They introduced deep hidden-to-output (DO-RNN) and deep hidden-to-hidden (DT-RNN) models, which help create compact hidden states and facilitate adding new information to summaries from previous steps. They acknowledged potential issues with losing long-time dependencies due to deeper transitions and proposed shortcut connections, resulting in DT(S)-RNNs and DOT(S)-RNNs. These models were tested on tasks like polyphonic music prediction and language modelling, showing clear superiority of deep transition RNNs over shallow ones in terms of perplexity and negative log-likelihood.

import torch
import torch.nn as nn

class DeepTransitionRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, num_transitions=1):
        super(DeepTransitionRNN, self).__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_transitions = num_transitions

        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.transitions = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(num_transitions)])
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

        out, _ = self.rnn(x, h0)

        for i in range(self.num_transitions):
            out = torch.tanh(self.transitions[i](out))

        out = self.fc(out[:, -1, :])

        return out

Encoder-Decoder Architecture

Training

In Natural Language Processing (NLP), the task of converting variable-length input sequences into variable-length output sequences is referred to as Sequence-to-Sequence, or seq2seq learning. As Cho et al. (2014) elucidate, the seq2seq model comprises two parts, illustrated as follows:

Figure 6: Encoder-Decoder architecture. Source

The seq2seq model comprises two components. The first, an encoder, is an RNN trained on input sequences to create a fixed-dimension summary vector $\boldsymbol{c}$, commonly referred to as the context. This vector is typically a function of the last hidden state, with Sutskever, Vinyals, & Le (2014) opting to use the final encoder hidden state as the context such that $\boldsymbol{c} = h_{(T)}^{e}$. The second component, the decoder, is another RNN which generates predictions given the context $\boldsymbol{c}$ and all previous outputs. In deviation from a simple RNN, decoder hidden states $h_{(t)}^{d}$ are conditioned on the previous outputs $y_{(t)}$, previous hidden states $h_{(t)}^{d}$, and the summary vector $\boldsymbol{c}$ from the encoder. Hence, the conditional distribution of the one-step prediction is obtained as follows:

\[\begin{equation} p(y_{(t)} | y_{1}, y_{2}, ..., y_{(t-1)}, \boldsymbol{c}) = f(h_{(t)}^{d}, y_{(t-1)}, \boldsymbol{c}) \end{equation}\]

Both parts are trained simultaneously to maximize the conditional log-likelihood $\frac{1}{N} \sum_{n=1}^{N} \log p_{\theta}(y_n | x_n)$ where $\theta$ denotes the set of parameters and ($x_n,y_n$) is an (input sequence, output sequence) pair from the training set with size N.

Multi-task seq2seq Learning

Luong et al. (2015) advanced the encoder-decoder architecture by incorporating multi-task learning (MLT) into seq2seq models. The aim of MLT is to leverage related tasks to enhance performance, with each task complementing the others. They examined three configurations: a) one-to-many - a shared encoder for diverse tasks like translation and syntactic parsing; b) many-to-one - distinct tasks are learned by the encoders, such as translation and image captioning, with a shared decoder; c) many-to-many - the model employs multiple encoders and decoders, characteristic of autoencoders, an unsupervised task designed to capture monolingual data representation.

Figure 7: Multi-task learning. Source

Recurrent Neural Networks (RNNs) are potent for sequential data processing, but their limitations spurred advancements like gated units and encoder-decoder architectures. As NLP’s prominence grows, further progress like Transfer Learning and Attention mechanism have emerged, building upon RNNs and their extensions. More on them in some other post.