Recurrent Neural Networks and Long Short-Term Memory Networks
January 16, 2018
What is a Recurrent Neural Net?
Recurrent neural networks (RNNs) are neural networks designed to analyze sequences of inputs, which traditional feed-forward neural networks are not good at modeling. Text comprehension, translation, motor control and speech are all instances of problems where either the input or the output (or both) is a sequence. Depending on the problem, the sequences can be of arbitrary length and have complex time dependencies.
RNNs are built on the same computational unit as feed forward nets (neurons), but differ in the architecture of how these neurons are connected to one another. In feed forward neural networks, information flows unidirectionally from input units to output units. In particular, there are no cycles in the neural network. In RNNs, cycles are allowed. In fact, neurons are even allowed to be connected to themselves. Since the previous vector of activities is used to compute the vector of activities in each time step, RNNs are able to retain memory of previous events and utilize this memory in making decisions.
Feed-forward NN (left) vs Recurrent NN (right). In RNNs, cyclic connections are allowed, including self connections.
Illustration of RNN simulating a sequence
Now that we understand how an RNN is structured, let’s see how it is able to simulate a sequence of events. Suppose we’d like our RNN to act like a timer module (a classic example designed by Herbert Jaeger, original manuscript can be found here).
We have two inputs. The input u1 corresponds to a binary switch which spikes to 1 when the RNN is supposed to start the timer. The input u2 is a discrete variable that varies between 0.1 and 1.0 inclusive in steps of 0.1 which corresponds to how long the output should be turned on if the timer is started at that instant. u2 assumes a new random value every time the input u1 triggers the timer and the RNN’s specification requires it to turn on the output, y1, (to a value of 0.5) for a duration of 1000u2.
Example of a timer we would like our RNN to simulate
But how exactly would a neural net achieve this calculation? First, the RNN has all of its hidden activities initialized to some pre-determined state (usually all zeros). Then at each time step (time t=1, 2, …), every neuron sends its current activity through all its outgoing connections. The input neurons current activity is the input at that particular time step. The neurons recalculate their new activity by computing the weighted sum of its inputs from other neurons, and then applying their respective activation functions (sigmoid, tanh, linear, etc).
An RNN is unlikely to perfectly emulate a timer as per the problem specification above since its hidden states and outputs are real values (as opposed to discrete values). However, it is reasonable to expect it to output a result (orange) that follows the ground truth (blue) pretty closely for this example after training the RNN with hundreds or thousands of examples.
An example fit for how a well-trained RNN might approximate the output of a test case
The above was a toy example. In practice, the sequence to estimate would be data from real systems. For example, one could train an RNN to transcribe audio into text by building a dataset or to train an autonomous robot to perform some task.
How an RNN might be used in practice to train an autonomous robot
Training an RNN: Backpropagation Through Time
Now that we understand what a RNN is, let’s look at how we can train it. Specifically, how do we determine the weights that are on each of the connections, and how do we choose the initial activities of all the hidden units? Our first instinct might be to use backpropagation, the algorithm we used for training feed-forward neural nets.
The problem with using backpropagation directly is that we have cyclical dependencies. In feed forward nets, when we calculated the error derivatives with respect to the weights in one layer, we could express them completely in terms of the error derivatives from the layer above. In a recurrent neural network, we don’t have this nice layering because the neurons do not form an acyclic graph. Back-propagating through an RNN would require us to express an error derivative in terms of itself.
In order to apply the backpropagation algorithm on RNNs, we employ a clever transformation where we convert our RNN into a new structure that’s essentially a feed-forward neural network! We call this strategy unrolling the RNN through time. An example can be seen in the figure below (with only one input/output per time step to simplify the illustration):
An example of “unrolling” an RNN through time to use backpropagation
We take the RNN’s inputs, outputs, and hidden units and replicate it for every time step. These replications correspond to layers in our new feed forward neural network. We then connect hidden units as follows: if the original RNN has a connection of weight ω from neuron i to neuron j, in our feed forward neural net, we draw a connection of weight ω from neuron i in every layer tk to neuron j in every layer tk+1. Conceptually, we’re thinking of each of the neurons in each of the time steps as a separate neuron. The gradient of each connection weight ω in the RNN is the sum over all the gradients for the connections in the feed forward net corresponding to it.
Thus, to train our RNN, we randomly initialize the weights, “unroll” it into a feed forward neural net, and backpropagate to determine the optimal weights! To determine the initializations for the hidden states at time 0, we can treat the initial activities as parameters fed into the feed forward network at the lowest layer and backpropagate to determine their optimal values as well.
The Problems with Deep Backpropagation
Since the number of time steps for a particular problem can be quite large (say 1000 or 10000), our unrolled feed forward nets can be enormously deep. This gives rise to a serious practical issue — vanishing and exploding gradients. Because of applying the same transformation repeatedly (represented by the weights of the RNN), we get gradients that either vanish over the layers or explode.
Below is a toy example to illustrate the point. Let’s say we have an RNN with a single hidden unit with a bias term, and we’ll connect it to itself and a single output. We want this neural network to output a fixed target value after 50 steps, let’s say 0.7. The cost function is the squared error, which we can plot as a surface over the value of the weight and the bias:
The error surface for our simple RNN (source: Pascanu et al.)
Now, let’s say we started at the red star (using a random initialization of weights). You’ll notice that as we use gradient descent, we get closer and closer to the local minimum on the surface. But suddenly, when we slightly overreach the valley and hit the cliff, we are presented with a massive gradient in the opposite direction. This forces us to bounce extremely far away from the local minimum. And once we’re in the no man’s land, we find that the gradients are so vanishingly small that coming close again will take a seemingly endless amount of time.
For an in-depth mathematical treatment of this issue, check out this paper.
Long Short-Term Memory networks
To address these problems, we use a modified architecture for recurrent neural networks to help bridge long time lags between forcing inputs and appropriate responses and protect against exploding gradients. The architecture forces constant error flow (thus, neither exploding nor vanishing) through the internal state of special memory units.
Long Short-Term Memory (LSTM) networks are comprised of LSTM units. Each LSTM unit has the following structure:
Structure of a Long Short-Term Memory (LSTM) unit. The points (black dots) where the gate activity meets the other connections are multiplication operations.
The LSTM unit consists of a memory cell which attempts to store information for extended periods of time. Access to this memory cell is protected by specialized gate neurons - the keep, write, and read gates - which are all logistic units. The memory cell is a linear neuron that has a connection to itself. In the above diagram, the points (black dots) where the gate activity meets the other connections are multiplication operations.
When the keep gate has an activity of 1 (is turned on), the self connection has weight one and the memory cell writes its contents into itself. When the keep gate outputs a zero, the memory cells value gets multiplied by zero. Hence, it forgets its previous contents. The write gate allows the rest of the neural net to write into the memory cell when it outputs a 1 while the read gate allows the rest of the neural net to read from the memory cell when it outputs a 1.
LSTM: Step by step example
So how exactly does this force a constant error flow through time to locally protect against exploding and vanishing gradients? To visualize this, let’s unroll the LSTM unit through time:
Unrolling the LSTM unit through the time domain
At first, the keep gate is set to 0 and the write gate is set to 1, which places 4.2 into the memory cell. This value is retained in the memory cell by a subsequent keep value of 1 and protected from read/write by values of 0. Finally, the cell is read and then cleared. Now we try to follow the backpropagation from the point of loading 4.2 into the memory cell to the point of reading 4.2 from the cell and its subsequent clearing. We realize that due to the linear nature of the memory neuron, the error derivative that we receive from the read point backpropagates with negligible change until the write point because the weights of the connections connecting the memory cell through all the time layers have weights approximately equal to 1 (approximate because of the logistic output of the keep gate). As a result, we can locally preserve the error derivatives over hundreds of steps without having to worry about exploding or vanishing gradients.
Example: LSTMs applied to hand-writing recognition
In the video below, you can see LSTMs in action, successfully reading cursive handwriting (research work done and video created by Alex Graves).
Explanation
- Row 1: Shows when the letters are recognized
- Row 2: Shows the states of some of the memory cells (Notice how they get reset when a character is recognized!)
- Row 3: Shows the writing as it’s being analyzed by the LSTM RNN
- Row 4: Shows the gradient backpropagated to the inputs from the most active character of the upper soft-max layer. Helpful for analyzing which data points are providing the most influence on your current decision for the character.
The LSTM RNN does quite well, and it’s been applied in lots of other places as well. Deep architectures for LSTM RNNs have also been used for data compression.If you are interested in diving deeper, check out this paper.
Other approaches to improve RNN training
Another approach to improve training of RNNs is to use optimizers that can deal with exploding and vanishing gradients (called 2nd order optimizers, or Hessian-free optimizers). They try to detect directions with a small gradient, but even smaller curvature.
A third approach involves a very careful initialization of the weights in hopes that it will allow us to avoid the problem of exploding and vanishing gradients in the first place (e.g. echo state networks and momentum based approaches).
Conclusion
The LSTM RNN architecture is one of the most successful neural network architectures for various applications in natural language processing and speech transcription.