Quick summary :
Neural language models typically use a softmax on top of a dot product of the hidden layer with word embeddings. This can be seen as a matrix factorization problem where the goal is to find a factorization of the conditional log probability matrix A that corresponds to the true data distribution. The authors show that to effectively express the true data distribution of natural language using the softmax based language model, the dimension d of the word embeddings must be high. Typically, these embeddings are O(100) dimensional and therefore, these language models suffer from what is called the softmax bottleneck. Using a mixture of softmaxes model is proposed as a solution to get rid of the bottleneck. This happens because the A becomes a nonlinear function of HW’ in the mixture of softmaxes scenario and becomes trainable to achieve a high rank.
The core idea :
Softmax based language models with distributed (output) word embeddings do not have enough capacity to model natural language
Why I liked reading it :
- Grounded my understanding of basic linear algebra concepts in the simple NLP task of language modeling.
- Uses theoretical math with lemmas and proofs to arrive at a state of the art model.
Detailed summary :
- Language modeling mostly relies on breaking the joint probability of a sequence of tokens into a product of conditional probabilities of the next word given a context and then modeling these conditional probabilities P(x|c). This is taught even in probability 101 but I didn’t know that it is known as an “auto regressive factorization”.
- Standard approach in neural language models (NLMs) is to use a recurrent neural network (RNN) to encode the context into a fixed length vector (also called the hidden state), take dot product with a word embedding and pass into a softmax layer to give a categorical probability distribution over the vocabulary. It sounds very similar to the model used in the word2vec paper.
- The input that goes into a softmax function is called a “logit”. In the class of models being referred to in this paper, the logit is the dot product of a hidden state and a word vector.
- One of the key concepts in this paper is the matrix A which represents the conditional log probabilities of the true data distribution. Imagine that a language has M tokens and N different contexts. For each token x and context c, we have a true probability P(x|c). Aij is simply equal to the conditional log probability of the jth token given the ith context. For the softmax case, Aij is simply equal to the logit for the ith context and jth token. Another way to think about it that I found helpful is that each row of A is the true log probability distribution over tokens for a given context.
An interesting feature is that there are an infinite number of possible matrices A that could correspond to the true data distribution because of the shift invariance the softmax function i.e. softmax(x + c) = softmax(x). The authors show that this property can be used to prove that adding arbitrary rowwise shift to A will give a matrix that also corresponds to the true data distribution.
- Our goal then is to find logits A’ij s.t. A’ corresponds to the true data distribution. The authors describe the problem as finding a matrix factorization HW for A’ where H is a Nxd matrix consisting of all the possible hidden states and W is a Mxd matrix consisting of all word embeddings. (Why is H Nxd? Why are hidden states d dimensional?)
- For natural language, language modeling is equivalent to trying to factorize the matrix A that corresponds to the true data distribution of natural language into matrices H and W. Using some linear algebra, it is possible to show that d must be greater than or equal to the rank(A). The final softmax bottleneck statement says that if d < rank(A) – 1, then the model cannot express the true data distribution of natural language.
The exact equation is probably not as important as the concept i.e. since the dimension d is typically O(10^2) whereas the rank(A) for natural language can be O(10^5).
- The most obvious fix is to increase d but this increases the number of parameters by too much leading to potential overfitting. Another one is to use a non parametric model like Ngrams but this again can lead to overfitting because of the large number of parameters.
- The proposed solution is a simple one — use a mixture of k softmaxes. It was quite hard to figure out from reading the paper why this method alleviates the bottleneck. The paper states “Because Amos is a nonlinear function (log_sum_exp) of the context vectors and word embeddings, Amos can be arbitrarily high rank. As a result, MoS does not suffer from the rank limitation, compared to Softmax.”
Because of my lack of linear algebra skills, I was unable to figure out why this is true and didn’t find any answers online. I eventually emailed the authors to clarify this point and Zihang Dai was generous enough to respond :
“As you may know, a linear operation (matrix multiplication) does not change the rank. In contrast, a non-linear operation can (has the capacity) change (not necessarily increase) the rank of a matrix.
However, it is not guaranteed that every non-linear operation will increase the rank. In other words, log_sum_exp may increase the rank of some matrices, but not others.
However, remember that
* (1) the inputs we give to log_sum_exp in MoS are trainable
* (2) the output of log_sum_exp will have better performance if it has a higher rank
* (3) log_sum_exp has the capacity to induce a higher rank for some matrices
Puting these points together, it means MoS has the capacity to induce a higher rank log probability matrix ( i.e. A ^ MoS), and it can be trained to exploit this advantage.”
- First of all — just wanted to call out how cool it is that he wrote such a thorough response. While this response cleared a lot for me, it’s still not quite clear to me how nonlinear operations can change the rank but linear operations can’t. I’m guessing it has something to do with linear operations not having any affect on the linear dependence between vectors of the matrix.
- I did not dive deep into the results but here’s a bird’s eye view — this technique is able to improve the state of the art for language modeling on the Penn Treebank and Wikitext-2 datasets. The authors also empirically show that the A matrix obtained using MoS is indeed high rank and that as rank increases, performance improves.
Scratchpad (i.e. random thoughts/questions/comments) :