RNN Language Models
When we focus on making predictions based on a fixed window of context (i.e. the \(n\) previous words), in some cases, the window may not be sufficient to capture the context. For instance, consider a case where an article discusses the history of Spain and France and somewhere later in the text, it reads “The two countries went on a battle”; clearly the information presented in this sentence alone is not sufficient to identify the name of the two countries.
Out of the many neural architectures and to provide the required long memory (up to a point that is) we will use the RNN architectures as shown next.
RNN Language Model. Note the different notation and certain replacements must be made: \(W_h → W\), \(W_e \rightarrow U\), \(U → V\)
To train an RNN language model
We start with big corpus of text which is a sequence of tokens \(\mathbf x_1, ..., \mathbf x_{T}\) where T is the number of words / tokens in the corpus.
Every time step we feed one word at a time to the LSTM and compute the output probability distribution \(\mathbf{\hat y}_t\), which is, by construction, a conditional probability distribution of every word in the vocabulary given the words we have seen so far.
The loss function at time step \(t\) is the CE between the predicted probability distribution and the distribution that corresponds to the one-hot encoded next token.
\[J_t(\theta) = CE(\mathbf{\hat y}_t, \mathbf{y}_t) = - \sum_j^{|V|} \mathbf{y}_{t,j} \log \mathbf{\hat y}_{t,j} = - \log \mathbf{\hat y}_{t,j}\]
- Average all the t-step losses
\[J(\theta) = \frac{1}{T} \sum_t J_t(\theta)\]
This is visually shown in the next figure for a hypothetical example of the shown sequence of words.
RNN Language Model Training Loss. For each input word (at step t\(t\)), the RNN predicts the next word and is penalized with a loss \(J_t(\theta)\). The total loss is the average across the corpus.
In practice we don’t compute the total loss over the whole corpus but we train over a finite span and compute gradients over that span iterating on a stochastic gradient decent optimization algorithm.
Example:
Character-level language models have achieved state of the art NLP results by Facebook Research. As a simple example, lets assume the very small vocabulary {‘h’,‘e’,‘l’,‘o’} and tokens are single letters represented in the input with a one-hot encoded vector.
RNN language model example - training ref. Note that in practice instead of one-hot encoded word vectors we will have word embeddings.
Let feed into the RNN the sequence “hello”. The letters will come in one at a time, each letter going through the forward pass that produces at the output the \(\mathbf y_t\) that indicates which letter is expected to arrive next. You can see, since we are just started training, that this network is not predicting correctly - this will improve over time as the model is trained with more sequence permutations form our limited vocabulary.
During inference we will use the language model to generate the next token.
RNN language model example - generate the next token reference