Overview
Speculative decoding is a method to speed up inference from large language models like GPT. At its core, we have a strong parent model like GPT3 and a weaker child model like GPT2. We use the weaker model to propose new tokens and the stronger parent model to verify and accept them. This leads to faster inference (i.e. a greater number of tokens generated per second).
To generate a new token, models like GPT generate a probability distribution $p(x)$ over the model’s vocabulary. We then sample a token $x \sim p(x)$ and append it to our sequence of tokens. This appended sequence is then sent through the model to generate another probability distribution to sample yet another token. This process continues, which is where these models get their autoregressive generation capabilities from.
Algorithm
To sample $x \sim p(x)$ we do the following:
- sample $x \sim q(x)$
- if $\frac{p(x)}{q(x)} \geq 1$, keep it
- if $\frac{p(x)}{q(x)} < 1$ reject the sample with probability $1 - \frac{p(x)}{q(x)}$
- this can be done by sampling $r \sim U(0,1)$. If $r > \frac{p(x)}{q(x)}$, we reject, otherwise we accept.
- if we rejected the sample from previous step, sample $x$ again from an adjusted distribution $p’(x) = \text{norm}(\text{max}(0, p(x) - q(x)))$.
For any distributions $p(x)$ and $q(x)$ and $x$ sampled this way, we have $x \sim p(x)$.
One may ask, how does this speed up our generation process? It seems we need both $p(x)$ and $q(x)$ to get our next token, which actually should increase our generation time, since we need to run forward pass through $M_q$ and $M_p$ now. This analysis is correct, generating one token at a time using this algorithm actually slows down generation. However, what if we look to generate the next $\gamma + 1$ tokens? This is what Algorithm 1 from [1] describes. Here is how to do it:
- We can use $M_q$ to autoregressively generate the next $\gamma$ guesses. So, in $\gamma$ forward passes through $M_q$, we get the distributions $q_1(x), \ldots, q_{\gamma}(x)$:
- for $ i = 1 $ to $ \gamma $ do
$ q_i(x) \leftarrow M_q(\mathit{prefix} + [x_1, \ldots, x_{i-1}]) $
$ x_i \sim q_i(x) $
end for
- for $ i = 1 $ to $ \gamma $ do
- After this step, we have all the subsequences that were generated by $M_q$ in addition to the prefix itself:
- $prefix$
- $prefix + [x_1]$
- $prefix + [x_1, x_2]$
- $\quad \quad \quad \vdots$
- $prefix + [x_1, x_2, \ldots, x_{\gamma}]$
-
Now, we batch the subsequences and pass them through $M_p$ in a single forward pass. So, in a single pass, we get the distributions $p_1(x), \ldots, p_{\gamma + 1}(x)$.
- Armed with $q_1(x), \ldots, q_{\gamma}(x)$ and $p_1(x), \ldots, p_{\gamma + 1}(x)$, we can use the algorithm steps outlined above to do the accept/reject process:
- Sample $\gamma $ uniform random variables $r_1 \sim U(0, 1), \ldots, r_{\gamma} \sim U(0, 1)$.
- Take the first $n$ guesses that are accepted. $n$ can be computed by finding the first guess that is rejected and subtracting one: $n = \min\left({i - 1 \mid 1 \leq i \leq \gamma,\, r_i > \frac{p_i(x)}{q_i(x)}} \cup {\gamma}\right)$
- If all guesses are accepted ($n = \gamma$), we can the sample the last token $x_{\gamma + 1} \sim p_{\gamma + 1}(x)$ since we already computed $p_{\gamma + 1}(x)$ from the forward pass through $M_p$.
-
If the above is not the case ($n < \gamma$), we know the first guess that is rejected is $x_{n+1}$. We sample the last token $x_{n+1} \sim \text{norm}(\text{max}(0, p_{n+1}(x) - q_{n+1}(x)))$.
- The returned sequence is
- $prefix + [x_1, \ldots, x_{n}, x_{n+1}]$
References
[1] Leviathan, Yaniv, Matan Kalman, and Yossi Matias. “Fast inference from transformers via speculative decoding.” International Conference on Machine Learning. PMLR, 2023.