Integrating Intermediate Layer Optimization and Projected Gradient Descent for Solving Inverse Problems with Diffusion Models¶
会议: ICML2025
arXiv: 2505.20789
代码: StarNextDay/DMILO
领域: 逆问题求解 / 扩散模型
关键词: 逆问题, 扩散模型, 中间层优化, 投影梯度下降, 图像重建, ILO, PGD
一句话总结¶
提出 DMILO 和 DMILO-PGD 两种方法,通过中间层优化(ILO)分解扩散模型采样过程以大幅降低显存,并结合投影梯度下降(PGD)避免次优收敛,在线性和非线性逆问题上全面超越 DMPlug 等 SOTA 方法。
研究背景与动机¶
逆问题(Inverse Problems)的目标是从噪声观测 \(\boldsymbol{y} = \mathcal{A}(\boldsymbol{x}^*) + \boldsymbol{\epsilon}\) 中恢复信号 \(\boldsymbol{x}^*\),广泛应用于医学成像、压缩感知和遥感等领域。扩散模型(DM)作为强大的生成先验已在逆问题求解中取得 SOTA 表现,但现有方法存在两大痛点:
显存瓶颈:DMPlug 等 CSGM 类方法需要在整个采样过程中保留完整计算图用于反向传播,随采样步数增加显存线性增长,4 步时即超出 RTX 4090 显存
次优收敛:依赖初始向量选择,容易陷入局部最优
此前的 ILO 方法虽在 GAN 上有效,但其层分解依赖特定网络架构,难以泛化。本文发现,扩散模型的采样过程天然由多个函数复合而成,非常适合 ILO 分解。
方法详解¶
核心思想:将扩散采样视为函数复合¶
扩散模型的生成过程可表示为 \(N\) 个采样步的函数复合:
每个 \(g_i\) 对应 DDIM 采样的一步:
这种分解与去噪网络架构无关,可无缝适配任何 DM。
DMILO:中间层优化 + 稀疏偏差¶
对第一层(与观测直接关联),优化:
对后续每层,以上一层优化结果为目标:
其中 \(\boldsymbol{\nu}\) 为稀疏偏差项(\(\ell_1\) 正则),用于探索扩散模型范围之外的信号。每次只需保留单步梯度信息,显存恒定。
DMILO-PGD:引入投影梯度下降¶
在 DMILO 基础上交替执行:
- 梯度下降步:\(\boldsymbol{x}_{t_0}^{(e)} = \boldsymbol{x}_{t_0}^{(e-1)} - \eta \nabla \|\boldsymbol{y} - \mathcal{A}(\boldsymbol{x}_{t_0}^{(e-1)})\|_2^2\)
- 投影步:用 DMILO 将更新后的信号投影回扩散模型的扩展范围
关键区别:投影时最小化 \(\|\mathcal{A}(\mathcal{G}(\boldsymbol{x}_{t_N})) - \mathcal{A}(\hat{\boldsymbol{x}}_{t_0})\|_2^2\),利用前向算子 \(\mathcal{A}\) 引导投影方向,而非传统 PGD 的纯距离投影,理论上保证更优重建。
理论保证¶
在 Lipschitz 连续和低维流形假设下,Theorem 4.4 保证当测量数 \(m = \Omega(k_2 \log \frac{L_1 n}{\delta} + k^2 \log(3n))\) 时,利用前向算子的测量最优解接近真实最优解:
实验关键数据¶
实验覆盖 CelebA、FFHQ、LSUN-bedroom、ImageNet,含 4 种线性任务和 2 种非线性任务。
显存对比(RTX 4090,模型 2.75GB)¶
| 采样步数 | DMPlug | DMILO | DMILO-PGD |
|---|---|---|---|
| 1 | 10.53 GB | 10.53 GB | 10.53 GB |
| 2 | 15.72 GB | 10.53 GB | 10.54 GB |
| 3 | 20.83 GB | 10.53 GB | 10.54 GB |
| 4 | N/A (OOM) | 10.54 GB | 10.54 GB |
超分辨率 & 修复(CelebA, σ=0.01)¶
| 方法 | SR PSNR↑ | SR SSIM↑ | Inpaint PSNR↑ | Inpaint SSIM↑ |
|---|---|---|---|---|
| DMPlug | 32.38 | 0.875 | 35.51 | 0.935 |
| DCPS | 29.47 | 0.834 | 35.42 | 0.940 |
| DMILO-PGD | 33.58 | 0.906 | 36.42 | 0.952 |
运动去模糊(CelebA, σ=0.01)¶
| 方法 | FID↓ | LPIPS↓ | PSNR↑ | SSIM↑ |
|---|---|---|---|---|
| DMPlug | 78.57 | 0.164 | 30.25 | 0.824 |
| DCPS | 35.19 | 0.054 | 31.05 | 0.856 |
| DMILO | 31.08 | 0.044 | 34.15 | 0.908 |
非线性去模糊(FFHQ, σ=0.01)¶
| 方法 | LPIPS↓ | PSNR↑ | SSIM↑ |
|---|---|---|---|
| DMPlug | 0.099 | 31.37 | 0.866 |
| DMILO-PGD | 0.047 | 34.02 | 0.919 |
亮点与洞察¶
- 自然分解:扩散采样过程的函数复合结构天然适配 ILO,无需依赖特定架构,这一观察非常优雅
- 显存恒定:无论采样步数如何增加,DMILO 显存几乎不变(~10.5 GB),而 DMPlug 线性增长直至 OOM
- 稀疏偏差扩展范围:\(\ell_1\) 正则允许探索 DM 生成范围之外的信号,对于真实信号不完全在生成分布内的场景很关键
- 前向算子引导投影:DMILO-PGD 中利用 \(\mathcal{A}\) 引导投影比纯距离投影效果更好,有直观的理论支撑
- 非线性任务大幅领先:PSNR 提升 2.5–3.5 dB,显示方法在复杂退化下的优势
局限与展望¶
- PGD 对盲去模糊效果有限:DMILO-PGD 在 BID 任务上不如 DMILO,作者推测是朴素梯度更新不适合核估计,需要设计专门的核更新策略
- 计算效率未充分讨论:虽然显存降低,但多轮外迭代 × 多轮内迭代的计算量可能很大(如超分辨率需 400 内迭代 × 10 外迭代)
- 高斯去模糊在 ImageNet 上表现一般:FID 和 LPIPS 在某些配置下劣于 DCPS,泛化性有待加强
- 超参数较多:\(\lambda\)、内外学习率、内外迭代次数等超参数需针对不同任务分别调优
- 仅验证了 DDIM 采样:未探索与其他采样器(DPM-Solver 等)的结合
相关工作与启发¶
- DMPlug (Wang et al., 2024):本文直接改进的基线,优化初始隐变量的 CSGM 方法
- ILO (Daras et al., 2021):中间层优化最初为 GAN 设计,本文将其自然推广到 DM
- 稀疏偏差 (Dhar et al., 2018):允许生成范围外的信号探索
- PGD for IP (Shah & Hegde, 2018):投影梯度下降框架
- DCPS (Janati et al., 2024):另一个强竞争对手,在某些任务上与本文方法各有胜负
评分¶
- 新颖性: ⭐⭐⭐⭐ — ILO 与 DM 的结合自然且有效,前向算子引导投影有理论新意
- 实验充分度: ⭐⭐⭐⭐⭐ — 6 种任务、4 个数据集、9+ 个 baseline,非常全面
- 写作质量: ⭐⭐⭐⭐ — 结构清晰,理论与实验衔接好
- 价值: ⭐⭐⭐⭐ — 显存问题的解决对实际部署有重要意义,但计算效率仍有提升空间