Exploring and Exploiting Stability in Latent Flow Matching¶
会议: ICML 2026
arXiv: 2605.08398
代码: https://github.com/briqr/explo-r-it-ing_lfm_stability
领域: 扩散模型 / Flow Matching / 数据剪枝
关键词: Latent Flow Matching、轨迹稳定性、数据剪枝、Coarse-to-Fine、推理加速
一句话总结¶
本文系统刻画了 Latent Flow Matching(LFM)的"轨迹稳定性"——同一噪声种子下,剪掉 75% 数据、换大小架构、改训练种子都能产生几乎相同的图像;进而把这个性质转化成两个实用算法:(1) 用 balanced-clustering 剪枝可在 CelebA-HQ 上把 50% 数据剪掉而 FID 反而轻微提升、ImageNet 上 75% 数据可剪;(2) Coarse-to-Fine 两段式生成,把 DiT-XL/2 (675M) 和 DiT-S/2 (33M) 拼起来,推理快 2.15×。
研究背景与动机¶
领域现状:扩散模型已经是图像/视频/医学影像生成的主流范式,Flow Matching(FM)作为 DDPM 的 ODE 替代品,因为采样步数少而越来越受欢迎;Latent FM (LFM) 进一步把 FM 搬到 VAE latent 空间,是目前 SD3、Flux 一类大模型的底座。
现有痛点:训练 LFM 巨贵——需要超大数据集、长时间、海量算力,条件模型还需要大量人工标注;但社区从来没系统问过:数据集到底要多大?模型要多大? 一些零散观察提示稳定性的存在(Kadkhodaie 在 score-based diffusion 上看到不同 split 训练的模型趋同),但都没给出可用的剪枝/加速方案,且都局限在 pixel 空间低分辨率。
核心矛盾:理论上 FM 学的是"分布之间的传输",应当对样本分布的小扰动敏感;但实际经验表明,FM 模型在大幅扰动(删一半数据、换 20× 大的架构)下仍然把同一个 \(x_0\) 映射到几乎相同的 \(x_1\)。这个"稳定性"如果是真的,就意味着大量训练数据其实在做重复劳动,可以剪。
本文目标:(1) 在 LFM 上严格度量这种稳定性(同种子下生成的人脸 ArcFace 相似度、ImageNet 的 DINO 相似度);(2) 给出理论解释(基于 Bertrand 2025 的 FM 闭式解里 softmax 极度尖峰);(3) 把稳定性翻译成实用算法——数据剪枝 + 模型剪枝。
切入角度:Bertrand 2025 证明 rectified FM 的最优速度场 \(\hat{u}^*(x,t)=\sum_i \lambda_i(x,t)\frac{x^i-x}{1-t}\) 里,softmax 权重 \(\lambda_i\) 在早期就极度尖峰——单个训练样本主导整个轨迹。所以只要这个"主导样本"留在数据里,剪不剪掉其他样本对轨迹影响极小。
核心 idea:用 LFM 的内在稳定性同时换"训练效率(数据/标注减少)"和"推理效率(大小模型拼接)",并用三种剪枝准则配合 balanced clustering 来实证。
方法详解¶
整体框架¶
方法分两条腿走:
- 数据侧(剪枝):定义三种 sample-scoring 准则(gradient \(\mathcal{G}\)、loss \(\mathcal{L}\)、clustering \(\mathcal{C}\)),把训练集 \(S\) 剪成 \(S'\subset S\),对每个剪枝比例 \(pr\) 训 LFM 比 FID。
- 推理侧(C2F):训两个 DiT,小的 DiT-S/2 (33M) 跑前段 \(t\in[0,t_0)\)、大的 DiT-XL/2 (675M) 跑后段 \(t\in[t_0,1]\);中间用 ODE 反向积分 + seam loss 把两段缝起来。
关键设计¶
-
三种 FM 兼容的剪枝准则:
- 功能:给每个训练样本算一个"重要性分数",按分数取前 \(1-pr\) 留下。
- 核心思路:
- \(\mathcal{G}\)(梯度):用 7% 步数训一个小代理模型,固定 \(M=2\) 个噪声 + \(T=8\) 个 timestep,算每样本梯度范数平方,除以 per-\(t\) 均值消除 timestep bias,得 \(s_i^{\mathcal{G}}\)。
- \(\mathcal{L}\)(loss):把上面公式里梯度换成 loss 值,便宜很多,常用来作主力。
- \(\mathcal{C}\)(聚类):在 CLIP image embedding 空间用 k-means 聚类,分 proportional(按簇大小取样,保持原分布)和 balanced(每簇等量,平衡数据集)两套;簇内可按"离中心近/远/kernel-mean 匹配"选样本。
- 设计动机:discriminative 模型上 \(\mathcal{L}\) 选 hard-example 有效;但 FM 的 loss 大部分来自共享噪声的方差,必须沿"共享噪声路径"+ EMA 才能拿到稳定信号——这是从分类剪枝迁到 FM 的核心适配。
-
Coarse-to-Fine 两段式生成 (C2F):
- 功能:把推理 cost 砍 ≈2.15×,同时保持/提升 FID。
- 核心思路:先在剪枝后的 \(S'\) 上训轻量 Coarse 模型 \(v_C\),覆盖 \(t\in[0,t_0)\);保留预训练 Fine 模型 \(v_F\)(DiT-XL/2)覆盖 \(t\in[t_0,1]\)。为了让 \(t_0\) 处接缝平滑,用 Fine 做 ODE 反向积分 \(x_{k+1}=x_k+h\,v_F(x_k,t_k),\,h<0\) 从 \(x_1\) 回到 \(x_{t_0}\),把这个 \(x_{t_0}\) 当 Coarse 的训练目标,并加 seam loss \(\mathcal{L}_{\text{seam}}^v=\|v_F(x_{t_0},t_0)-v_C(x_{t_0},t_0)\|^2\)。
- 设计动机:稳定性说明大小模型在轨迹上"长得很像"——既然如此,前段噪声主导阶段用小模型就够了,没必要让 675M 参数白跑半条路径;seam loss 只是把两段"接缝"对齐,几个 epoch 就能 fine-tune 出可用 C2F。
-
Balanced Clustering 兼顾公平性 (𝒞ᵦ):
- 功能:在 CLIP embedding 上做 k-means,再每簇等量取样,自动平衡数据集偏差。
- 核心思路:CelebA-HQ 上未剪枝模型生成的图像 gender 分布偏斜(女多男少);用 \(\mathcal{C}_b\) 剪枝后,PaliGemma 算的 gender KL 散度从 0.044 降到 0.016(甚至比显式 label-aware 的 \((\mathcal{C}_b)_{\text{gender}}=0.005\) 仅差一点),age/skin-tone/hair-color 等多个属性 KL 也同步下降。
- 设计动机:稳定性保证了 cluster 间"互不干扰",删一个 cluster 内的样本不影响其他 cluster 的轨迹;所以可以放心地用 cluster-level 平衡来纠数据偏差,同时不损 FID。
损失函数 / 训练策略¶
Coarse 模型的总损失: \(\mathcal{L}_{\text{coarse}}=\mathbb{E}\,\mathcal{L}_{\text{FM}}^{t\in[0,t_0)}+\lambda_v\,\mathcal{L}_{\text{seam}}^v\)。 seam 系数 \(\lambda_v\) 是个超参,文中 \(t_0=0.7\) 时 FID/速度平衡最好;Coarse 用 DiT-S/2,Fine 用 DiT-XL/2,在 H100 上 batch 128、\(256^2\) 分辨率下,C2F 跑 43.53 ms/img,Fine-only 跑 93.95 ms/img。
实验关键数据¶
主实验¶
CelebA-HQ (\(pr=0.5\)) 不同剪枝准则下的 FID(越低越好):
| 方法 | FID | 备注 |
|---|---|---|
| Unpruned | 24.24 | 全数据基线 |
| Random | 25.25±0.38 | 随机剪 |
| \(\mathcal{G}\) (高梯度) | 24.62 | 几乎持平 |
| \(\mathcal{G}^{-1}\) (低梯度) | 29.75 | 显著恶化 |
| \(\mathcal{L}\) (高 loss) | 33.92 | 最差(反直觉,和分类相反) |
| \(\mathcal{L}^{-1}\) (低 loss) | 23.49 | 反而轻微改进 |
| \(\mathcal{C}_p\) | 25.19 | 按比例 |
| \(\mathcal{C}_b\) | 22.80 | balanced clustering 最优 |
| \(\mathcal{C}_b^\kappa\) | 23.42 | kernel 变体 |
ImageNet(DiT-XL/2 conditional,200k 迭代):
| 剪枝率 \(pr\) | FID 趋势 | 备注 |
|---|---|---|
| 0 (unpruned) | 基线 | |
| 0.75 | 略升至 600k 后趋同 | 长期最稳定收益 |
| 0.9 | 200k 前最快,590k 后跌 | 中期最强 |
| 0.95 | 170k 前最快,之后崩 | 短期 sprint |
消融实验¶
C2F 在 CelebA-HQ 上,seam 位置 \(t_0\) 的影响:
| 配置 | FID@\(t_0=0.7\) | 推理速度 (ms/img) | 说明 |
|---|---|---|---|
| Fine-only | 24.24 | 93.95 | 全用 DiT-XL/2 |
| C2F (unpruned Coarse) | 略好 | 43.53 | 2.15× 加速 |
| C2F + \(\mathcal{C}_b\) pruned Coarse | 最优 | 43.53 | 速度+FID 双赢 |
| C2F_male(违反稳定性) | 44.92 | 43.53 | seam loss 救不了 |
关键发现¶
- \(\mathcal{L}\) 在 FM 上的表现与分类模型完全相反:分类里"高 loss 样本"是 hard-example、留下来有用;但 FM 里 \(\mathcal{L}\) 反而最差(FID 33.92),\(\mathcal{L}^{-1}\) 最好。原因是 FM 的高 loss 多半来自"密度低的离群样本",而 FM 主要靠"主导样本"建路径,离群样本反而拖累训练。这是个很反直觉、对从业者有用的发现。
- 不同扰动对稳定性的影响差异巨大:换 DiT-S/2→DiT-XL/2 (s=0.81 几乎不变)、换 U-Net (s=0.55 稍降)、移除一个 gender 模态 (s=0.58),但换 VAE 种子 (s=0.32) 或 flip latent 全部 feature map 符号 (s=0.32) 会完全打破稳定性。这说明稳定性的根源在 latent 空间几何 + FM 目标的耦合,不是架构本身。
- score-based diffusion 不具备相同稳定性:把 FM 换成 score-based 后稳定性完全消失,说明这是 rectified FM 这个特定目标的性质,不是所有 diffusion 都有。
- Balanced clustering 同时减少 bias 且不损 FID:\(\mathcal{C}_b\) 把 gender KL 从 0.044 降到 0.016,FID 不退反进。这给"数据集均衡"提供了一个不需要标签的简洁方案。
亮点与洞察¶
- 把"稳定性"从现象提升为理论解释 + 实用算法 是这篇文章最大的贡献:直接拿 Bertrand 2025 的 closed-form solution 当解释根基,再翻成数据剪枝 + C2F 两个落地方案,理论-实证-工程闭环很完整。
- C2F 的工程价值很大:在不动 Fine 模型权重的前提下,只训一个小 Coarse + seam loss,就能在生产环境拿到 2.15× 加速;这种"模型蒸馏的 partial 版本"对部署 DiT-XL/Flux 级模型非常友好。
- 稳定性的边界条件(VAE 改变、latent 符号翻转就会破)这个发现对 LFM 社区是个警告——任何动 VAE 的操作(换 VAE、改 scaling、归一化)都会让已有 LFM 失效,要重新训。
局限与展望¶
- 主要在中等规模数据集(CelebA-HQ 28k、FFHQ 63k、ImageNet 1.2M)和 DiT 系列上验证;在 web-scale 数据(LAION-5B 量级)+ 大 Flux/SD3 上是否仍然成立,本文没回答。
- \(\mathcal{G}\) 梯度准则计算太贵,文中只用来分析,没在大数据集落地;要让它实用,可能要做随机投影或 sketch。
- C2F 的 seam loss 只对齐了一个时间点,没考虑两段 ODE 的曲率匹配;如果两段速度场二阶导差很多,仍可能出现微小 artifact。
- 文章把 stability 与 generalization 的关系当作 future work——按理说稳定性越强就越接近"复刻训练集",怎么平衡稳定性与多样性是个开放问题。
相关工作与启发¶
- vs Kadkhodaie 2024:他们在 score-based diffusion + pixel 空间观察到分裂训练后趋同;本文把现象搬到 latent FM、给出理论根据、并翻成可用工具。
- vs Bertrand 2025:Bertrand 给出 FM closed-form,但只用来研究"模型何时 generalize";本文借用其 softmax-peaked 性质来论证剪枝可行性,这是非常聪明的"复用"。
- vs 数据集 distillation / coreset:本文证明 LFM 上简单的 cluster-balanced 剪枝就能打败更复杂的 coreset 方法,给生成模型领域的数据高效化提供了一个简洁 baseline。
评分¶
- 新颖性: ⭐⭐⭐⭐ 稳定性现象 + C2F 两段式都不是首创,但首次把它们在 LFM 上系统化、给理论解释
- 实验充分度: ⭐⭐⭐⭐⭐ CelebA-HQ / FFHQ / ImageNet 三个数据集、6 种剪枝准则、5 种扰动类型,覆盖很全
- 写作质量: ⭐⭐⭐⭐ 公式叙述清晰,图 4 的扰动分类做得很有视觉冲击力
- 价值: ⭐⭐⭐⭐⭐ 工程价值大(直接 2.15× 加速 + 数据剪 50%),且对 LFM 稳定性边界给出 actionable 指导