Malav Patel

GPT2 Walkthrough

In this note, we follow the path of input data through the gpt2 architecture to better understand how the transformer works. First let us show an overview of the model.

Overview

The content below outlines how data is processed by gpt2. The pseudocode for a forward pass of GPT2 is shown below.

FORWARD PASS

def forward(tokens):
    """
    T : current sequence length
    C : embedding dimension          (default 768)
    L : number of transformer blocks (default 12)
    V : vocabulary size              (default 50257)
    maxT : max sequence length       (default 1024)
    """
    
    x = encoder(tokens) # (1, T) -> (T, C)
    residual = x        # store residual

    for i in range(L):
        x = layernorm(x, scale_1[i], shift_1[i]) # (T, C) -> (T, C)

        ##### Attention Block #################################
        x = matmul(x, W_1[i], bias_1[i]) # (T, C) -> (T, 3C)
        x = MHAttention(x)               # (T, 3C) -> (T, C)
        x = matmul(x, W_2[i], bias_2[i]) # (T, C) -> (T, C)
        #######################################################

        x = x + residual # add residual
        residual = x     # store new residual

        x = layernorm(x, scale_2[i], shift_2[i]) # (T, C) -> (T, C)

        #### FeedForward Block ################################
        x = matmul(x, W_3[i], bias_3[i]) # (T, C) -> (T, 4C)
        x = GELU(x)                      # (T, 4C) -> (T, 4C)
        x = matmul(x, W_4[i], bias_4[i]) # (T, 4C) -> (T, C)
        #######################################################

        x = x + residual # add residual
        residual = x     # store new residual

    x = layernorm(x, scale_f, shift_f) # (T, C) -> (T, C)
    x = matmul(x, wte.T)               # (T, C) -> (T, V)
    x = softmax(x)                     # (T, V) -> (T, V)

    return x

Here is the full architecture as a diagram:

Encoder
Encoder
+
+
\([1, T] \to [T...
+
+
\([V,...
\([\tex...
Attention
Attention
\([T, 3C] \to [...
Matmul
Matmul
\([C,...
\([T, C] \to [T...
Matmul
Matmul
\([C, 4C]\)
\([T, C] \to [T...
Matmul
Matmul
\([4C...
\([T, 4C] \to [...
Layernorm
Layernorm
\([T, C] \to [T...
Layernorm
Layernorm
\([T, C] \to [T...
12 x
12 x
Layernorm
Layernorm
\([T, C] \to [T...
Unembedding
Unembedding
\([T, C] \to [T...
\([V,...
Softmax
Softmax
\([T, V] \to [T...
Matrix of shape
Matrix \(\boldsymbol{P}\) of shape \([T,...
the probability of -th token in the sequence having token id  
\(\boldsymbol{P}_{ij}\ :=\) the probability of \((i + 1)\)-th token in the sequence having...
Matmul
Matmul
\([C, 3C...
\([T, C] \to [T...
GeLU
GeLU
\([T, 4C] \to [...
Text is not SVG - cannot display

Forward Pass

Here we will outline some of the main architectural blocks involved in the forward pass shown above. We begin with a length $T$ array of integers, called tokens, which represent a string of text. As a reminder, $T$ is the sequence length.

TODO: insert image of string and corresponding tokens

Encoder $(1, T) \to (T, C)$

The encoder has two sets of parameters, the weight token embedding matrix and the positional embedding matrix. The weight token embedding matrix has shape ($V$, $C$) and the positional embedding matrix has shape ($\text{max}T$, $C$).

Generally, if the $j$-th token in the sequence has token id $i$, then we add the $j$-th row of the positional embedding matrix to the $i$-th row of the weight token embedding matrix. Pseudocode is presented below.

ENCODER

INPUT: u   # a [1, T] vector of tokens
       wte # weight token embedding matrix
       wpe # weight positional embedding matrix
       
OUTPUT: out # [T, C] matrix of embedded tokens

for i, token in enumerate(u):
    e = wte[token, :] + wpe[i, :]  # embedded vector for token at position i
    out[i, :] = e                  # store result in the output

For example, if the first token in the input sequence is 31373, then we would extract the 31373-th row from the weight token embedding matrix and the first row of the positional embedding matrix and add them together, placing the resulting row vector in the first row of the output. A diagram of the process is shown below.

max
max\(T\)
\(C\)
0
0
1
1
2
2
0
0
1
1
2
2
Weight Position Embedding Matrix
Weight Position Embed...
Weight Token Embedding Matrix
Weight Token Embeddin...
313733137431372
Output
Output
\(T\)
\(C\)
\(C\)
\(V\)
\(\boldsymbol{W}_{\t...
Text is not SVG - cannot display

LayerNorm $(T, C) \to (T, C)$

The layernorm operation defines a transformation that first normalizes each row in the input tensor to zero mean and unit variance and then applies a scale and shift operation to each element. Since there are $C$ elements in a row, there are $C$ scale parameters and $C$ shift parameters to learn.

The pseudocode below shows how to apply the layernorm operation to a row of the input.

LAYERNORM

INPUTS : row
OUTPUTS: output

m = mean(row) # calculate mean of row
std = sqrt(var(row)) # calculate std of row

normalized = (row - m) / std # normalize to zero mean and unit variance

for i = 0 to length(row):
    output[i] = normalized[i] * scale[i] + shift[i]

Below is a diagram showing how a single row from the tensor of shape $(T, C)$ is transformed by the layernorm. Note that this operation is done for all the rows in the tensor.

0.12
0.12
-0.78
-0.78
-0.65
-0.65
-0.07
-0.07
0.55
0.55
... ... ... ... ... 
... ... ... ... ... 
-0.11
-0.11
9.76
9.76
7.12
7.12
-0.46
-0.46
2.12
2.12
... ... ... ... ... 
... ... ... ... ... 
Normalize to zero mean, unit variance
Normalize to zero mean, unit variance
0.43
0.43
0.66
0.66
-0.14
-0.14
-0.98
-0.98
0.48
0.48
... ... ... ... ... 
... ... ... ... ... 
Scale Parameters
Scale Parameters
0.05
0.05
-0.51
-0.51
-0.09
-0.09
0.07
0.07
0.26
0.26
... ... ... ... ... 
... ... ... ... ... 
0.1
0.1
0.33
0.33
0.51
0.51
-0.2
-0.2
0.82
0.82
... ... ... ... ... 
... ... ... ... ... 
Shift Parameters
Shift Parameters
0.15
0.15
-0.18
-0.18
0.42
0.42
-0.13
-0.13
1.08
1.08
... ... ... ... ... 
... ... ... ... ... 
Text is not SVG - cannot display

Multi Head Attention Block

The following block is most commonly referred to as the multi head attention block. This block is split into 3 subblocks. The first is a linear projection of the queries, keys, and values. The second subblock is where the actual attention mechanism is applied. The third subblock is another linear projection into the latent space. We go through each of these subblocks below.

Matmul $(T, C) \to (T, 3C)$

Following the layernorm is a matrix multiplication operation that creates query, key, and value vectors for each token in the sequence. The following diagram shows how the $\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}$ matrices are created for an input $\boldsymbol{A}$ of shape $(T, C)$.

\(\boldsymbol...
\(\boldsymbol...
\(\boldsymbol...
\(3C\)
\(T\)
\(\boldsymbol...
\(T\)
\(C\)
\(\boldsymbol{W}_1\)
\(C\)
\(3C\)
\(\vdots\)
\(T\)
\(3C\)
The matrix multiplication operation that transforms the input into the matrices used for Attention
The matrix multiplication operation that transforms the input \(\boldsymbol{A}\) into the \(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}\) matrices...
bias
bias\(_1\)
bias
bias\(_1\)
Parameters
Parameters
Output and input activations
Output and input activations
Text is not SVG - cannot display

Scaled Dot Product Attention $(T, 3C) \to (T, C)$

The data is now ready to be processed by the attention mechanism. First the $\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}$ matrices are partioned into $H$ submatrices, where $H$ is the number of attention heads. The trio of submatrices $\boldsymbol{Q}_h, \boldsymbol{K}_h, \boldsymbol{V}_h$ for $h = 1 \ldots H$, is processed by the $h$-th attention head. A diagram outlining the partition and routing of submatrices is shown below.

\(\cdots\)
\(\cdots\)
\(\cdots\)
\(\cdots\)
\(\bo...
\(\bo...
\(\bo...
\(\bo...
\(\bo...
\(\bo...
\(\bo...
\(\bo...
\(\bo...
Attention Head 1
Attention H...
Attention Head 2
Attention H...
\(\cdots\)
Attention Head H
Attention H...
is partitioned into 
submatrices
\(\boldsymbol{Q}\) is partitioned in...
is partitioned into 
submatrices
\(\boldsymbol{K}\) is partitioned in...
is partitioned into 
submatrices
\(\boldsymbol{V}\) is partitioned in...
\(T\)
\(C\)
Text is not SVG - cannot display

Now we can take a closer look at what happens at the $h$-th attention head. The inputs are the trio of submatrices $\boldsymbol{Q}_h, \boldsymbol{K}_h, \boldsymbol{V}_h$. We will focus on what happens to the $t$-th token (i.e. how the $t$-th row of this attention head’s output is constructed). The first step is to gather the query vector for the $t$-th token and the key vectors for the first $t$ tokens. Next, we take the dot product between the query vector and each of the key vectors, storing the result in a buffer. Then we normalize the buffer to account for the length of the vectors in the dot product. Then we apply a softmax operation. The resulting array represents the “weights” or the importance of previous tokens to the current token. We scale the corresponding value vectors by these “weights” and sum the resulting vectors. This gives the output of the $h$-th attention head for the $t$-th token. The diagram below illustrates this.

\(\boldsymbol...
\(\...
\(\cdots\)
\(\...
\(\...
\(\frac{C}{H}\)
\(\frac{C}{H}\)
\(t\)
\(\cdots\)
\(t\)
\(\frac{1}...
1. Take the dot product between the query vector for the -th token and the key vectors for the first tokens, scaling the result by a factor of  
1. Take the dot product between the query vector for the \(t\)-th token and the key vectors for the first \(t\) tokens, scaling the result by a factor of \(...
0.23
0.23
\(\cdots\)
0.02
0.02
0.08
0.08
\(t\)
2. Apply Softmax to elements. Interpret each entry of output as a "weight" or importance of the corresponding token
2. Apply Softmax to elements. Interpret each entry of output as a "weight" or importance of the correspon...
0.23
0.23
0.08
0.08
0.02
0.02
\(\vdots\)
0.41
0.41
0.63
0.63
-0.19
-0.19
-1.36
-1.36
0.40
0.40
\(\cdots\)
\(\frac{C}{H}\)
3. Use the "weights" to scale the value vectors for the corresponding tokens. Add the value vectors together to produce output of attention for -th token
3. Use the "weights" to scale the value vectors for the corresponding tokens. Add the value vectors together to produce output of attention...
\(\Bigg)\)
\(\cdots\)
softmax
softmax\(\Bigg(\)
\(\boldsymbol...
\(\boldsymbol...
\(\boldsymbol...
=
\(\boldsymbol{Q}_h\) =
\(\boldsymbol...
\(\boldsymbol...
\(\boldsymbol...
\(\vdots\)
=
\(\boldsymbol{K}_h\) =
\(\boldsymbol...
\(\boldsymbol...
\(\boldsymbol...
\(\vdots\)
=
\(\boldsymbol{V}_h\) =
\(\boldsymbol...
\(\boldsymbol...
\(\boldsymbol...
\(\vdots\)
Attention
 Head
Attention...
\(T\)
\(\frac{C}{H}\)
For :
For \(t = 1\  \ldots \...
do some calculation...
do some calculation...
End For
End For
-th row of output
\(t\)-th row of output
Text is not SVG - cannot display

To recap, the first diagram shows how the input $\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}$ matrices are partioned into submatrices and processed by the attention heads to produce an output of shape $(T, C)$. The second diagram shows how one attention head processes a trio of query, key, value submatrices. This diagram shows that the attention layer is the only layer where information is mixed across the tokens (i.e. $t = 1 \ldots T$). In other words, this layer is where tokens can “talk” to each other and feel each other’s presence in the sequence.

Matmul $(T, C) \to (T, C)$

The next step is to transform the the results of the attention mechanism by a linear projection (i.e. matrix multiplication and bias). This is a simple transformation. A diagram is shown below.

\(C\)
\(T\)
\(C\)
\(T\)
\(\boldsymbol{W}_2...
\(C\)
\(C\)
\(\vdots\)
\(T\)
\(C\)
bias
bias\(_2\)
bias
bias\(_2\)
Output and input activations
Output and input activations
Parameters
Parameters
\(\cdots\)
Text is not SVG - cannot display

FeedForward Block

This block is a standard 2-layer MLP with a GELU activation in between the layers. It consists of a linear projection, a GELU operation, followed by a another linear projection. Each subblock is shown below.

Matmul $(T, C) \to (T, 4C)$

The first is a matmul block which applies a simple linear transformation to the input, projecting it to a higher dimensional space: a $C$-dimensional vector is mapped to a $4C$-dimensional vector. A diagram is shown below.

\(4C\)
\(T\)
\(C\)
\(T\)
\(\boldsymbol{W}_3\)
\(4C\)
\(C\)
\(\vdots\)
\(T\)
\(4C\)
bias
bias\(_3\)
bias
bias\(_3\)
Output and input activations
Output and input activations
Parameters
Parameters
Text is not SVG - cannot display

GeLU $(T, 4C) \to (T, 4C)$

Following this matmul we have a nonlinearity that is applied element-wise to the data. This is necessary because the next block is another linear transformation. Without this nonlinearity in between, the two linear transformations in this FeedForward block would be equivalent to a single composed linear transformation. We use an approximation to the GeLU function for computational speed. The approximation applies the following nonlinear transformation:

\[\text{GeLU}(x) \approx \frac{1}{2}x\left[1+\tanh\left(\sqrt{\frac{2}{\pi}}\left(x + 0.044715x^{3}\right)\right)\right]\]

Matmul $(T, 4C) \to (T, C)$

Following the nonlinearity is another linear transformation. This transformation projects the input back into the original latent space. In other words, it transforms a $4C$-dimensional vector into a $C$-dimensional vector. A diagram is shown below.

\(4C\)
\(T\)
\(C\)
\(T\)
\(\boldsymbol{W}_4...
\(4C\)
\(C\)
\(\vdots\)
\(T\)
\(C\)
bias
bias\(_4\)
bias
bias\(_4\)
Output and input activations
Output and input activations
Parameters
Parameters
Text is not SVG - cannot display

Unembedding $(T, C) \to (T, V)$

The unembedding is a matrix multiplication with the weight token embedding matrix from the encoder. This is done to incorporate some implicit regularization: the matrix responsible for encoding the token sequence into the latent space is also responsible for the decoding. Note, there is no bias parameter in this block. A diagram is shown below.

\(V\)
\(T\)
\(C\)
\(T\)
\(V\)
\(C\)
Output and input activations
Output and input activations
Parameters
(Weight Token Embedding Matrix Transposed)
Parameters...
\(\boldsymbol{W}_{\text{wte}}^...
Text is not SVG - cannot display

Softmax $(T, V) \to (T, V)$

The final operation is a softmax on the inputs, called logits. The softmax operation is done row-wise: each row of the input is normalized so that its entries sum to 1. Pseudocode for the operation is shown below. Note that we use a numerical stabilization trick in the code to prevent overflow/underflow errors during computation. Refer to this link to learn more about it.

INPUT: u   # a [T, V] matrix of logits
       
OUTPUT: P # [T, V] matrix where each row is a probability distribution over the vocabulary

for i in range(T):
    row = u[i, :]        # get the i-th row 
    maxval = max(row)    # get the maximum value for the i-th row

    row = row - maxval   # subtract off the maximum value from each element in the row
    row = exp(row)       # exponentiate each element of the row
    row = row / sum(row) # divide each element by the sum of all elements in the row

    P[i, :] = row      # assign the softmax'd row to the i-th row of output

We can interpret each $V$-dimensional row of the output as a probability distribution over the the vocabulary. For example, let us call the output of the softmax operation $\boldsymbol{P}$, which is a shape $(T, V)$ matrix. Then $\boldsymbol{P}_{ij}$ is the probability that the $(i+1)$-th token in the token sequence has token ID equal to $j$.