Analytics

Sunday, February 22, 2026

Speculative Decoding in LLM Inference

Running frontier LLMs is slow, but that's the tradeoff we make to get more intelligent output. But if you think about the text (tokens) that an LLM produces (or that a human produces, for that matter), you might have an intuition that a lot of it does not actually require high intelligence. There's a lot of "filler" in language to make things work syntactically that is easy to predict compared to, for example, the crux of an argument. This is evidenced by Shannon's old experiments in which he had subjects predict the next letter in a text and they got it right on the first try 75% of the time. The actual information content of most text is concentrated in a few key words. This property is doubly true of programming languages, where syntax makes up a huge portion of the content.

So a natural question one might ask is -- can we make LLMs faster at producing "easy" tokens vs "hard" tokens? This thinking is the motivation behind the idea of speculative decoding, which is a technique for speeding up LLM inference. It's a remarkably simple idea that requires just a bit of math to make it work.

Suppose we have a model $M_p$ which produces distribution $p(x)$ that we sample from. Assuming some input context $X_c$, then the processes of generating tokens from $M_p$ looks like:

$p_i(x) = M_p(X_c + [x_1, ..., x_{i-1}])$
$x_i \sim p_i(x)$

This is the standard autoregressive model sampling process where each token depends on the previous tokens. It's also one of the fundamental reasons why LLMs are hard to speed up -- there is an inherent sequentiality to how they produce tokens. But what if we can quickly "guess" the tokens that the LLM is going to generate with high probability, the same way that people did in Shannon's experiment?

Speculative decoding example. Green = speculated, red = rejected, blue = corrected. Source: Leviathan et al.

Speculative decoding involves using a faster, independent process $M_q$ to "speculate" potential future tokens of $M_p$. Traditionally, $M_q$ is another, smaller LLM, but there are multiple options (more on this below) -- the important thing is that $M_q$ produces a distribution $q(x)$ from which we can sample tokens. Incorporating speculation into the decoding process is straightforward: have $M_q$ generate $N$ (small, e.g. 4-8) tokens at each step, run $M_p$ on $N+1$ tokens simultaneously (which is possible because we now have actual tokens to condition on), and then run a "correction process" to ensure that the output distribution matches. Assuming $M_q$ is much faster than $M_p$, we get an overall speedup because $M_p$ can decode all $N+1$ of those tokens in approximately the same amount of time that it takes to decode a single one. This is a consequence of the observation that LLM decoding is usually limited by memory bandwidth, i.e. loading all of the weights from HBM to perform the forward pass.

Let's first prove correctness, specifically that we can do this without changing the token output distribution $p(x)$. Otherwise, you end up sacrificing performance for intelligence, which is specifically the tradeoff we want to avoid. The key lies in the correction process mentioned above.

For each token $x$ that we sample from $M_q$, we look at the probability distributions of both models $p(x)$ and $q(x)$ (note that we have $p(x)$ available because we still ran $M_p$). If $q(x) \le p(x)$, then we keep the token. Otherwise, we reject the token with probability $1 - p(x) / q(x)$, which corresponds to how much more likely $M_q$ was to choose the token than $M_p$, i.e. the blue excess over the red in the diagram. When we reject a token, we then re-sample from a modified distribution that spreads the excess probability to the other tokens that have deflated probabilities in $q(x)$ vs $p(x)$.

Token probability distributions and the correction process. The excess blue probability of token 1 gets spread out to the other tokens that have excess red probability.

More precisely, the resulting distribution is $p'(x) = \text{norm}(\text{max}(0, p(x) - q(x)))$. Once we reject a token, all remaining tokens that we sampled from $M_q$ are invalid, as they would have depended on the rejected token which was changed. This process guarantees that, no matter what distributions $M_q$ produces, the resulting output distribution matches exactly what $M_p$ would have produced (in the degenerate case, we end up rejecting everything and only ever have $M_p$ generate tokens directly).

The remaining question then, is how to pick a good $M_q$. Based on the correction process, we see that the quality of $M_q$'s approximation of $M_p$ is the defining factor in how well it speculates. Leviathan et al analyze this formally, and it turns out that the above process produces an expected number of valid tokens $(1 - \alpha^{N+1}) / (1 - \alpha)$ where $\alpha = E(\text{min}(p, q))$ is the expected overlap between the distributions.

In practice, there are two common choices for $M_q$. The simplest one is choosing a smaller version in the same model family, e.g. Llama-3.1-8B to speculate for Llama-3.1-405B. This is effective because the same model family tends to be trained on the same data and have the same biases introduced by architecture, leading to more overlap in the output distributions. Alternatively, the approach of the Medusa paper is to fine-tune special "decoding heads" on top of an LLM that are specifically trained to do speculation and can handle multiple branching paths. This is more efficient but requires training and is therefore not universal. In practice, the improvement from speculative decoding seems to end up in the 2-3x range on real data.

Medusa heads attached to a frozen base LLM and fine-tuned for speculation. Source: Cai et al.

The obvious downside to speculative decoding is increased computation -- we generate additional tokens from $M_q$ without changing the amount of total work $M_p$ does (it's faster because the model does the work in parallel now). In the case where we reject everything, we would be wastefully generating $N$ tokens from $M_q$ for each one token of $M_p$. Given this, speculative decoding is not a good fit for an inference environment that is compute-bound, e.g. running on big batches, which amortizes the model loading cost across multiple generations. Instead, it's well-suited for environments that have extra resources but want to provide lower latency, or as a way to leverage excess compute during periods of low traffic. As LLMs become more prevalent in real-time tasks like live translation, interactive coding (e.g. Cursor's speculative edits), and conversational agents, the demand for faster inference will continue to increase.

Wednesday, February 18, 2026

Triton Language

The world of GPU programming for AI has come a long way since I worked on writing a CUDA-based matrix library back in 2009. Both NVIDIA hardware and the CUDA ecosystem have evolved dramatically and are now the basis for the majority of AI compute in the world today (hence the $4T+ market cap). Nevertheless, writing CUDA is still really hard, primarily because you need a good understanding of low-level mechanisms within the GPU (e.g. memory hierarchy, warp scheduling, memory coalescing) to produce performant code. I recently came across a project called Triton, which is a Python-based DSL that makes it easier to build high-performance GPU kernels. I ended up writing a handful of LLM-related kernels to understand Triton better and found it quite interesting, so I want to share a little bit about this technology.

Improving performance of CUDA matrix multiplication kernel. Source: siboehm.com.

To illustrate how Triton works, it's helpful to start out with one of the simplest GPU kernels and the backbone of modern deep learning: matrix multiplication (C = A * B). There is a really nice blog post written by a performance engineer at Anthropic that walks through what it takes to get matrix multiplication performance on par with a mature library like cuBLAS. It validates my point above -- you not only have to understand these low-level GPU concepts, you need to reason carefully about how they interact with your specific computation. Even for something as basic as matrix multiplication, this is quite complex. But at a high level, you can boil it down to: how do we move data from High Bandwidth Memory (HBM) through the cache hierarchy (e.g. Shared Memory) efficiently and then make sure we have high enough arithmetic intensity to avoid being memory-bound.

By contrast, Triton has a block-centric programming model where you're abstracted away from the details of threads, warps, shared memory, etc. Instead, you schedule the execution of programs across a grid (similar to the CUDA grid), and each program instance operates on "blocks" of data, i.e. sub-regions of tensors. In the official Triton tutorial for matrix multiplication, you still need to be aware of memory hierarchy and choose the correct ordering to compute the output C. But the coalescing, shared memory caching, blocktiling, vectorization, and warptiling get done behind-the-scenes as optimizations by the Triton compiler.

Here is the Triton program code from the tutorial:

Let's break it down. First, assume we're launching a grid of programs computing [BLOCK_SIZE_M, BLOCK_SIZE_N] blocks of the output C (we'll show the grid construction later). The above function computes one such block, indexed by the program ID: pid = tl.program_id(axis=0). The first chunk of code decides which block we are going to compute:

# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote shared memory reuse.
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

For reasons explained in the tutorial, there is a somewhat sophisticated decision for which block to compute in order to improve L2 cache reuse in the GPU. The key insight is that we can reuse the same blocks that we loaded from A and B across consecutive program instances so that they are warm in the L2 cache and don't require HBM reads. This brings to light the importance of understanding program scheduling when writing Triton -- you can't hide all of the complexity! This next section of code is a very Triton-esque pattern:

# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

Based on which block we decided (pid_m, pid_n), we calculate the blocks of A and B that we will perform the computation over. offs_am is an array of row offsets in A (tl.arange returns a range of indices), offs_bn is an array of column offsets in B, and offs_k is an array of column offsets in A and row offsets in B (the dimensions that we accumulate over in matrix multiplication). From these 1-D arrays of row and column offsets, we broadcast them to produce the 2-D arrays of row + column offsets in A and B that represent the entries we are going to use in the computation.

The final two lines show how, in Triton, we think of tensors as pointers (to the beginning of the data) and need to address entries within the tensor by their absolute position relative to the beginning. That's why this function needs to know the strides, or how much to advance the pointer to represent moving forward one row or column. Here's a visual representation of each of these variables:

Visual representation of offsets and blocks in Triton matrix multiplication kernel.

Once we've figured out the two blocks that we're performing the computation over, it's just a matter of actually computing the result:

# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator = tl.dot(a, b, accumulator)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = accumulator.to(tl.float16)

This is your typical matrix multiplication inner loop, reframed using Triton block loads and pointer advancement. Note that masking is performed to handle cases where the dimensions are not multiples of our block sizes (another common pattern you'll see in both Triton and CUDA code). Finally, we store the result into the actual output location, as accumulator is a temporarily allocated block.

# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)

We do similar offset computations to identify the correct block of C to write to, and we again mask to handle non-divisible dimensions. Finally, we note that the function we examined is the code for a single program instance, and we still need to know the grid that we're executing these programs on. Here's how that is defined:

def matmul(a, b):
    # Check constraints.
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    M, K = a.shape
    K, N = b.shape
    # Allocates output.
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
    )
    return c

The main thing worth noting here is that the grid is not strictly a constant size. It takes advantage of Triton's autotune capability that searches over a list of options (specifically for BLOCK_SIZE_M, BLOCK_SIZE_N, and GROUP_SIZE_M in this case) to find the best configuration for a particular input size. This helps better align the program scheduling with the underlying hardware.

The Triton code is far from simple, but it does eliminate having to think about many of the harder aspects of GPU programming. I won't go into the details here, but Triton does this by compiling the DSL code into an LLVM intermediate representation where it can perform optimizations. Because the Triton compiler sees block-level access patterns rather than thread per thread-accesses, it can perform dataflow analysis on blocks and reason at a much higher level about what the code does. In doing so, it can decide when to prefetch memory, do hierarchical tiling (e.g. to fit the thread/warp model), schedule to coalesce memory accesses, manage and synchronize shared memory, and more. See the original paper (not up-to-date for Triton 2.0) for more details.

Triton vs cuBLAS matrix multiplication performance. Source: triton-lang.org.
As a user of Triton, writing code in Python that is mostly describing the logic of your computation and getting close to optimized CUDA kernel performance is a huge win. To get the most out of the GPU, you still need to understand memory hierarchy, program scheduling, and data layouts, but Triton gets you off the ground very quickly. Besides, we can't have programming be too easy, right?

For further examples, my triton-practice repository has a few well-documented implementations of core LLM building blocks like Rotary Positional Embeddings (RoPE) and Flash Attention.

Saturday, February 14, 2026

Linear Representations and Superposition

As LLMs become larger, more capable, and more ubiquitous, the field of mechanistic interpretability -- that is, understanding the inner workings of these models -- becomes increasingly interesting and important. Similar to how software engineers benefit from having good mental models of file systems and networking, AI researchers and engineers should strive to have some theoretical basis for understanding the "intelligence" that emerges from LLMs. A strong mental model would improve our ability to harness the technology. In this post, I want to cover two fundamental and related concepts in the field (each with their own paper) that I find fascinating from a mathematical perspective: the linear representation hypothesis (Park et al.) and superposition (Anthropic).

The linear representation hypothesis (LRH) has existed for quite some time, ever since people noticed that the word embeddings produced by Word2Vec satisfied some interesting properties. If we let $E(x)$ be the embedding vector of a word, then you observe the approximate equivalence

$E(``\text{king"}) - E(``\text{man"}) + E(``\text{woman"}) \approx E(``\text{queen"})$.

Observations of this form suggest that concepts (i.e. gender in the example) are represented linearly in the geometry of the embedding space, which is a simple but non-obvious claim.

Simplified model of an LLM in terms of embeddings and unembeddings.

Fast forward to modern LLMs, and the LRH remains a popular way to interpret what is going on inside these models. The Park et al. paper presents a mathematical framing of the hypothesis to try and formalize the idea. It uses a simplified model of an LLM where most of the inner workings (multilayer perceptron, attention, etc) are treated as a black box, and the interpretation of the LRH happens in two separate representation spaces with the same dimensionality as the model:

  • The "embedding space" where the final hidden states of the network live ($E(x)$ for an input context $x$). This is similar to the word embedding formulation and is where you would perform interventions that affect the model's behavior.
  • The "unembedding space" where the rows of the unembedding matrix live ($U(y)$ for each output token $y$). The concept direction measured by a linear probe over the hidden state (to evaluate the presence of the concept) corresponds to a vector in this space.

There are analogous statements of the LRH in the two respective spaces. Suppose $C$ represents the directional concept of gender, i.e. male => female. Then any pairs of input contexts that differ only in that concept should satisfy, e.g.

$E(``\text{Long live the queen"}) - E(``\text{Long live the king"}) = \alpha \cdot E_C$

where $\alpha > 0$ and $E_C$ is a constant vector in the embedding space referred to as the embedding representation. Similarly, any pairs of output tokens that differ only in that concept should satisfy, e.g.

$U(``\text{queen"}) - U(``\text{king"}) = \beta \cdot U_C$

where $\beta > 0$ and $U_C$ is a constant vector in the unembedding space referred to as the unembedding representation. Basically, applying the concept has a (directional) linear effect in both spaces.

The paper goes into much more detail that I'll skip over here, but they show that the embedding and unembedding representations are isomorphic, which unifies the intervention and linear probe ideas. They then empirically verify on Llama 2 that they can find the representations for a variety of concepts (e.g. present => past tense, noun => plural, English => French) that approximately fit into their theoretical framework -- cool!

Approximate orthogonality of concept representations in Llama 2. Source: Park et al.

Okay, so let's assume concepts do in fact have linear representations. Then it would stand to reason that unrelated concepts have orthogonal directions. Otherwise, applying the male => female concept could influence the presence of the English => French concept, which doesn't make sense. One of the key results from Park et al. is that this orthogonality doesn't occur under the standard Euclidean inner product but instead under a "causal inner product" that is derived from the unembedding matrix. Only by looking at concept representations through that lens do we get the orthogonality we expect.

But in these models, the representation space is relatively small (most ranging from 2K to 16K dimensions). So how do these spaces fit such a large number of language features that far exceeds their dimensionality? It's impossible for all such features to be orthogonal, no matter the geometry.

The interference effect of non-orthogonal features. Source: Anthropic.

This is where superposition comes into play. In low-dimensional spaces, the intuition is that, when you have $N$ vectors in a $d$-dimensional space with $N > d$, they start to interfere substantially (inner product has a large magnitude). This is one of those examples where low-dimensional intuition does not extend to higher dimensions, however, as evidenced by the Johnson-Lindenstrauss lemma. An implication of the lemma is that you can choose exponentially (in the number of dimensions) many vectors that are almost-orthogonal -- that is, the inner products between any pair are bounded by a small constant. You can think of this as the flip side of the curse of dimensionality.

The Anthropic paper demonstrates the superposition phenomenon in toy models on small, synthetic datasets. One particularly interesting observation is that superposition does not occur with no activation function (purely linear computation), but it does occur with a nonlinear one (ReLU in their case). The idea is that the nonlinearity allows the model to manage the interference in a productive way. But this still only works well because of the natural sparsity of these features in the data -- models learn to superimpose features that are unlikely to be simultaneously present, minimizing interference.

Visualization of a square antiprism, the energy-minimizing arrangement of 8 points on a 3-D unit sphere.

In experimental setups on synthetic data with independent features of equal importance and sparsity, they observe that the embedding vectors learned by the model form regular structures in the embedding space, e.g. a tetrahedron, pentagon, or square antiprism. Coincidentally, these are the same types of structures that I found in some old research I did on spherical codes. These structures emerged from using gradient descent-like algorithms to minimize the energy (analogous to that described by the Thomson problem) of arrangements of points on unit hyperspheres. Fun to see the overlap of multiple fields!

To conclude, features as linear representations, even if not the complete story, is a valuable framework to help us interpret and intervene in LLMs. It has a solid theoretical basis that is backed up empirically. Sparsity, superimposition, and the non-intuitive nature of higher-dimensional spaces give us a window into understanding how the complexity of language (and intelligence?) gets captured by these models. Mechanistic interpretability has a long way to go, but it's reassuring to see us slowly uncovering the nature of LLMs and why they're able to do such incredible things.

Wednesday, February 11, 2026

Reflections on using Claude Code

This is the follow-up to my post from two weeks ago on observations using Claude Code (Opus 4.5 to start, 4.6 after it released) to rebuild the kfchess.com website, with the constraint that I would not write any of the code myself. 60k lines of code and 100 commits later, I'm happy to share that I've successfully deployed the new site to production! It took me 60-80 hours of total work over the course of 3.5 weeks, which I estimate is ~3x less time than it would have taken me to do myself. All of the code can be found in this repository.

My goal for this post is to document what I found to be easy vs hard for Claude Code (CC) in order to refine my mental model about what AI coding tools are currently capable of. In the same way that a good software engineer understands when to use Postgres vs Redis (I use both in kfchess), the engineers who have the best mental models of CC will get the most out of it. And increasingly, your ability to leverage tools like CC is strongly correlated to how quickly you can build software.

So, let's start with the good stuff -- things where I was genuinely impressed and found a ton of value in what CC did for me. There's a lot.

  1. Project bootstrap. I mentioned this in my previous post, but I was able to get started developing the new version of kfchess much faster than I would have otherwise. CC does a great job of choosing solid technologies and putting together a development environment that is both modern and easy to work with.
  2. Anything CRUD+. For your typical website, you have lots of pretty straightforward CRUD operations that just need to be built. It's no surprise that CC is good at doing this pattern matching. But it was also really good at the slightly less common, but still standard parts of the site: Google OAuth, email verification, WebSocket setup, database migrations, etc.
  3. Architecture design. I asked CC to design a multi-server architecture that allows games to persist across restarts and move between processes. This is a nontrivial capability, and it came up with a solid design that I was able to work off of. I had to steer the implementation and help it resolve a bunch of edge cases, but it was directionally good.
  4. Bonus UX. After I describe a feature to CC, it does its own small extrapolation of what would be useful from a UX perspective and automatically implements those features. Some examples of things I didn't specify but were great: the entire lobby feature, replay playback controls, game over modal, pagination of games/replays.
  5. Writing CSS + responsiveness. My own CSS skill is limited at best, and CC is great at covering for that. CSS as a domain is easily verifiable -- after I ask CC to make a change, it is easy for me to see if it worked and provide feedback. It is substantially more pleasant to write CSS through CC than it is to write it manually.
  6. Extensive unit testing. To double down on the importance of verifiability, unit tests are really important for CC as a way for it to understand whether the changes it makes are good. The best thing here is that CC itself can write all the tests, and it can write way more tests than a human typically would because it's essentially "free." The test coverage on this repo is, without a doubt, the highest of any code I've ever written.
  7. Easy debugging. Most of my debugging workflows look like: describe the observed issue to CC, ask it to investigate and fix, then have it write a test to catch the issue in the future. Probably 80-90% of bugs are fixed in this way without me having to do a deeper dive into the code to understand why the bug was happening at all.
It's easy to see how you get a boost to productivity with all of the above. There's also a thread running through most of these: CC is at its best when the problem is well-defined and the output is easy to verify. CSS looks right or it doesn't. Tests pass or they don't. CRUD endpoints either work or they throw errors. But when you stretch it beyond these domains and things get fuzzier, there are some clear limitations.
  1. Game engine and AI player. Using CC to build the game engine and AI player took about the same amount of time as it would have taken me to build them myself. In the case of the game engine, there were so many edge cases in how a real-time chess game behaves that CC didn't cover, so I had to discover them myself and then teach CC how to fix them. On the AI player front, this is an inherently open-ended problem of coming up with a set of cheap heuristics that "feel" strong to play against, so I had to keep going back and forth on the ideas that CC would implement. Both of these domains lack the verifiability aspect that allows CC to thrive.
  2. Complex debugging. This is the hardest aspect of software engineering, so it's not surprising that CC struggles here. There were three bugs in particular that CC was never able to solve even with many iterations: board resizing causing an infinite loop, the AI player incorrectly counting dangerous positions as safe, and stale games getting stuck in the registry. Interestingly, in each of these cases, gpt-5.3-codex (xhigh) made significantly more progress in identifying root causes.
  3. Campaign levels. Designing campaign levels has an aspect of creativity and taste to it so that the levels are interesting and fun to play. CC's hit rate on levels that felt good to me was about 10%, and it mostly just created variants of other levels I had already designed. This is actually the only time where I did write "code" myself, i.e. describing the campaign levels.
  4. Multi-system interactions. As the codebase grew over the weeks, I noticed that it became harder and harder for CC to keep track of exactly how different parts of the system should interact. For example, the games played as part of the campaign feature didn't integrate well with the replay feature. Increasingly, I felt that my role was to probe these multi-system interactions, much like how a senior engineer would inspect a junior engineer's work.
The mental model I came away with is that CC replaces most of the time-consuming, boilerplate parts of engineering, which lets me focus on the more open-ended and deep problems. That's a fantastic improvement to the workflow, and it's the first time that an AI coding tool has enabled me to build much better (not just faster) than before.

Saturday, February 7, 2026

Quantization-Aware Distillation

Following up on the last post, I want to write about this new paper from NVIDIA on quantization-aware distillation. We learned about quantization last time, so let's talk about distillation now. Model distillation is the process of transferring knowledge from one model (the teacher) to another, usually smaller, one (the student). Similar to quantization, the goal is to produce a smaller model that uses less memory and compute while still retaining intelligence.

Benchmark performance of DeepSeek-R1 distillations. Source: DeepSeek.

A fairly well-known example is the DeepSeek-R1 release from a year ago, which was the first major open-source reasoning model. As part of their release, they distilled DeepSeek-R1 into various smaller Llama and Qwen models to demonstrate the fact that the reasoning capability could, in part, be transferred to other models. The idea is that you can take these small models which lack strong reasoning capabilities, show them DeepSeek-R1's output (more specifically, the output probability distributions), and ask them to mimic it. This substantially improved the small models' performance on tasks like math and coding which benefit from careful reasoning.

With respect to NVFP4, the problem to solve is how you get the set of NVFP4 weights that correspond to the strongest model. There are three approaches discussed in the paper.

  1. Post-training quantization (PTQ): Start from the full-precision weights and scale through calibration to map them to NVFP4 weights. This works well for large models, but has poor observed performance in smaller models.
  2. Quantization-aware training (QAT): Simulate quantization during the training process to allow the model to adjust for the bias introduced by quantization.
  3. Quantization-aware distillation (QAD): Distill knowledge directly from a high-precision, post-trained model (of the same size) to the quantized one.

In the context of the paper, the teacher model uses bfloat16 while the student model of course uses NVFP4. To perform QAD, they train the quantized model using the Kullback-Leibler divergence between the teacher and student probability distributions as the loss function. This is in contrast to traditional pre-training and QAT where the model tries to replicate the training data itself. While distillation typically uses larger teacher models, the authors observe that keeping the teacher model the same size as the student works better for QAD, likely because it's easier for the student to recover its own distribution rather than learning a new one.

KL divergence vs cross-entropy loss of QAT compared to QAD. Source: NVIDIA.

One particularly fascinating result is that, although both QAT and QAD models achieve similar cross-entropy loss on the dataset (fairly close to that of the bfloat16 model), the KL divergence of the QAD model is substantially better on held-out samples. The takeaway being that, although QAT adjusts for quantization well during training, the resulting model behaves very differently from its high-precision reference.

On top of that, QAD is much simpler to perform on extensively post-trained models. Models these days undergo significant post-training via supervised fine-tuning (SFT) and/or reinforcement learning (RL), and it can be quite challenging to keep these processes stable under quantization. In fact, the paper finds that QAT actually degrades performance over PTQ, sometimes losing the capabilities gained during RL training. Using QAD bypasses the need for this, with the tradeoff of needing the high-precision model where the heavy lifting of post-training has already been done. The paper shows even larger QAD vs QAT wins on the performance of these types of models.

Recovering performance via QAD with limited training data. Source: NVIDIA.

A final, interesting observation made in the paper is that QAD as a process is robust to incomplete training data. That is, even when presented with only math training data or only code training data, the model recovers performance on both domains. The paper suggests that the output probability distributions of the teacher model contain information for all domains even on limited input tokens. So as long as you present the model with some amount of high-quality training data, it can perform well generically.

Distillation as an LLM training mechanism is a powerful tool, which intuitively suggests that mimicking intelligence is computationally simpler than deriving it, whether from scratch (as with DeepSeek-R1) or as part of recovering performance in a quantized scenario. The fact that smaller (or quantized) models can successfully mimic is also an indication that it's less about size or precision gating model strength than it is the process of synthesizing behaviors into the parameters. This is already happening to some extent, but my guess is that, long-term, we will have lots of small, quantized models running at the edge (e.g. on phones, computers, browsers) that are distilled from centralized, intelligent teachers.

Tuesday, February 3, 2026

LLM Quantization and NVFP4

With the rise of large language models and the desire to run them more cheaply and efficiently, the concept of quantization has gained a lot of traction. By representing the weights of the LLM with data types that use fewer bits, you reduce the necessary GPU memory to load the model and the memory bandwidth of operations like matrix multiplication.

As a quick refresher, floating-point numbers consist of a single sign bit, a number of exponent bits (E), and a number of mantissa bits (M). If $e$ is the value of the exponent bits (potentially biased), and $m$ is the value of the mantissa bits, then the floating point number represented is

$f = sign \cdot 2^e \cdot (1 + \frac{m}{2^M})$

Intuitively, the exponent determines the rough scale of the number, and the mantissa determines the precise value within that scale.

float32 representation. Source: Wikipedia.

The full-precision reference format commonly used for LLMs is float32, which is your standard float data type in C, and has E = 8, M = 23 for 32 bits total. For quantizing down to 16 bits, it's become popular to use bfloat16, which is a newer format developed for machine learning specifically. The bfloat16 format uses E = 8, M = 7 versus traditional float16 that uses E = 5, M = 10. This is because capturing a wider range of scale (e.g. for large gradients or activations) is more important than high precision in ML.

bfloat16 representation. Source: Wikipedia.

Quantization can go further: 8-bit, 4-bit, and even 2-bit formats are common now. By the time you get down to 4 bits, it's no longer obvious that anything will work -- after all, 4 bits can only represent 16 values! To understand how this could work, let's look in particular at NVFP4, which is NVIDIA's own data type that is targeted specifically at maintaining model accuracy.

It's a bit of a misclassification to consider NVFP4 a data type at all, as it's not a standalone representation like traditional floating-point numbers. Rather, it is a format for an entire tensor, which is the building block for neural networks. It consists of a standard 4-bit floating-point value (E = 2, M = 1) used in conjunction with per-block scaling factors and a per-tensor scaling factor. You can think of it like a generalization of the floating-point concept across multiple values instead of within a single one.

NVFP4 representation. Source: NVIDIA.
 
There is a single 32-bit tensor scaling factor that determines the "global scale" of the tensor, and then an 8-bit scaling factor for every 16 values in the tensor. These scaling factors mitigate the 4-bit format's limited range, and they've chosen E and M values for the scaling factors to most accurately reconstruct true values in practice (as measured by mean squared error). The additional scaling factors add an overhead of about 8 bits per 16 values, or 0.5 bits per 4-bit value (12.5% overhead), which is the tradeoff against a standard float4. It might seem complex to manipulate tensors in this format, but NVIDIA has implemented hardware support for NVFP4 into their Blackwell architecture, so their GPUs natively understand how to do it, abstracting the complexity away from developers.

One thing I've glossed over is how the quantization actually happens, which isn't trivial. Another reason that bfloat16 is preferred over float16 is the fact that quantizing from float32 to bfloat16 is easy as they have the same range (number of exponent bits). But if you quantize to 8 or 4 bits, that's never going to be true, so you have to apply scaling as described in this blog post. The post covers techniques for post-training quantization in order to choose quantized weights that maintain model accuracy, but I won't be covering them here. Instead, this post is a setup for my next post on quantization-aware distillation as an alternative approach, which is a paper that NVIDIA recently published -- stay tuned!