In this note, we follow the path of input data through the gpt2 architecture to better understand how the transformer works. First let us show an overview of the model.
Overview
The content below outlines how data is processed by gpt2. The pseudocode for a forward pass of GPT2 is shown below.
FORWARD PASS
def forward(tokens):
"""
T : current sequence length
C : embedding dimension (default 768)
L : number of transformer blocks (default 12)
V : vocabulary size (default 50257)
maxT : max sequence length (default 1024)
"""
x = encoder(tokens) # (1, T) -> (T, C)
residual = x # store residual
for i in range(L):
x = layernorm(x, scale_1[i], shift_1[i]) # (T, C) -> (T, C)
##### Attention Block #################################
x = matmul(x, W_1[i], bias_1[i]) # (T, C) -> (T, 3C)
x = MHAttention(x) # (T, 3C) -> (T, C)
x = matmul(x, W_2[i], bias_2[i]) # (T, C) -> (T, C)
#######################################################
x = x + residual # add residual
residual = x # store new residual
x = layernorm(x, scale_2[i], shift_2[i]) # (T, C) -> (T, C)
#### FeedForward Block ################################
x = matmul(x, W_3[i], bias_3[i]) # (T, C) -> (T, 4C)
x = GELU(x) # (T, 4C) -> (T, 4C)
x = matmul(x, W_4[i], bias_4[i]) # (T, 4C) -> (T, C)
#######################################################
x = x + residual # add residual
residual = x # store new residual
x = layernorm(x, scale_f, shift_f) # (T, C) -> (T, C)
x = matmul(x, wte.T) # (T, C) -> (T, V)
x = softmax(x) # (T, V) -> (T, V)
return x
Here is the full architecture as a diagram:
Forward Pass
Here we will outline some of the main architectural blocks involved in the forward pass shown above. We begin with a length $T$ array of integers, called tokens, which represent a string of text. As a reminder, $T$ is the sequence length.
TODO: insert image of string and corresponding tokens
Encoder $(1, T) \to (T, C)$
The encoder has two sets of parameters, the weight token embedding matrix and the positional embedding matrix. The weight token embedding matrix has shape ($V$, $C$) and the positional embedding matrix has shape ($\text{max}T$, $C$).
Generally, if the $j$-th token in the sequence has token id $i$, then we add the $j$-th row of the positional embedding matrix to the $i$-th row of the weight token embedding matrix. Pseudocode is presented below.
ENCODER
INPUT: u # a [1, T] vector of tokens
wte # weight token embedding matrix
wpe # weight positional embedding matrix
OUTPUT: out # [T, C] matrix of embedded tokens
for i, token in enumerate(u):
e = wte[token, :] + wpe[i, :] # embedded vector for token at position i
out[i, :] = e # store result in the output
For example, if the first token in the input sequence is 31373, then we would extract the 31373-th row from the weight token embedding matrix and the first row of the positional embedding matrix and add them together, placing the resulting row vector in the first row of the output. A diagram of the process is shown below.
LayerNorm $(T, C) \to (T, C)$
The layernorm operation defines a transformation that first normalizes each row in the input tensor to zero mean and unit variance and then applies a scale and shift operation to each element. Since there are $C$ elements in a row, there are $C$ scale parameters and $C$ shift parameters to learn.
The pseudocode below shows how to apply the layernorm operation to a row of the input.
LAYERNORM
INPUTS : row
OUTPUTS: output
m = mean(row) # calculate mean of row
std = sqrt(var(row)) # calculate std of row
normalized = (row - m) / std # normalize to zero mean and unit variance
for i = 0 to length(row):
output[i] = normalized[i] * scale[i] + shift[i]
Below is a diagram showing how a single row from the tensor of shape $(T, C)$ is transformed by the layernorm. Note that this operation is done for all the rows in the tensor.
Multi Head Attention Block
The following block is most commonly referred to as the multi head attention block. This block is split into 3 subblocks. The first is a linear projection of the queries, keys, and values. The second subblock is where the actual attention mechanism is applied. The third subblock is another linear projection into the latent space. We go through each of these subblocks below.
Matmul $(T, C) \to (T, 3C)$
Following the layernorm is a matrix multiplication operation that creates query, key, and value vectors for each token in the sequence. The following diagram shows how the $\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}$ matrices are created for an input $\boldsymbol{A}$ of shape $(T, C)$.
Scaled Dot Product Attention $(T, 3C) \to (T, C)$
The data is now ready to be processed by the attention mechanism. First the $\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}$ matrices are partioned into $H$ submatrices, where $H$ is the number of attention heads. The trio of submatrices $\boldsymbol{Q}_h, \boldsymbol{K}_h, \boldsymbol{V}_h$ for $h = 1 \ldots H$, is processed by the $h$-th attention head. A diagram outlining the partition and routing of submatrices is shown below.
Now we can take a closer look at what happens at the $h$-th attention head. The inputs are the trio of submatrices $\boldsymbol{Q}_h, \boldsymbol{K}_h, \boldsymbol{V}_h$. We will focus on what happens to the $t$-th token (i.e. how the $t$-th row of this attention head’s output is constructed). The first step is to gather the query vector for the $t$-th token and the key vectors for the first $t$ tokens. Next, we take the dot product between the query vector and each of the key vectors, storing the result in a buffer. Then we normalize the buffer to account for the length of the vectors in the dot product. Then we apply a softmax operation. The resulting array represents the “weights” or the importance of previous tokens to the current token. We scale the corresponding value vectors by these “weights” and sum the resulting vectors. This gives the output of the $h$-th attention head for the $t$-th token. The diagram below illustrates this.
To recap, the first diagram shows how the input $\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}$ matrices are partioned into submatrices and processed by the attention heads to produce an output of shape $(T, C)$. The second diagram shows how one attention head processes a trio of query, key, value submatrices. This diagram shows that the attention layer is the only layer where information is mixed across the tokens (i.e. $t = 1 \ldots T$). In other words, this layer is where tokens can “talk” to each other and feel each other’s presence in the sequence.
Matmul $(T, C) \to (T, C)$
The next step is to transform the the results of the attention mechanism by a linear projection (i.e. matrix multiplication and bias). This is a simple transformation. A diagram is shown below.
FeedForward Block
This block is a standard 2-layer MLP with a GELU activation in between the layers. It consists of a linear projection, a GELU operation, followed by a another linear projection. Each subblock is shown below.
Matmul $(T, C) \to (T, 4C)$
The first is a matmul block which applies a simple linear transformation to the input, projecting it to a higher dimensional space: a $C$-dimensional vector is mapped to a $4C$-dimensional vector. A diagram is shown below.
GeLU $(T, 4C) \to (T, 4C)$
Following this matmul we have a nonlinearity that is applied element-wise to the data. This is necessary because the next block is another linear transformation. Without this nonlinearity in between, the two linear transformations in this FeedForward block would be equivalent to a single composed linear transformation. We use an approximation to the GeLU function for computational speed. The approximation applies the following nonlinear transformation:
\[\text{GeLU}(x) \approx \frac{1}{2}x\left[1+\tanh\left(\sqrt{\frac{2}{\pi}}\left(x + 0.044715x^{3}\right)\right)\right]\]Matmul $(T, 4C) \to (T, C)$
Following the nonlinearity is another linear transformation. This transformation projects the input back into the original latent space. In other words, it transforms a $4C$-dimensional vector into a $C$-dimensional vector. A diagram is shown below.
Unembedding $(T, C) \to (T, V)$
The unembedding is a matrix multiplication with the weight token embedding matrix from the encoder. This is done to incorporate some implicit regularization: the matrix responsible for encoding the token sequence into the latent space is also responsible for the decoding. Note, there is no bias parameter in this block. A diagram is shown below.
Softmax $(T, V) \to (T, V)$
The final operation is a softmax on the inputs, called logits. The softmax operation is done row-wise: each row of the input is normalized so that its entries sum to 1. Pseudocode for the operation is shown below. Note that we use a numerical stabilization trick in the code to prevent overflow/underflow errors during computation. Refer to this link to learn more about it.
INPUT: u # a [T, V] matrix of logits
OUTPUT: P # [T, V] matrix where each row is a probability distribution over the vocabulary
for i in range(T):
row = u[i, :] # get the i-th row
maxval = max(row) # get the maximum value for the i-th row
row = row - maxval # subtract off the maximum value from each element in the row
row = exp(row) # exponentiate each element of the row
row = row / sum(row) # divide each element by the sum of all elements in the row
P[i, :] = row # assign the softmax'd row to the i-th row of output
We can interpret each $V$-dimensional row of the output as a probability distribution over the the vocabulary. For example, let us call the output of the softmax operation $\boldsymbol{P}$, which is a shape $(T, V)$ matrix. Then $\boldsymbol{P}_{ij}$ is the probability that the $(i+1)$-th token in the token sequence has token ID equal to $j$.