Distilled Decoding 2: One-step Sampling of Image Auto-regressive Models with Conditional Score Distillation¶
会议: NeurIPS 2025
arXiv: 2510.21003
代码: GitHub
领域: 图像生成 / 模型加速
关键词: 自回归模型加速, 一步生成, 分数蒸馏, 条件分数蒸馏, 图像生成
一句话总结¶
本文提出Distilled Decoding 2(DD2),通过将自回归图像模型重新解读为条件分数模型,设计了条件分数蒸馏(CSD)损失,将多步AR采样压缩为一步生成,在ImageNet-256上实现FID从3.40到5.43的微小退化同时获得8.0x加速(VAR)和238x加速(LlamaGen),相比DD1缩小了67%的性能差距且训练快12.3倍。
研究背景与动机¶
领域现状:图像自回归(AR)模型(如VAR、LlamaGen)已经达到了最先进的图像生成质量,甚至超越了VAE、GAN和扩散模型。这些模型通过逐个预测离散token来生成图像,保持了良好的likelihood建模特性。
现有痛点:AR模型的核心瓶颈在于逐步采样——VAR需要10步,LlamaGen需要256步。每步都需要完整的模型前向传播,导致生成速度远慢于GAN等一步生成方法。集合预测方法(如MaskGIT式的一次预测多个token)在极端情况(一步生成)下会完全丢失token间的相关性。推测解码(speculative decoding)在图像AR上仅能实现不到3倍的有限压缩。DD1是第一个实现一步采样的方法,但依赖预定义的noise→data映射,这个映射难以学习,导致FID退化明显(VAR-d20从3.40到9.55),且训练较慢。
核心矛盾:DD1将AR模型的每步采样转化为flow matching的ODE求解来构建确定性映射,本质上只用了分数信号来建映射。但这种预定义映射(1)对模型学习困难,(2)限制了灵活性——GAN、VAE等不依赖显式映射的模型反而在下游任务中适用性更广。
本文目标 能否训练一个一步生成器,让其输出分布匹配给定的AR模型,而不依赖任何预定义映射?
切入角度:将AR模型重新解读为条件分数模型——给定前面所有token后,AR模型输出的next-token概率向量可以解析地计算codebook embedding空间中的条件score。然后借鉴扩散模型score distillation的思想,匹配生成器和teacher AR模型在每个token位置上的条件分数。
核心 idea:把AR模型的next-token概率向量当作条件score的来源,通过条件分数蒸馏(CSD)在所有token位置上同时对齐生成器与teacher的条件分布,实现无需预定义映射的一步AR图像生成。
方法详解¶
整体框架¶
整个训练分为两个阶段:(1)初始化阶段:将teacher AR模型的分类头替换为MLP头,用GTS loss微调成AR-diffusion模型;(2)CSD训练阶段:用该AR-diffusion模型初始化generator和guidance network,交替训练generator(用CSD loss)和guidance network(用FCS loss)。推理时,generator一次前向传播输出整个token序列。
关键设计¶
-
Teacher AR模型作为条件分数模型:
- 功能:将离散概率向量转化为连续的score信号
- 核心思路:将token \(q_i\)的采样视为从"加权Dirac函数之和→Gaussian"的flow matching过程。给定teacher输出的概率向量\(p = (p_1,...,p_V)\)和flow matching时间步\(t\),条件score可以解析计算为:\(s(x_t, t, p) = -\frac{\sum_j p_j(x_t - (1-t)c_j)e^{-\frac{(x_t-(1-t)c_j)^2}{2t^2}}}{t^2 \sum_j p_j e^{-\frac{(x_t-(1-t)c_j)^2}{2t^2}}}\)。与DD1不同,DD2不用这个score去构建ODE映射,而是直接用于蒸馏
- 设计动机:AR模型在每个token位置已经隐式包含了完整的条件分数信息,DD1只用了这个信息的一部分(构建映射),DD2要更充分地利用
-
条件分数蒸馏损失(CSD Loss):
- 功能:训练一步generator,使其输出序列的联合分布匹配teacher AR模型
- 核心思路:在每个token位置\(i\)上,对齐teacher的条件score \(s_\Phi(q_i^{t_i}, t_i | q_{<i})\)和guidance network学到的fake条件score \(s_{\text{fake}}(q_i^{t_i}, t_i | q_{<i})\)。CSD loss是所有位置的score distillation loss之和:\(\mathcal{L}_{CSD} = \mathbb{E}\sum_{i=1}^n d(s_\Phi(q_i^{t_i}, t_i | sg(q_{<i})), s_{\text{fake}}(q_i^{t_i}, t_i | sg(q_{<i})))\),其中\(sg(\cdot)\)表示stop gradient。采用SiD loss形式。关键理论保证(Proposition 1):最小化CSD loss → generator的联合分布匹配teacher
- 设计动机:渐进对齐逻辑——先对齐\(q_1\)的分布(无条件),再在此基础上对齐\(q_2|q_1\),依此类推。这比对齐整个联合分布更容易优化
-
AR-diffusion初始化策略:
- 功能:为generator和guidance network提供良好的初始化
- 核心思路:直接用teacher AR模型的权重初始化不可行,因为teacher输出离散概率而generator需要输出连续值。解决方案:将teacher的分类头替换为MLP,用Ground Truth Score (GTS) loss微调:\(\mathcal{L}_{GTS} = \mathbb{E}\sum_{i=1}^n \|s_\psi(q_i^{t_i}, t_i | q_{<i}) - s_\Phi(q_i^{t_i}, t_i | q_{<i})\|^2\),即直接回归teacher的解析score。微调后的模型同时作为generator和guidance network的初始化
- 设计动机:实验证明没有良好初始化score distillation会完全collapse。即使只随机初始化generator的最后一层,性能也会严重退化。GTS loss比标准AR-diffusion loss收敛更快更稳定
损失函数 / 训练策略¶
- Generator训练:CSD loss(Eq. 4),采用SiD loss形式
- Guidance network训练:FCS loss(Eq. 5),标准的score matching loss
- 两个网络交替训练
- GTS初始化阶段做全量微调,CSD阶段的训练远短于DD1
实验关键数据¶
主实验¶
ImageNet-256上的一步生成质量对比:
| 方法 | 模型 | FID↓ | IS↑ | 步数 | 加速比 |
|---|---|---|---|---|---|
| VAR-d20原始 | 600M | 3.40 | 305.1 | 10 | 1x |
| DD1 | VAR-d20 | 9.55 | 197.2 | 1 | - |
| DD2 | VAR-d20 | 5.43 | 233.7 | 1 | 8.0x |
| LlamaGen-L原始 | 343M | 4.11 | 283.5 | 256 | 1x |
| DD1 | LlamaGen-L | 11.35 | 193.6 | 1 | - |
| DD2 | LlamaGen-L | 8.59 | 229.1 | 1 | 238x |
DD2在所有config下1步FID优于DD1的2步结果。
消融实验¶
| 配置 | FID (LlamaGen) | FID (VAR-d24) | 说明 |
|---|---|---|---|
| Gui-Init ✓ Gen-Init ✓ | 14.77 | 11.53 | 完整初始化(最佳) |
| Gui-Init ✓ Gen-Init ✗ | 16.08 | >200 (collapse) | 缺Generator初始化 |
| Gui-Init ✗ Gen-Init ✓ | 21.76 | >200 (collapse) | 缺Guidance初始化 |
训练效率对比:
| 方法 | 模型 | GPU小时(8xA800) | 训练加速比 |
|---|---|---|---|
| DD1 | LlamaGen-L | 647.7 | 1x |
| DD2 | LlamaGen-L | 52.6 | 12.3x |
| DD1 | VAR-d24 | 604.2 | 1x |
| DD2 | VAR-d24 | 96.1 | 6.3x |
关键发现¶
- DD2大幅超越DD1:VAR上缩小了teacher和one-step之间67%的性能差距
- 初始化至关重要:缺少任一组件初始化会导致collapse(VAR-d24上FID>200),即使只随机初始化最后一层也会严重退化
- 训练效率极高:最多12.3倍训练加速,因为不需要像DD1那样预先构建所有noise-data映射
- PPL更低:DD2的Perceptual Path Length(7231.9 vs 18437.6)远低于DD1,说明DD2学到了更光滑的latent space
- 多步采样锦上添花:DD2 3步(4.88 FID)就优于DD1 6步(5.03 FID),提供灵活的质量-速度trade-off
亮点与洞察¶
- AR → score model的重新解读极其巧妙:原本AR模型输出的是next-token的离散概率,本文把它视为连续embedding空间中的score信号。这个视角的转换使得扩散蒸馏的整套理论和工具可以直接迁移到AR加速中。这种跨paradigm的知识迁移值得学习
- "消除预定义映射"的核心洞察:DD1的mapping是hard target(必须精确学到每个noise→data的对应),DD2的CSD是soft target(只需分布匹配),后者显然更容易优化。打个比方:DD1是学填空的唯一答案,DD2是学整个答案的分布
- 初始化策略的通用价值:GTS loss(用teacher的解析score代替Monte Carlo估计)收敛更快更稳定,这个技巧可推广到其他score distillation设置
局限与展望¶
- 仍有性能差距:VAR-d20从3.40到5.43仍有2.03的FID退化,对质量要求极高的场景可能不够
- 仅验证了VQ-based AR模型:对MAR等continuous-space AR模型的兼容性仅做了理论讨论,未实验验证
- 仅在ImageNet-256上评测:未扩展到文本→图像等更实用的场景
- 训练仍需8×A800 GPU:虽然比DD1快12x,绝对成本仍然不低
- 初始化的强依赖是双刃剑:高度依赖GTS初始化,如果teacher模型quality差,可能影响DD2效果
相关工作与启发¶
- vs DD1:DD1用flow matching构建确定性映射然后拟合,DD2用score distillation做分布匹配。DD2在性能(-67%性能差距)、训练速度(12.3x)、latent smoothness(PPL减半)上全面超越
- vs 扩散Score Distillation(DMD/SiD):虽然都做score matching,但AR模型和diffusion model的生成过程完全不同——AR是顺序条件生成,diffusion是iterative denoising。DD2的创新在于将score distillation从独立变量扩展到具有AR依赖的序列变量
- vs Set-of-token预测(MaskGIT/VAR):一步预测所有token会丢失token间的相关性(如{(0,0),(1,1)}的例子),DD2通过CSD在所有位置做条件score对齐来保持相关性
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ 将score distillation引入AR模型加速,AR→score model的重新解读非常有创见
- 实验充分度: ⭐⭐⭐⭐ 两种AR模型(VAR+LlamaGen)、多种size、完整的消融和对比
- 写作质量: ⭐⭐⭐⭐ 动机和理论推导清晰,但第3节公式密度较高
- 价值: ⭐⭐⭐⭐⭐ 对AR图像生成加速具有里程碑意义,可能改变AR vs Diffusion的速度对比格局