Large Scale Diffusion Distillation via Score-Regularized Continuous-Time Consistency¶
会议: ICLR 2026
arXiv: 2510.08431
代码: 项目页
领域: 图像生成
关键词: 连续时间一致性模型, Score蒸馏, 大规模蒸馏, JVP, 少步生成
一句话总结¶
提出 rCM(score-regularized continuous-time consistency model),首次将连续时间一致性蒸馏扩展到 14B 参数的文生图/视频模型,通过结合前向散度(一致性)和反向散度(score蒸馏),在保持多样性的同时匹配 DMD2 的质量,实现 15-50× 加速。
研究背景与动机¶
- sCM(连续时间一致性模型)理论优雅,但在大规模文生图/视频模型上的适用性不明——JVP 计算与 FlashAttention、并行训练不兼容
- sCM 在细节生成上存在质量问题(误差累积 + 前向散度的 mode-covering 特性导致质量扩散)
- Score/对抗蒸馏方法(如 DMD2)在质量上领先,但存在模态坍塌和多样性不足
- 前向散度(一致性模型)与反向散度(score蒸馏)具有互补性
方法详解¶
整体框架¶
rCM 要解决的问题是:让理论优雅的连续时间一致性蒸馏(sCM)真正在 14B 级文生图/视频模型上跑起来,并补上它在精细画质上的短板。整体怎么转——以冻结的教师模型为监督源,学生网络同时吃两路互补的监督信号:一路是 sCM 的前向 KL(mode-covering,负责覆盖全部模态、保住多样性),另一路是 DMD 的反向 KL(mode-seeking,负责把分布往高密度区收、锐化细节保质量),二者加权成统一的 rCM 目标。这两路都各有工程前提:sCM 路依赖一个能在大模型上算梯度的 JVP 内核加一套数值稳定化,DMD 路依赖把学生当前生成的样本喂给一个 fake score 网络来估反向 score。训练时两套网络交替更新——学生用组合后的 rCM 目标更新,fake score 网络则用 flow matching loss 在学生最新生成的数据上跟训。
%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400}}}%%
flowchart TD
T["教师模型<br/>(冻结·蒸馏 CFG)"] --> S["学生 F_θ<br/>少步生成器"]
S -->|"沿 PF-ODE 求切向"| JVP["FlashAttention-2<br/>JVP 内核 (兼容 FSDP/CP)"]
JVP --> ST["稳定时间导数<br/>半连续 / 时间嵌入 FP32"]
ST --> LSCM["sCM 损失·前向 KL<br/>mode-covering 保多样性"]
S -->|"Rollout 多步采样"| X0["学生样本 x_0 ~ p_θ"]
X0 --> FAKE["fake score 网络<br/>flow matching 跟训"]
FAKE --> LDMD["DMD 损失·反向 KL<br/>mode-seeking 提质量"]
LSCM --> RCM["Score 正则化<br/>L_rCM = L_sCM + λ·L_DMD"]
LDMD --> RCM
RCM -->|"反传更新学生 θ"| S
关键设计¶
1. FlashAttention-2 的 JVP 内核:让连续时间一致性能在大模型上算梯度
sCM 的训练依赖雅可比向量积(JVP)来估计 teacher 沿概率流 ODE 的切向,但现成的 FlashAttention 只暴露前向输出、不返回 JVP,导致 sCM 在 10B 级模型上根本无法接入主流注意力实现与并行训练。作者用 Triton 写了一个把 JVP 计算直接嵌进 FlashAttention-2 前向传播的内核,自注意力和交叉注意力都覆盖,并且与 FSDP、Context Parallelism 兼容,从而把 JVP-based 的 sCM 训练首次撑到了 14B 参数规模——这是整套方法能落地的工程前提。
2. Score 正则化:用反向 KL 的 DMD 项给前向 KL 的 sCM 补质量
纯 sCM 的前向散度天然 mode-covering,叠加少步生成的误差累积后,在文字渲染等精细场景会出现明显质量缺陷。作者把 DMD loss 当作一个「长跳」正则器加到 sCM 上,得到组合目标 \(\mathcal{L}_{\text{rCM}}(\theta) = \mathcal{L}_{\text{sCM}}(\theta) + \lambda \mathcal{L}_{\text{DMD}}(\theta)\)。其中 sCM 提供 mode-covering 的多样性,DMD 的反向 KL 则 mode-seeking 地把分布往高密度区收以提升质量,两者方向互补。权重 \(\lambda=0.01\) 在所有模型和任务上通用,无需逐任务调参,也省去了 GAN 式蒸馏的对抗调优。
3. 稳定的时间导数计算:压住大模型 JVP 的数值不稳定
大模型上直接算 JVP 的时间分量容易数值发散,作者给出两条可叠加的稳定化方案。一是半连续时间:空间部分仍走精确 JVP,时间方向改用步长 \(\Delta t = 10^{-4}\) 的有限差分近似,避开最不稳定的那一项;二是高精度时间:对时间嵌入层强制 FP32 精度,防止低精度下时间导数被舍入误差吞掉。两招让 14B 规模下的 sCM 训练得以收敛。
4. Rollout 多步采样:让学生既能一步也能多步,并稳定反传 DMD
学生被训练成支持任意步采样,训练时随机抽步数 \(N \in [1, N_{\max}]\) 做 rollout,只对最后一步反传 DMD loss,并用随机时间步保证整个 \([0,1]\) 时间范围都被覆盖到。这样单步与多步推理共享同一组参数,推理时可按质量/速度权衡自由选 NFE,而只回传末步梯度也避免了多步展开带来的显存和不稳定。
损失函数 / 训练策略¶
sCM 项采用切线归一化形式 \(\mathcal{L}_{\text{sCM}} = \mathbb{E}\left[\left\|\mathbf{F}_\theta - \mathbf{F}_{\theta^-} - \frac{\mathbf{g}}{\|\mathbf{g}\|_2^2 + c}\right\|_2^2\right]\),其中 \(\mathbf{F}_{\theta^-}\) 是 EMA 目标网络、\(\mathbf{g}\) 为切向、\(c\) 为数值稳定常数。DMD 项依据 fake score 网络与 teacher score 之间的差异把学生分布往真实分布拉,而 fake score 网络本身用 flow matching loss 在学生当前生成的数据上交替训练,二者互为依赖、轮流更新。
实验关键数据¶
主实验(GenEval T2I)¶
| 模型 | 参数 | NFE | Overall | Counting | Position |
|---|---|---|---|---|---|
| FLUX.1-dev | 12B | 50 | 0.66 | 0.74 | 0.22 |
| Cosmos-Predict2 14B (teacher) | 14B | 70 | 0.84 | 0.79 | 0.64 |
| Cosmos-Predict2 + DMD2 | 2B | 4 | 0.80 | 0.70 | 0.57 |
| Cosmos-Predict2 + rCM | 2B | 4 | 0.81 | 0.73 | 0.58 |
| Cosmos-Predict2 + rCM | 14B | 4 | 0.83 | 0.80 | 0.59 |
| Cosmos-Predict2 + rCM | 14B | 1 | 0.82 | 0.84 | 0.49 |
VBench 视频实验¶
| 模型 | 参数 | NFE | Total Score | Throughput(FPS) |
|---|---|---|---|---|
| Wan2.1 14B (teacher) | 14B | 100 | 83.58 | 0.18 |
| Wan2.1 + DMD2 | 1.3B | 4 | 84.56 | 14.6 |
| Wan2.1 + rCM | 1.3B | 4 | 84.43 | 14.6 |
| Wan2.1 + rCM | 14B | 2 | 85.05 | 8.3 |
关键发现¶
- rCM 在质量上匹配或超过 DMD2,同时在多样性上明显优于 DMD2(Figure 1 显示 DMD2 生成物体位置/姿态趋同)
- 14B rCM 4步 GenEval 0.83,接近 teacher 70步的 0.84
- 视频任务中 rCM 2步即可达到接近 teacher 的 VBench 分数
- \(\lambda=0.01\) 在质量和多样性之间取得最佳平衡
- 纯 sCM 在文字渲染等精细场景存在明显质量缺陷,rCM 成功修复
亮点与洞察¶
- 首次将 JVP-based 连续时间一致性扩展到 14B 参数和 5 秒视频
- 从前向/反向散度互补性的角度理解蒸馏方法的统一框架
- 无需 GAN 调优或大量超参搜索,\(\lambda=0.01\) 跨任务通用
- rCM 的多样性优势对交互式 world model 等需要多样响应的场景尤为重要
局限与展望¶
- 需要额外的 fake score 网络(内存开销)
- JVP 计算仍比标准前向传播慢,训练成本高
- 1步视频生成质量仍有明显下降(VBench 从 85.05 降至 83.02)
- 对 autoregressive video diffusion 的扩展仅有展望
相关工作与启发¶
- sCM 和 MeanFlow 提供了理论基础
- DMD/DMD2 提供了反向散度蒸馏的实践方案
- DDO 和 DDRL 的前向+反向散度联合思想是 rCM 的哲学基础
- 为大规模视觉生成模型的部署提供了实用加速方案
技术细节补充¶
- TrigFlow 噪声调度:\(\alpha_t = \cos(t), \sigma_t = \sin(t)\),与 rectified flow 通过 SNR 匹配互转
- Fake score 网络用 flow matching loss 在学生生成数据上训练,交替优化
- Selective Activation Checkpointing (SAC) 用于减少内存消耗
- Teacher 使用 CFG,CFG 同时蒸馏到学生中
- 全参数微调(不用 LoRA),强调 rCM 的稳定性
- 实验涵盖 Cosmos-Predict2(0.6B/2B/14B T2I)和 Wan2.1(1.3B/14B T2V)
- Wan2.1 14B 2步加速达 8.3 FPS vs teacher 的 0.18 FPS(约 46× 加速)
评分¶
- 新颖性: ⭐⭐⭐⭐ 前向+反向散度结合的理论洞察有价值,但各组件已知
- 实验充分度: ⭐⭐⭐⭐⭐ 验证规模前所未有(14B参数、T2I+T2V、多步消融)
- 写作质量: ⭐⭐⭐⭐⭐ 理论分析清晰,工程细节详尽
- 价值: ⭐⭐⭐⭐⭐ 解决了大规模扩散模型加速的核心问题,实用性极强