Welcome back to another exciting article in the NLP Tutorials series. It’s time to fully delve into Deep Learning for NLP! In NLP, it is very important that we remember things and retain the context very well. For example, as humans we learn progressively, word by word, sentence by sentence, the way you are doing now. We understand things by reading and thinking progressively and we need to have a memory to retain things and maintain the context for that particular task.
So far, the Artificial Neural Networks (ANNs) can’t do this; Given data and output logits they can predict/classify things. But can they take into account the flow of information? If we want to predict the next word of a sentence, we need historical data of what words the network has seen. Enter Recurrent Neural Networks (RNNs). They are capable of doing this. They form the base for all the latest architectures out there. It started with RNNs then LSTMs, GRUs, Encoder-Decoder, Seq2Seq, Attention and then Transformers! So it makes sense to learn the fundamental architecture. Let’s get started!
Recurrent Neural Networks
RNNs are similar to ANNs but take in sequential input, i.e. sequence of vectors (words/sentences for that matter). RNNs are capable of retaining and passing on information to the next layer. The way they do this is by an extra component called the internal state (memory). If we consider a sequence which goes like x(1), x(2), x(3), x(4)…..x(N), such that the output vector x(T) at the time T then memory can be defined as h(T). Question now is “How is this modelled into the architecture and how does it train?”
Network Architecture & Working Principle
As we can see in the diagram, x(0), x(1), …x(T) are sequential inputs and h(0), h(1), …h(T) are the internal states. It takes the input from x(0) and then stores h(0) as internal state (memory) for that input. In the next step the input is x(1) + h(0). This is how it carries information throughout the sequence.
In the above formula, h denotes the hidden state, Wh is the weight at previous hidden state, Wx is the weight at the current input state and tanh is the activation function. (Both the W matrices are learnable over time and they make the difference for the RNN). At output -> y(t) = W(hy).h(t)
It is trained similar to general feed forward networks but includes an extra route for storing the h(t) which is of the memory or hidden state. Let me decode it with a diagram.
Back propagation through time (BPTT)
In a typical RNN which has ‘t’ states all the h(t) memory states and x(t) inputs are involved in the final output. Back propagation is carried out throughout all the states. Remember the formula after activation of tanh, there we saw two matrices W(h) and W(x) — Both will be optimized to minimize the loss over the ‘t’ states. That is the reason it is called Back propagation through time. It will back propagate through all the t states which is like going back in time for that particular sequence input.
Are RNNs good for long sentences? The reality is “No, not that good for longer context”. Why? It suffers from a problem called vanishing gradients. Deeper the RNN, higher the chances of vanishing gradients. In RNNs, due to BPTT the contribution of weight multiplied by gradient over time reduces because this error is back propagated and in the previous time step the gradients and weights are multiplied again and this is like a chain till the initial time state. When this becomes a long chain multiplication, the gradients might get very small and the network stalls with negligible updates to the loss function. This is one of the main disadvantages of RNNs. It is tricky to train an RNN to get very good results with the data we have nowadays which are very rich and contextual (lengthier too)!
Advantages & Disadvantages of RNNs
- One-of-a-kind network which can model sequential data (Sequence Modelling). Opened up newer avenues for Time-Series modelling!
- Excellent for short-term dependencies
- Suffers from the Vanishing Gradients problem more often than not
- Not that good for long-term dependencies
RNNs are effective for sequence modelling but not good with sequences with longer dependencies. The vanishing gradient problem is persistent if we have longer sequences. This was solved by adding a few gates to prevent the gradients from vanishing; the name of that architecture is LSTM (Long Short-Term Memory). Are you already excited to know more about LSTMs which enabled us to train longer sequences with excellent outputs? Hope you understood the basic idea of sequences and RNNs which are fundamental for the next architecture we will look at in the NLP Tutorials series— LSTMs.