Large Scale Diffusion Distillation via Score-Regularized Continuous-Time Consistency¶
会议: ICLR 2026
arXiv: 2510.08431
代码: 项目页
领域: 图像生成
关键词: 连续时间一致性模型, Score蒸馏, 大规模蒸馏, JVP, 少步生成
一句话总结¶
提出 rCM(score-regularized continuous-time consistency model),首次将连续时间一致性蒸馏扩展到 14B 参数的文生图/视频模型,通过结合前向散度(一致性)和反向散度(score蒸馏),在保持多样性的同时匹配 DMD2 的质量,实现 15-50× 加速。
研究背景与动机¶
- sCM(连续时间一致性模型)理论优雅,但在大规模文生图/视频模型上的适用性不明——JVP 计算与 FlashAttention、并行训练不兼容
- sCM 在细节生成上存在质量问题(误差累积 + 前向散度的 mode-covering 特性导致质量扩散)
- Score/对抗蒸馏方法(如 DMD2)在质量上领先,但存在模态坍塌和多样性不足
- 前向散度(一致性模型)与反向散度(score蒸馏)具有互补性
方法详解¶
整体框架¶
rCM = sCM(前向散度一致性蒸馏)+ DMD(反向散度score蒸馏)+ 基础设施优化。训练交替优化学生模型(rCM loss)和 fake score 网络(flow matching loss)。
关键设计¶
-
FlashAttention-2 JVP 内核: 开发 Triton 内核,将 JVP 集成到 FlashAttention-2 前向传播中,支持自注意力和交叉注意力,兼容 FSDP 和 Context Parallelism,使 sCM 训练可扩展到 10B+ 参数模型。
-
Score 正则化: 将 DMD loss 作为长跳正则器补充 sCM。最终目标: $\(\mathcal{L}_{\text{rCM}}(\theta) = \mathcal{L}_{\text{sCM}}(\theta) + \lambda \mathcal{L}_{\text{DMD}}(\theta)\)$ \(\lambda=0.01\) 跨模型和任务通用。sCM 提供 mode-covering(高多样性),DMD 提供 mode-seeking(高质量)。
-
稳定时间导数计算: 针对大模型 JVP 训练不稳定问题,提出两种方案:
- 半连续时间:空间部分用 JVP,时间部分用有限差分近似(\(\Delta t = 10^{-4}\))
- 高精度时间:对时间嵌入层强制 FP32 精度
-
Rollout 策略: 学生可做任意步采样,随机选择步数 \(N \in [1, N_{\max}]\),仅对最后一步反传 DMD loss,使用随机时间步确保覆盖整个时间范围。
损失函数 / 训练策略¶
- sCM loss(切线归一化):\(\mathcal{L}_{\text{sCM}} = \mathbb{E}\left[\left\|\mathbf{F}_\theta - \mathbf{F}_{\theta^-} - \frac{\mathbf{g}}{\|\mathbf{g}\|_2^2 + c}\right\|_2^2\right]\)
- DMD loss:基于 fake score 和 teacher score 的差异引导学生
- Fake score 网络用 flow matching loss 在学生生成的数据上训练
实验关键数据¶
主实验(GenEval T2I)¶
| 模型 | 参数 | NFE | Overall | Counting | Position |
|---|---|---|---|---|---|
| FLUX.1-dev | 12B | 50 | 0.66 | 0.74 | 0.22 |
| Cosmos-Predict2 14B (teacher) | 14B | 70 | 0.84 | 0.79 | 0.64 |
| Cosmos-Predict2 + DMD2 | 2B | 4 | 0.80 | 0.70 | 0.57 |
| Cosmos-Predict2 + rCM | 2B | 4 | 0.81 | 0.73 | 0.58 |
| Cosmos-Predict2 + rCM | 14B | 4 | 0.83 | 0.80 | 0.59 |
| Cosmos-Predict2 + rCM | 14B | 1 | 0.82 | 0.84 | 0.49 |
VBench 视频实验¶
| 模型 | 参数 | NFE | Total Score | Throughput(FPS) |
|---|---|---|---|---|
| Wan2.1 14B (teacher) | 14B | 100 | 83.58 | 0.18 |
| Wan2.1 + DMD2 | 1.3B | 4 | 84.56 | 14.6 |
| Wan2.1 + rCM | 1.3B | 4 | 84.43 | 14.6 |
| Wan2.1 + rCM | 14B | 2 | 85.05 | 8.3 |
关键发现¶
- rCM 在质量上匹配或超过 DMD2,同时在多样性上明显优于 DMD2(Figure 1 显示 DMD2 生成物体位置/姿态趋同)
- 14B rCM 4步 GenEval 0.83,接近 teacher 70步的 0.84
- 视频任务中 rCM 2步即可达到接近 teacher 的 VBench 分数
- \(\lambda=0.01\) 在质量和多样性之间取得最佳平衡
- 纯 sCM 在文字渲染等精细场景存在明显质量缺陷,rCM 成功修复
亮点与洞察¶
- 首次将 JVP-based 连续时间一致性扩展到 14B 参数和 5 秒视频
- 从前向/反向散度互补性的角度理解蒸馏方法的统一框架
- 无需 GAN 调优或大量超参搜索,\(\lambda=0.01\) 跨任务通用
- rCM 的多样性优势对交互式 world model 等需要多样响应的场景尤为重要
局限与展望¶
- 需要额外的 fake score 网络(内存开销)
- JVP 计算仍比标准前向传播慢,训练成本高
- 1步视频生成质量仍有明显下降(VBench 从 85.05 降至 83.02)
- 对 autoregressive video diffusion 的扩展仅有展望
相关工作与启发¶
- sCM 和 MeanFlow 提供了理论基础
- DMD/DMD2 提供了反向散度蒸馏的实践方案
- DDO 和 DDRL 的前向+反向散度联合思想是 rCM 的哲学基础
- 为大规模视觉生成模型的部署提供了实用加速方案
技术细节补充¶
- TrigFlow 噪声调度:\(\alpha_t = \cos(t), \sigma_t = \sin(t)\),与 rectified flow 通过 SNR 匹配互转
- Fake score 网络用 flow matching loss 在学生生成数据上训练,交替优化
- Selective Activation Checkpointing (SAC) 用于减少内存消耗
- Teacher 使用 CFG,CFG 同时蒸馏到学生中
- 全参数微调(不用 LoRA),强调 rCM 的稳定性
- 实验涵盖 Cosmos-Predict2(0.6B/2B/14B T2I)和 Wan2.1(1.3B/14B T2V)
- Wan2.1 14B 2步加速达 8.3 FPS vs teacher 的 0.18 FPS(约 46× 加速)
评分¶
- 新颖性: ⭐⭐⭐⭐ 前向+反向散度结合的理论洞察有价值,但各组件已知
- 实验充分度: ⭐⭐⭐⭐⭐ 验证规模前所未有(14B参数、T2I+T2V、多步消融)
- 写作质量: ⭐⭐⭐⭐⭐ 理论分析清晰,工程细节详尽
- 价值: ⭐⭐⭐⭐⭐ 解决了大规模扩散模型加速的核心问题,实用性极强