Running frontier LLMs is slow, but that's the tradeoff we make to get more intelligent output. But if you think about the text (tokens) that an LLM produces (or that a human produces, for that matter), you might have an intuition that a lot of it does not actually require high intelligence. There's a lot of "filler" in language to make things work syntactically that is easy to predict compared to, for example, the crux of an argument. This is evidenced by Shannon's old experiments in which he had subjects predict the next letter in a text and they got it right on the first try 75% of the time. The actual information content of most text is concentrated in a few key words. This property is doubly true of programming languages, where syntax makes up a huge portion of the content.
So a natural question one might ask is -- can we make LLMs faster at producing "easy" tokens vs "hard" tokens? This thinking is the motivation behind the idea of speculative decoding, which is a technique for speeding up LLM inference. It's a remarkably simple idea that requires just a bit of math to make it work.
Suppose we have a model $M_p$ which produces distribution $p(x)$ that we sample from. Assuming some input context $X_c$, then the processes of generating tokens from $M_p$ looks like:
This is the standard autoregressive model sampling process where each token depends on the previous tokens. It's also one of the fundamental reasons why LLMs are hard to speed up -- there is an inherent sequentiality to how they produce tokens. But what if we can quickly "guess" the tokens that the LLM is going to generate with high probability, the same way that people did in Shannon's experiment?
| Speculative decoding example. Green = speculated, red = rejected, blue = corrected. Source: Leviathan et al. |
Speculative decoding involves using a faster, independent process $M_q$ to "speculate" potential future tokens of $M_p$. Traditionally, $M_q$ is another, smaller LLM, but there are multiple options (more on this below) -- the important thing is that $M_q$ produces a distribution $q(x)$ from which we can sample tokens. Incorporating speculation into the decoding process is straightforward: have $M_q$ generate $N$ (small, e.g. 4-8) tokens at each step, run $M_p$ on $N+1$ tokens simultaneously (which is possible because we now have actual tokens to condition on), and then run a "correction process" to ensure that the output distribution matches. Assuming $M_q$ is much faster than $M_p$, we get an overall speedup because $M_p$ can decode all $N+1$ of those tokens in approximately the same amount of time that it takes to decode a single one. This is a consequence of the observation that LLM decoding is usually limited by memory bandwidth, i.e. loading all of the weights from HBM to perform the forward pass.
Let's first prove correctness, specifically that we can do this without changing the token output distribution $p(x)$. Otherwise, you end up sacrificing performance for intelligence, which is specifically the tradeoff we want to avoid. The key lies in the correction process mentioned above.
For each token $x$ that we sample from $M_q$, we look at the probability distributions of both models $p(x)$ and $q(x)$ (note that we have $p(x)$ available because we still ran $M_p$). If $q(x) \le p(x)$, then we keep the token. Otherwise, we reject the token with probability $1 - p(x) / q(x)$, which corresponds to how much more likely $M_q$ was to choose the token than $M_p$, i.e. the blue excess over the red in the diagram. When we reject a token, we then re-sample from a modified distribution that spreads the excess probability to the other tokens that have deflated probabilities in $q(x)$ vs $p(x)$.
| Token probability distributions and the correction process. The excess blue probability of token 1 gets spread out to the other tokens that have excess red probability. |
More precisely, the resulting distribution is $p'(x) = \text{norm}(\text{max}(0, p(x) - q(x)))$. Once we reject a token, all remaining tokens that we sampled from $M_q$ are invalid, as they would have depended on the rejected token which was changed. This process guarantees that, no matter what distributions $M_q$ produces, the resulting output distribution matches exactly what $M_p$ would have produced (in the degenerate case, we end up rejecting everything and only ever have $M_p$ generate tokens directly).
The remaining question then, is how to pick a good $M_q$. Based on the correction process, we see that the quality of $M_q$'s approximation of $M_p$ is the defining factor in how well it speculates. Leviathan et al analyze this formally, and it turns out that the above process produces an expected number of valid tokens $(1 - \alpha^{N+1}) / (1 - \alpha)$ where $\alpha = E(\text{min}(p, q))$ is the expected overlap between the distributions.
In practice, there are two common choices for $M_q$. The simplest one is choosing a smaller version in the same model family, e.g. Llama-3.1-8B to speculate for Llama-3.1-405B. This is effective because the same model family tends to be trained on the same data and have the same biases introduced by architecture, leading to more overlap in the output distributions. Alternatively, the approach of the Medusa paper is to fine-tune special "decoding heads" on top of an LLM that are specifically trained to do speculation and can handle multiple branching paths. This is more efficient but requires training and is therefore not universal. In practice, the improvement from speculative decoding seems to end up in the 2-3x range on real data.
| Medusa heads attached to a frozen base LLM and fine-tuned for speculation. Source: Cai et al. |
The obvious downside to speculative decoding is increased computation -- we generate additional tokens from $M_q$ without changing the amount of total work $M_p$ does (it's faster because the model does the work in parallel now). In the case where we reject everything, we would be wastefully generating $N$ tokens from $M_q$ for each one token of $M_p$. Given this, speculative decoding is not a good fit for an inference environment that is compute-bound, e.g. running on big batches, which amortizes the model loading cost across multiple generations. Instead, it's well-suited for environments that have extra resources but want to provide lower latency, or as a way to leverage excess compute during periods of low traffic. As LLMs become more prevalent in real-time tasks like live translation, interactive coding (e.g. Cursor's speculative edits), and conversational agents, the demand for faster inference will continue to increase.
No comments:
Post a Comment