Fast Inference from Transformers via Speculative Decoding

Authors: Yaniv Leviathan, Matan Kalman, Yossi Matias Year: 2023 Source: arXiv 2211.17192

One-Sentence Summary

A small, fast “draft” model guesses what a large, slow model would say next, and the large model checks many guesses at once in parallel, producing identical outputs 2-3x faster by doing several steps of work in a single pass.

Problem Statement

Large language models based on the Transformer architecture (see Attention Is All You Need) generate text one token at a time. Each token – a word or word-piece – requires a full forward pass through the entire model. For a model like T5-XXL with 11 billion parameters, or LaMDA with 137 billion parameters, each forward pass is expensive. Generating a sentence of 50 tokens means 50 sequential forward passes, and there is no way around this serial dependency: token 5 depends on token 4, which depends on token 3, and so on.

What makes this especially frustrating is that the bottleneck is often not the arithmetic itself. Modern accelerators like TPUs and GPUs have enormous computational throughput, but they frequently sit idle waiting for data to move between memory and processing units. This is the “memory bandwidth bottleneck”: the model’s weights must be read from memory for every single token, and those reads are slow relative to the actual math. The hardware has spare compute capacity, but the sequential nature of autoregressive decoding (generating one token at a time, left to right) prevents using it.

Prior approaches to speed up inference fell into two categories. Some reduced the cost uniformly for all tokens through techniques like distillation (training a smaller model to mimic the large one), quantization (using lower-precision arithmetic), or architecture changes. Others used adaptive computation, where easier tokens get less processing. Both categories typically required retraining the model, modifying its architecture, or accepting different outputs. None offered a way to accelerate an existing, deployed model while guaranteeing that its output distribution stays exactly the same.

Key Innovation

Imagine you have a slow but meticulous editor who reviews every sentence of a document. You also have a fast but less reliable assistant who can draft text quickly. Instead of waiting for the editor to write one word at a time, you have the assistant draft several words ahead. The editor then reviews the entire draft in one pass, accepting the words the assistant got right and correcting the first mistake. This is faster because the editor’s single review covers multiple words at once, and the accepted words skip what would have been expensive sequential steps.

This is speculative decoding. The “assistant” is a small, cheap approximation model \(M_q\) (for example, a 77-million-parameter T5-small). The “editor” is the large target model \(M_p\) (for example, the 11-billion-parameter T5-XXL). The approximation model generates \(\gamma\) draft tokens autoregressively (one at a time, since it is fast, this is cheap). Then the target model evaluates all \(\gamma\) draft positions in a single batched forward pass – this is the key: the target model reads its weights from memory once and processes all positions simultaneously, converting a memory-bandwidth-limited operation into a compute-limited one.

The paper’s deepest contribution is speculative sampling, a novel acceptance/rejection scheme that guarantees the output distribution is identical to sampling from the target model alone. This is not an approximation – the mathematical guarantee is exact. If the draft token’s probability under the small model \(q(x)\) does not exceed its probability under the large model \(p(x)\), the token is always accepted. If \(q(x) > p(x)\), the token is rejected with probability \(1 - p(x)/q(x)\), and a corrected token is drawn from a residual distribution \(p'(x) = \text{norm}(\max(0, p(x) - q(x)))\). This scheme is strictly better than classical rejection sampling because rejected drafts still produce a correction token, so every iteration of the algorithm outputs at least one new token and potentially up to \(\gamma + 1\).

Architecture / Method

The method works with any autoregressive model and requires no changes to model architecture, weights, or training. It needs only two things: the target model \(M_p\) and a faster approximation model \(M_q\) for the same task.

Each iteration of the speculative decoding algorithm proceeds in three stages:

Stage 1 – Draft. The approximation model \(M_q\) generates \(\gamma\) candidate tokens autoregressively. For each position \(i\) from 1 to \(\gamma\), \(M_q\) runs on the prefix extended by the previous draft tokens and samples \(x_i \sim q_i(x)\). Since \(M_q\) is roughly 100x smaller than \(M_p\), this is fast (the cost coefficient \(c\), defined as the ratio of \(M_q\)’s runtime to \(M_p\)’s runtime, is typically below 0.05).

Stage 2 – Verify. The target model \(M_p\) evaluates all \(\gamma + 1\) positions in parallel: the original prefix plus each prefix extended by 1, 2, …, \(\gamma\) draft tokens. This produces probability distributions \(p_1(x), p_2(x), \ldots, p_{\gamma+1}(x)\). Because these positions are independent once the tokens are fixed, they can be batched into a single forward pass through \(M_p\), reading the model weights from memory only once.

Stage 3 – Accept/Reject. For each draft position \(i\) from 1 to \(\gamma\), draw a uniform random number \(r_i \sim U(0,1)\). Accept draft token \(x_i\) if \(r_i \leq p_i(x_i) / q_i(x_i)\). The first rejected position \(n\) terminates the chain: all tokens \(x_1, \ldots, x_n\) before it are kept, and a correction token \(t\) is sampled from the adjusted distribution \(p'(x) = \text{norm}(\max(0, p_{n+1}(x) - q_{n+1}(x)))\). If all \(\gamma\) drafts are accepted, \(t\) is sampled directly from \(p_{\gamma+1}(x)\), yielding \(\gamma + 1\) new tokens.

Speculative decoding illustrated for unconditional language generation

Figure 1: A worked example of speculative decoding generating a news sentence. Each row is one iteration of the algorithm. Green tokens are draft suggestions from the small model (6M parameters) that the large model (97M parameters) accepted. Red tokens are rejected drafts, and blue tokens are the corrections sampled from the adjusted distribution. In the first row, the large model ran once and 5 tokens were produced. The entire 38-token sentence required only 9 serial runs of the large model instead of 38.

Trace diagram comparing speculative decoding to standard decoding

Figure 5: A simplified execution trace for an encoder-decoder Transformer. Top row: speculative decoding with gamma=7, where each call to the large model (purple) is preceded by 7 fast calls to the small model (blue). Middle row: gamma=3. Bottom row: standard decoding with no speculation. The yellow and orange blocks on the left represent encoder passes. Speculative decoding completes in fewer large-model calls, reducing total walltime despite more total computation.

The algorithm also handles different sampling strategies uniformly. Methods like argmax (greedy), top-k, nucleus sampling, and temperature scaling are all cast into “standardized sampling” by first adjusting the probability distribution (for example, argmax zeros out all non-maximum probabilities and renormalizes), then sampling from that adjusted distribution. This means speculative decoding works with any decoding strategy without modification.

A key design choice is the number of draft tokens \(\gamma\). Larger \(\gamma\) means more potential tokens per iteration but also more wasted computation when drafts are rejected. The optimal \(\gamma\) depends on the acceptance rate \(\alpha\) (how well \(M_q\) matches \(M_p\)) and the cost coefficient \(c\). For typical values in the paper’s experiments (\(\alpha\) between 0.5 and 0.9, \(c\) around 0.02), optimal \(\gamma\) ranges from 3 to 10.

Mathematical Foundations

Expected Tokens per Iteration (Equation 1)

\[E(\#\ generated\ tokens) = \frac{1 - \alpha^{\gamma + 1}}{1 - \alpha}\]

This is the formula for the expected value of a capped geometric distribution. Under the simplifying assumption that acceptance decisions are independent and identically distributed with success probability \(1 - \alpha\) (where “success” means a rejection that stops the chain), the number of accepted tokens before the first rejection follows a geometric distribution, capped at \(\gamma\). Plus one correction token is always added. When \(\alpha = 0.8\) and \(\gamma = 5\), this gives \(E = (1 - 0.8^6) / (1 - 0.8) = (1 - 0.262) / 0.2 = 3.69\) tokens per iteration – a 3.69x reduction in the number of serial calls to \(M_p\).

This equation matters because it quantifies the core tradeoff: a better approximation model (higher \(\alpha\)) or more draft tokens (higher \(\gamma\)) yields more tokens per iteration, but with diminishing returns as \(\alpha^{\gamma+1}\) approaches zero.

Acceptance Rate (Theorem 3.5 and Corollary 3.6)

\[\beta = \sum_x \min(p(x), q(x))\]

\[\alpha = E(\beta) = E\left(\sum_x \min(p(x), q(x))\right)\]

The acceptance rate \(\beta\) equals the sum of the pointwise minimum of the two distributions. Intuitively, this measures how much the two distributions overlap. If \(M_q\) and \(M_p\) assign similar probabilities to the same tokens, \(\min(p(x), q(x))\) is large for most \(x\) and \(\beta\) is close to 1. If the distributions disagree, \(\beta\) is low.

This result also connects to a divergence measure defined in the paper: \(D_{LK}(p, q) = 1 - \sum_x \min(p(x), q(x))\), so \(\beta = 1 - D_{LK}(p, q)\). This divergence is symmetric (unlike KL divergence), bounded in \([0, 1]\), and equals zero if and only if \(p = q\).

Walltime Improvement Factor (Theorem 3.8)

\[\text{Speedup} = \frac{1 - \alpha^{\gamma + 1}}{(1 - \alpha)(\gamma c + 1)}\]

This equation divides the expected tokens per iteration by the cost per iteration. The numerator captures how many tokens we generate; the denominator captures how long it takes. When \(c\) is close to zero (the approximation model is negligible cost), the speedup approaches the raw token generation formula. When \(c\) is larger, more draft tokens means more time spent on \(M_q\), reducing the benefit.

For a concrete example: with \(\alpha = 0.75\), \(\gamma = 7\), and \(c = 0.02\) (the T5-small/T5-XXL translation configuration from the paper), the predicted speedup is \((1 - 0.75^8) / ((1 - 0.75)(7 \times 0.02 + 1)) = (1 - 0.1) / (0.25 \times 1.14) = 0.9 / 0.285 = 3.16\)x, which matches the empirically measured 3.4x within the expected variance.

Condition for Any Improvement (Corollary 3.9)

\[\alpha > c \implies \exists\ \gamma : \text{Speedup} > 1, \quad \text{and the speedup is at least } \frac{1 + \alpha}{1 + c}\]

This is the minimum viability condition. If the acceptance rate \(\alpha\) exceeds the cost coefficient \(c\), speculative decoding is guaranteed to help. Since \(c\) is typically below 0.05 and \(\alpha\) is typically above 0.5 for reasonable approximation models, this condition is satisfied in virtually all practical scenarios. Even a trivial bigram model (\(\alpha \approx 0.2\), \(c \approx 0\)) meets the threshold and yields a 1.25x speedup.

Results

The authors tested speculative decoding on T5-XXL (11 billion parameters) with three approximation models: T5-small (77M), T5-base (250M), and T5-large (800M). Two tasks were evaluated: English-to-German translation (WMT EnDe) and text summarization (CNN/DailyMail). All experiments used a batch size of 1 on a single TPU-v4.

Task \(M_q\) Temp \(\gamma\) \(\alpha\) Speedup
EnDe T5-small (77M) 0 7 0.75 3.4x
EnDe T5-base (250M) 0 7 0.80 2.8x
EnDe T5-large (800M) 0 7 0.82 1.7x
EnDe T5-small (77M) 1 7 0.62 2.6x
EnDe T5-base (250M) 1 5 0.68 2.4x
EnDe T5-large (800M) 1 3 0.71 1.4x
CNNDM T5-small (77M) 0 5 0.65 3.1x
CNNDM T5-small (77M) 1 5 0.53 2.3x

The striking finding is that the smallest approximation model (T5-small, 77M) consistently produces the best speedups despite having a lower acceptance rate \(\alpha\) than larger approximation models. This happens because T5-small has a much lower cost coefficient \(c\) (approximately 0.02 vs. 0.11 for T5-large), so the walltime cost of generating drafts is negligible. The 150x parameter ratio between T5-XXL and T5-small hits the sweet spot between acceptance quality and draft cost.

Argmax sampling (temperature=0) yields consistently higher speedups than standard sampling (temperature=1). This makes sense: when the target distribution is sharper (concentrated on fewer tokens), the approximation model’s top predictions are more likely to match, increasing \(\alpha\). The empirical results closely match the theoretical predictions from Theorem 3.8, with differences attributable to implementation overhead and the simplifying i.i.d. assumption.

The authors also measured \(\alpha\) values across model families beyond T5, including a 97M-parameter GPT-like decoder and a 137B-parameter LaMDA model. Approximation models roughly 100x smaller consistently produced \(\alpha\) values between 0.5 and 0.9. Even trivial unigram and bigram models yielded non-zero \(\alpha\) values (0.03-0.23), providing modest but free speedups since their cost is effectively zero.

Limitations

Impact and Legacy

Speculative decoding became one of the most widely adopted inference optimization techniques in the years following this paper. Its key advantage – producing mathematically identical outputs with no model changes – made it safe to deploy in production systems where output quality is non-negotiable. By 2024, speculative decoding was integrated into major serving frameworks including vLLM, TensorRT-LLM, and HuggingFace’s text-generation-inference.

The paper spawned a rich line of follow-up research. Contemporaneously, Chen et al. (2023) independently discovered the same technique and demonstrated 2-2.5x speedups on Chinchilla 70B. Subsequent work explored training custom draft models optimized for high \(\alpha\), using non-autoregressive draft models, self-speculative decoding (where a model’s early layers serve as the draft), tree-structured speculation (verifying multiple draft continuations simultaneously), and online adaptation of \(\gamma\) during generation.

The core insight – that a cheap approximation can be verified more efficiently than computed from scratch – generalized beyond decoding. The paper explicitly notes that stochastic speculative execution applies wherever a fast function approximates a slow one’s output distribution, suggesting applications in physics simulations and reinforcement learning. This framing influenced thinking about inference-time compute allocation more broadly, contributing to the shift from “make models smaller” to “make inference smarter.”

Prerequisites

To follow this paper, you need to understand:

Connections