Following up on the last post, I want to write about this new paper from NVIDIA on quantization-aware distillation. We learned about quantization last time, so let's talk about distillation now. Model distillation is the process of transferring knowledge from one model (the teacher) to another, usually smaller, one (the student). Similar to quantization, the goal is to produce a smaller model that uses less memory and compute while still retaining intelligence.
| Benchmark performance of DeepSeek-R1 distillations. Source: DeepSeek. |
A fairly well-known example is the DeepSeek-R1 release from a year ago, which was the first major open-source reasoning model. As part of their release, they distilled DeepSeek-R1 into various smaller Llama and Qwen models to demonstrate the fact that the reasoning capability could, in part, be transferred to other models. The idea is that you can take these small models which lack strong reasoning capabilities, show them DeepSeek-R1's output (more specifically, the output probability distributions), and ask them to mimic it. This substantially improved the small models' performance on tasks like math and coding which benefit from careful reasoning.
With respect to NVFP4, the problem to solve is how you get the set of NVFP4 weights that correspond to the strongest model. There are three approaches discussed in the paper.
- Post-training quantization (PTQ): Start from the full-precision weights and scale through calibration to map them to NVFP4 weights. This works well for large models, but has poor observed performance in smaller models.
- Quantization-aware training (QAT): Simulate quantization during the training process to allow the model to adjust for the bias introduced by quantization.
- Quantization-aware distillation (QAD): Distill knowledge directly from a high-precision, post-trained model (of the same size) to the quantized one.
In the context of the paper, the teacher model uses bfloat16 while the student model of course uses NVFP4. To perform QAD, they train the quantized model using the Kullback-Leibler divergence between the teacher and student probability distributions as the loss function. This is in contrast to traditional pre-training and QAT where the model tries to replicate the training data itself. While distillation typically uses larger teacher models, the authors observe that keeping the teacher model the same size as the student works better for QAD, likely because it's easier for the student to recover its own distribution rather than learning a new one.
| KL divergence vs cross-entropy loss. Source: NVIDIA. |
One particularly fascinating result is that, although both QAT and QAD models achieve similar cross-entropy loss on the dataset (fairly close to that of the bfloat16 model), the KL divergence of the QAD model is substantially better on held-out samples. The takeaway being that, although QAT adjusts for quantization well during training, the resulting model behaves very differently from its high-precision reference.
On top of that, QAD is much simpler to perform on extensively post-trained models. Models these days undergo significant post-training via supervised fine-tuning (SFT) and/or reinforcement learning (RL), and it can be quite challenging to keep these processes stable under quantization. In fact, the paper finds that QAT actually degrades performance over PTQ, sometimes losing the capabilities gained during RL training. Using QAD bypasses the need for this, with the tradeoff of needing the high-precision model where the heavy lifting of post-training has already been done. The paper shows even larger QAD vs QAT wins on the performance of these types of models.
| Recovering performance via QAD with limited training data. Source: NVIDIA. |
A final, interesting observation made in the paper is that QAD as a process is robust to incomplete training data. That is, even when presented with only math training data or only code training data, the model recovers performance on both domains. The paper suggests that the output probability distributions of the teacher model contain information for all domains even on limited input tokens. So as long as you present the model with some amount of high-quality training data, it can perform well generically.
Distillation as an LLM training mechanism is a powerful tool, which intuitively suggests that mimicking intelligence is computationally simpler than deriving it, whether from scratch (as with DeepSeek-R1) or as part of recovering performance in a quantized scenario. The fact that smaller (or quantized) models can successfully mimic is also an indication that it's less about size or precision gating model strength than it is the process of synthesizing behaviors into the parameters. This is already happening to some extent, but my guess is that, long-term, we will have lots of small, quantized models running at the edge (e.g. on phones, computers, browsers) that are distilled from centralized, intelligent teachers.
No comments:
Post a Comment