Dissecting GPT-2
[Sorry for the delay. I traveled to India for my brother’s wedding but it had to be postponed because my parents got covid-19. The last 12 days we spent keeping them in isolation, and they are fine now!]
Side thought: The whole system of education is so confusing because it is hard to determine when you should transition from just consuming to start taking the lead. I will go with my gut feeling that says “as soon as you learn anything.”
Now back to the topic. GPT is insanely famous in the ML world. I am not the best judge of the novelty of the techniques in it, but I initially approached it as though every bit in it was new but that is not completely true.
These are the building blocks that constitute it:
Encoder-Decoder structure. Remember that I covered Transformer last time? Because of the popularity of that paper, I was led to believe that encoder-decoder was their breakthrough. But that architecture has been around for sometime (and I knew it but had forgotten!! — old problem). LSTMs do the same thing (that is the only other model I have read in small detail). Anyway, the innovation of models is mostly in the circuits within the individual blocks — encoder and decoder in this case.
Attention. Again, it has been around for some time.
Attention supposedly highlights the context of the sentence that is important for the current word being processed. So, if we have a sentence “Manu thinks he is a superstar,” here ‘he’ refers to Manu and attention block’s ideal output should give most weight to Manu when input is ‘he’ (input comes in the form of word/token embedding + positional encoding which are respectively absolute information about the word and relative information of the word to the text around it. Positional encoding can take into account multiple sentences depending on how big we have allowed our context size to be during training. GPT sets it to be 768 in the smallest model to 1600 in the largest. Average sentence length in English is 20 words.)
How is this done?
We train magic-like vectors Q, K, V in an attention block.
Now comes the part where I have interpreted these matrices in terms that makes it easier for me to think and visualize what is going on.
Remember word embeddings? Think of them as vectors in a large dimensional space where representing each word. The main feature is that the position of the words in the space tells us about the relationship with other words. We can think of dimensions of the space as different features of a word: adjective, male/female, verb, article, … and so on. The famed example of VKing — VMan + VWoman = VQueen tells us that a word can be represented as a linear combination of vectors of the words “nearby.”
Now, when we want to represent a word such that it also carries hints of what has preceded it, we want some sort of linear combination of the words that are related to it.
Example, if we have a sentence, “Memories are like mulligatawny soup in a cheap restaurant. It is best not to stir them.”
Here, “stir” is influenced more by “soup” than “restaurant” or “cheap.” To capture this relation, we train two sets of vectors for each word (Q and K): so that we can take dot product and get some sort of relevancy score between two words.
Then once we have a relevancy score, we need yet another third vector (V), which will actually be used in the linear combination equation to get the vector for the current word. This will be the output: a “word-vector” of the current word which has paid “attention” to the past words.
For more details, and a different spin on this, please refer to Jay Alammar’s post as always.
Putting it all together for GPT-2:
Once we have this attention-vector, we will compare it with all the words in the vocabulary to see which one is the closest match. How do we do it? We train yet another matrix of the size of (vocab size) X (vector size of the output vector above) which is supposed to contain the vector expected. We take our output vector and take a dot product with each one of the vectors in this huge matrix: each giving a score for each word in vocabulary. The max one is the closest word.
Note that this setup also gives us a hint to how it is trained: It consists of only decoders stacked atop each other. The last matrix is one which gives us the vector for the next word. While training, we already know the next word. So, it is easy to see back propagation from here since it is all matrix multiplication.
PS: I am sort of stuck because I am visiting home and forgot to bring my iPad to treat all of us with amazing visuals to go with the blogpost. But I will resolve this soon and update the post soon with my art work!