跳转至

Scale-wise Distillation of Diffusion Models

会议: ICLR 2026
OpenReview: https://openreview.net/forum?id=Z06LNjqU1g
论文: Project Page
代码: https://github.com/yandex-research/swd (项目页提供)
领域: 扩散模型 / 模型蒸馏 / 图像生成
关键词: 扩散蒸馏, 尺度递进生成, few-step采样, MMD, 文生图/文生视频

一句话总结

SwD 提出"按尺度蒸馏"框架,把任意预训练扩散模型蒸成一个 few-step 生成器,并让它在每个采样步上逐级提高分辨率——前几步在低分辨率上跑、最后才到全分辨率,从而在不增加步数的前提下又把单步算量砍掉一半;同时配套提出一个基于 MMD 的 patch 级蒸馏损失,单独用就能逼近 SOTA,让文生图提速约 2×、文生视频提速约 3× 且画质不降。

研究背景与动机

领域现状:把扩散模型蒸馏成 1–4 步的 few-step 生成器,是当前给大规模文生图/文生视频提速最成功的方向之一。以 DMD2、ADD 为代表的分布匹配(distribution matching)方法已经能在约 4 步内逼近教师模型的质量。

现有痛点:这些方法几乎都把注意力集中在"减少采样步数"这一个维度上,而把模型架构、输入分辨率等同样有潜力的自由度冻结住了。问题是,步数继续往下压(4 步→2 步→1 步)会越来越难、质量明显劣化,说明单靠减步数的红利快被挖空了,提速应该从别的轴上找。

核心矛盾:扩散过程本质上是"从粗到细"的——近年工作(Rissanen、Dieleman)指出反向扩散其实在隐式做"频谱自回归",即高噪声步只在恢复低频结构、高频细节要到低噪声步才出现。既然高噪声步根本没有高频信息,在全分辨率上算高噪声步就是浪费:那些被噪声掩盖掉的高频分量,在低分辨率 latent 里压根不存在,算了也白算。

本文目标:(1) 验证这种"频谱自回归"是否也适用于隐空间(VAE latent)以及视频的时间维;(2) 据此设计一个能让 few-step 模型在采样过程中递进升分辨率的蒸馏框架;(3) 顺手给分布匹配蒸馏家族补一个更简单高效的损失。

切入角度:作者先对 SD3.5、Wan2.1 等模型的 VAE latent 做频谱分析(RAPSD),确认 latent 的空间/时间分辨率确实随扩散过程"隐式增长"——高噪声步可以安全地用更低分辨率表示而不丢信号。这给"什么时候该用多大分辨率"提供了原则性依据。

核心 idea:用一个单一扩散过程、单一 few-step 模型实现多尺度递进生成(噪声→低分辨率→逐级升到全分辨率),把中间高噪声步的冗余计算省掉;再用一个特征空间的 MMD 损失把蒸馏做得又快又好。

方法详解

整体框架

SwD(Scale-wise Distillation)的目标是把一个普通的预训练扩散模型,蒸成一个 few-step(4 或 6 步)生成器,但这个生成器与众不同的地方是:它的每一步都对应一个递增的分辨率。作者预先定义一组 few-step 时间步表 \([t_1,\dots,t_N]\),并给每个 \(t_i\) 配一个非递减的尺度 \(s_i\)\([s_1,\dots,s_N]\))。采样从最低分辨率 \(s_1\) 上的高斯噪声开始,每往前走一步就把分辨率往上抬一级,直到最后一步到达全分辨率。这与"级联扩散"(cascaded DM,每级都从头跑一遍完整扩散)截然不同——SwD 全程是一个扩散过程、一个模型。

整条 pipeline 由三件事拼起来:一是"什么时候该用多大分辨率"——由前面频谱分析得到的尺度/时间步调度决定;二是"跨尺度时怎么把中间样本升上去而不破坏噪声统计"——靠 few-step 模型天然的"预测干净样本 \(\hat{x}_0\) → 再加噪"机制,在 \(\hat{x}_0\) 上升采样后再重新加噪;三是"用什么损失把学生对齐教师"——分布匹配损失(DMD/ADD)外加新提出的 patch 级 MMD 损失。训练时模型在相邻尺度对 \([s_i, s_{i+1}]\) 上迭代,学会既生成又当一个鲁棒的"升采样器"。

%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400}}}%%
flowchart TD
    A["高斯噪声<br/>最低分辨率 s1"] --> B["频谱分析驱动的<br/>尺度/时间步调度<br/>高噪声步→低分辨率"]
    B --> C["Scale-wise 采样<br/>预测 x̂0 → 升采样 → 重新加噪"]
    C -->|未到全分辨率,升一级 s_i→s_{i+1}| C
    C -->|到达全分辨率| D["全分辨率图像/视频"]
    C -.训练时对齐教师.-> E["Patch级 MMD 蒸馏损失<br/>+ DMD/GAN"]

关键设计

1. 频谱分析驱动的尺度调度:从"高噪声步没有高频"推出"高噪声步该用低分辨率"

这是整套方法的物理依据,也直接回答了"每一步该配多大分辨率"。作者对 SD3.5(128×128 latent)和 Wan2.1(21×60×104 latent)做了径向平均功率谱密度(RAPSD)分析,发现 latent 的频谱大致服从幂律,且加噪过程会逐步滤掉高频:在 \(t=800\) 时,噪声已经把 32×32 以上分辨率才会出现的高频分量盖住,于是把 128×128 的 latent 下采样 4× 不会丢信号("绿区"),但下采样 8× 就会破坏数据信号("红区")。视频的时间维同理——\(t=600\) 时 21 帧的有效信号只需约 11 帧就能表示。结论一句话:隐空间扩散在高噪声步可以用更低分辨率建模而不丢信号,空间和时间维都成立。正因如此,在高噪声步用全分辨率算就是把算力花在被噪声掩盖、看不见的高频上。不过作者也诚实指出,频谱分析只给平均结论、没考虑升采样伪影,所以最终的尺度/时间步表仍是超参——实践中取默认 few-step 时间步表、略往噪声更大的方向平移,尺度表从 2–4× 低分辨率起步逐级升到全分辨率,且方法对具体调度不敏感。

2. Scale-wise 采样与训练:在 \(\hat{x}_0\) 上升采样再加噪,跨尺度而不破坏噪声统计

有了"每步用多大分辨率",还得解决"怎么从低分辨率样本平滑升到高分辨率"这个跨尺度难题。最朴素的做法是直接升采样含噪 latent \(x_{t_i}\),但这会扭曲噪声方差、引入局部噪声相关性,破坏分布。作者的关键观察是:few-step 模型采用"预测干净样本 \(\hat{x}_0\) → 重新加噪到下一噪声级"的随机多步采样,这天然提供了一个正确的升采样时机——在 \(\hat{x}_0\)(干净、无噪声统计约束)上升采样,再加噪,就能在高分辨率上保持正确的噪声统计。论文用一个对照实验佐证(Table 1,FID-5K,64→128):策略 A 是直接对原始 \(x_0\) 加噪(参考),策略 B 是先升采样 \(x_0^{down}\) 再加噪,策略 C 是先对 \(x_0^{down}\) 加噪再升采样;结果 C 产出严重 OOD 的含噪 latent(FID 122/223/327),而 B 在高噪声级上与参考 A 接近(如 \(t=800\) 时 14.7 vs 13.7),因为重噪声会掩盖插值伪影。具体实现上,空间维用 bicubic 插值、时间维用相邻帧融合。采样时:给定 \(s_{i-1}\) 上的含噪样本,模型预测 \(\hat{x}_0^{i-1}\),升采样到 \(s_i\) 后用前向扩散重新加噪得到更高分辨率的 \(\hat{x}_{t_i}\),再预测下一步。训练时在相邻尺度对 \([s_i, s_{i+1}]\) 上迭代:把全分辨率图像在像素空间下采样到两个尺度(作者发现像素空间下采样明显优于 latent 下采样)再编码进 VAE,把低尺度 latent 升采样并加噪到 \(t_i\),喂给模型预测目标尺度 \(s_{i+1}\)\(\hat{x}_0\),最后算蒸馏损失。这样训出的生成器同时也是个鲁棒的升采样器,能消掉单靠重新加噪去不干净的升采样伪影。

3. Patch 级 MMD 蒸馏损失:用预训练 DM 的特征空间做分布匹配,不需额外可训练模型

分布匹配蒸馏(DMD/GAN)通常要额外训练一个 fake-score 网络或判别器,又重又慢。作者提出一个简单到出奇的替代:直接在预训练教师 DM 的中间特征空间里做学生与目标分布的最大均值差异(MMD)匹配。MMD 的定义为 \(\text{MMD}^2(P,Q)=\mathbb{E}_{x,x'\sim P}[k(x,x')]+\mathbb{E}_{y,y'\sim Q}[k(y,y')]-2\mathbb{E}_{x\sim P,y\sim Q}[k(x,y)]\)。做法是:先把生成样本和目标样本都在一个预设噪声区间内加噪(利用 DM 能在不同噪声级提取结构/细粒度信号的能力,实践中低-中噪声区间略好),再从教师 transformer 的中间 block 抽特征图 \(F\in\mathbb{R}^{N\times L\times C}\)\(L\) 是空间 token 数,对应 ViT 的 patch 表示),然后在空间 token 分布上算 MMD。作者比较了线性核(\(k(x,y)=x^\top y\))和 RBF 核,两者表现相近,于是简化用线性核——此时损失退化为逐图像、对空间 token 求均值后的 MSE

\[\mathcal{L}_{\text{MMD}}=\sum_{n=1}^{N}\left\|\frac{1}{L}\sum_{l=1}^{L}F^{\text{real}}_{n,l,\cdot}-\frac{1}{L}\sum_{l=1}^{L}F^{\text{fake}}_{n,l,\cdot}\right\|^2.\]

一个关键细节:特征均值必须逐图像算,而不是在整个 batch 上算——后者会冲掉条件相关信息、导致文本相关性下降。这个损失可看作 GAN 训练里"feature matching loss"在扩散蒸馏上的改造,但有三点不同:(i) 用预训练 DM 而非可学习判别器;(ii) 利用了不同噪声级的反馈;(iii) 逐图像而非整 batch 算特征均值。总损失把它叠在 DMD/GAN 上:\(\mathcal{L}_{\text{SwD}}=\mathcal{L}_{\text{MMD}}+\alpha\cdot\mathcal{L}_{\text{DMD}}+\beta\cdot\mathcal{L}_{\text{GAN}}\)。最妙的是它不需要任何额外可训练模型,所以又快又容易接进现有蒸馏管线,单独用还能当一个有竞争力的 baseline。

损失函数 / 训练策略

  • 总目标:\(\mathcal{L}_{\text{SwD}}=\mathcal{L}_{\text{MMD}}+\alpha\cdot\mathcal{L}_{\text{DMD}}+\beta\cdot\mathcal{L}_{\text{GAN}}\);SDXL 与 Wan2.1 因为基础模型在低分辨率上生成质量不佳、DMD 损失在 scale-wise 设定下表现差,只用 \(\mathcal{L}_{\text{MMD}}\) 训练。
  • 数据:全程只用教师生成的合成数据(隔离蒸馏设定,避免外部数据偏置),蒸馏收敛很快(约 3K 迭代),数据需求远小于训练 DM。
  • 设定:蒸到 4 或 6 步;文生图尺度从 256×256 或 512×512 起、升到 1024×1024;文生视频从 21×160×272 起、升到 81×480×832。

实验关键数据

主实验(文生图,Table 3 节选)

模型 步数 延迟(s/图) PS↑(MJHQ) HPSv3↑ IR↑ GenEval↑
SD3.5-L(教师) 28 8.3 21.8 10.4 1.04 0.70
SD3.5-L-Turbo 4 0.63 21.7 9.9 0.9 0.70
SD3.5-L-SwD 4 0.39 21.8 11.1 1.22 0.71
FLUX(教师) 30 10.0 21.7 10.7 0.93 0.66
FLUX-Schnell 4 1.41 21.5 10.3 0.96 0.69
FLUX-SwD 4 0.72 21.9 11.6 1.06 0.71

SwD 在各自模型家族内取得 PS/HPSv3/IR/GenEval 的 SOTA,且延迟相比最快的同类再快近 2×,往往比教师还好但快 10× 以上。人类偏好研究(Figure 6)显示 SwD 在图像复杂度、美学上优于多数对手(包括更贵的教师及其蒸馏版),文本相关性与缺陷率持平。

主实验(文生视频,Table 2)

模型 延迟(s/视频) VisionReward↑ VideoReward↑ VBench2 Overall↑
Wan2.1(教师) 137 0.038 5.43 51.6
CausVid(3步) 4.2 0.042 6.21 52.3
Spatial SwD(4步) 2.1 0.064 6.15 52.8
SwD(4步,时空双维) 1.8 0.064 6.27 53.2

SwD 比教师快 72× 且质量更好;比 CausVid 快约 2.3×、质量相当;时空双维比只做空间维不掉质量还更快。

消融实验(MMD 损失,SD3.5-M SwD,MJHQ30K,Table 6)

配置 PS↑ HPSv3↑ IR↑ FID↓ 说明
\(\mathcal{L}_{\text{SwD}}\)(完整) 21.8 10.7 1.11 13.6 主模型
\(\mathcal{L}_{\text{MMD}}\) 21.5 10.5 1.15 13.8 单用 MMD 仍很强
\(\mathcal{L}_{\text{SwD}}\) w/o \(\mathcal{L}_{\text{MMD}}\) 21.2 9.7 0.91 19.5 去掉 MMD 明显劣化
A: RBF 核 21.8 10.8 1.09 13.7 与线性核相近
B: batch 平均 21.5 10.5 0.97 16.4 改 batch 平均掉点
C: 不加噪 21.3 10.2 1.01 16.6 不加噪掉点

关键发现

  • MMD 损失是质量主力:从完整损失里去掉 \(\mathcal{L}_{\text{MMD}}\),FID 从 13.6 恶化到 19.5、HPSv3 从 10.7 掉到 9.7;而单用 \(\mathcal{L}_{\text{MMD}}\) 几乎不掉,证明它本身就是有竞争力的独立蒸馏目标,且因不训练额外模型,迭代速度快 7× 以上(Table 5)。
  • 两个设计细节缺一不可:把逐图像特征均值改成整 batch 平均(B)、或去掉特征提取前的加噪(C),FID 都从约 13.7 退到 16+,说明"逐图像均值保留条件信息"和"用不同噪声级反馈"都很关键。
  • scale-wise 在等步数下不掉质量、等耗时下明显更好:4-vs-4、6-vs-6 同步数对比中 scale-wise 与全分辨率人评无明显差异(自动指标甚至更优);而把耗时对齐(scale-wise 4 步 vs 全分辨率 2 步)时,scale-wise 在降缺陷、提升复杂度上明显占优,因为 2 步全分辨率 baseline 缺陷率很高。
  • 提速来源清晰:相同步数下,scale-wise 相比全分辨率在训练和采样上都约 2×(文生图)、约 3×(文生视频)(Table 4/5)。

亮点与洞察

  • 把"减步数"换成"减每步分辨率":当步数红利见顶时,作者从频谱分析里挖出"高噪声步根本没高频、不必全分辨率算"这条新的提速轴,思路很漂亮——这是一个可迁移的视角,任何 few-step 扩散都能问一句"我每一步真的需要这么高分辨率吗"。
  • 升采样时机的洞察:在 \(\hat{x}_0\) 上升采样再加噪 vs 在含噪 latent 上升采样,看似细节,实则决定了跨尺度过渡是否保持正确噪声统计;few-step 模型的"预测干净样本"机制恰好天然提供了这个正确时机,设计与已有采样算法严丝合缝。
  • 不训练额外模型的分布匹配损失:MMD 直接借用教师自己的特征空间,省掉了 DMD/GAN 必须的可训练 score/判别器,单用就接近 SOTA、迭代快 7×,对想低成本接入蒸馏的人非常实用。

局限与展望

  • 低分辨率生成质量是前提:当基础模型在低分辨率上生不出合理图像时(如 SDXL、Wan2.1),DMD 损失在 scale-wise 设定下会失效,只能退回纯 MMD 训练,这也带来 SDXL-SwD 缺陷略多于 DMD2 的代价(作者在附录 B 讨论)。
  • 调度仍是超参:频谱分析只给平均结论、未建模升采样伪影,所以尺度/时间步表最终还是靠经验设定,虽然作者称方法对调度不敏感,但缺乏自动确定最优调度的方法。
  • FID 自身的局限:作者也承认 FID 与人类感知相关性差,主要靠人评和 PS/HPSv3 等偏好指标支撑结论。
  • 展望:作者期待把 MMD 损失发展成一个完全自包含、不需任何额外可训练模型的蒸馏管线。

相关工作与启发

  • vs 级联/多阶段渐进扩散(cascaded DM、Edify、多阶段 pipeline): 它们要么每级从头跑一遍完整扩散、要么需专门技巧处理"跳点"保持采样连续;SwD 用单一扩散过程、单一 few-step 模型,靠 \(\hat{x}_0\) 升采样自然处理跨尺度,且能直接接进现有蒸馏流程,把任意预训练 DM 改成渐进 few-step 模型。
  • vs DMD2 / ADD(分布匹配蒸馏): SwD 与它们正交互补——既复用它们的 few-step 采样算法,又用 MMD 损失补强;且 SwD 在不增加步数的前提下进一步砍每步算量。
  • vs IMM / DMMD 等 MMD 用法: IMM 在原始生成预测上、用固定核做一致性蒸馏;DMMD 用噪声自适应判别器做 MMD 梯度流。SwD 的不同在于把 MMD 放进预训练 DM 的特征空间、利用多噪声级反馈、逐图像算均值,得到一个更强的分布匹配目标。

评分

  • 新颖性: ⭐⭐⭐⭐⭐ "减每步分辨率"作为新提速轴 + 特征空间 MMD 损失,两个点都有原创性
  • 实验充分度: ⭐⭐⭐⭐⭐ 覆盖 SDXL/SD3.5/FLUX/Wan2.1 多模型、文生图+文生视频、自动指标+人评+消融
  • 写作质量: ⭐⭐⭐⭐ 频谱分析→方法→实验逻辑清晰,部分调度/设定细节下放附录
  • 价值: ⭐⭐⭐⭐⭐ 即插即用、提速 2–3×不掉质量,且 MMD 损失低成本可复用,实用价值高