Latent Drifting in Diffusion Models for Counterfactual Medical Image Synthesis¶
会议: CVPR 2025
arXiv: 2412.20651
代码: https://latentdrifting.github.io/ (项目页)
领域: 医学图像 / 扩散模型
关键词: 反事实图像生成, 扩散模型微调, 隐空间漂移, 医学影像合成, 分布迁移
一句话总结¶
本文提出 Latent Drifting (LD),通过在扩散模型的前向和反向过程中引入一个标量偏移参数 δ 来弥合预训练自然图像模型与医学图像目标分布之间的差距,显著提升了多种微调方案下的医学图像生成和反事实图像合成效果。
研究背景与动机¶
- 领域现状:预训练扩散模型(如 Stable Diffusion)在自然图像生成上表现卓越,医学领域希望利用这些模型的强大生成能力。现有微调方法(Textual Inversion、DreamBooth、Custom Diffusion)允许用少量样本为模型引入新概念。
- 现有痛点:医学图像与自然图像的分布差异巨大(如脑 MRI 背景必须全黑、骨性结构必须保持形状),直接微调预训练模型难以适应这种分布偏移。少量医学样本无法有效调整模型学到的自然图像分布。同时,从头训练医学扩散模型面临数据隐私、成本和稀有疾病等限制。
- 核心矛盾:预训练模型的隐空间噪声分布 \(z_T \sim \mathcal{N}(0, I)\) 是为自然图像设计的,但医学图像的最优采样分布可能与 \(\mathcal{N}(0, I)\) 有偏移。微调只调整模型参数 θ,但从不改变隐空间分布。
- 本文目标 (1) 如何高效地将预训练扩散模型适配到医学图像域;(2) 如何实现高质量的医学反事实图像生成(如疾病添加/移除、年龄变化、性别转换)。
- 切入角度:将隐空间的终态变量 \(z_T\) 视为另一个条件因子而非固定假设,通过一个简单的标量偏移 δ 修改隐空间均值来匹配目标分布。
- 核心 idea:在扩散过程的每个时间步为均值添加一个全局偏移 δ,将隐空间分布从自然图像域「漂移」到医学图像域。
方法详解¶
整体框架¶
Latent Drifting 是一个通用的插件方法,可以与任何扩散模型微调方案结合。给定预训练的 Stable Diffusion 和目标医学数据集,LD 在微调时同时修改前向和反向过程的分布,在推理时修改反向过程的分布。方法将反事实图像生成形式化为一个 min-max 优化问题,在保持与原图相似性(Counterfactual Fidelity)的同时最大化结果变化(Desired Outcome Fidelity)。
关键设计¶
-
Latent Drifting 机制:
- 功能:通过引入标量偏移 δ,修改扩散过程的隐空间分布以匹配目标医学图像域。
- 核心思路:在反向过程的转移核中添加偏移 \(p_\theta(x_{t-1}|x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t) + \delta, \Sigma_\theta(x_t, t))\)。δ 是一个有符号标量,在前向扩散时也相应偏移 \(z_T\)。通过网格搜索(遍历 δ 从 -0.2 到 0.2)找到使生成分布 \(\mathcal{D}_\theta\) 与目标分布 \(\mathcal{D}_{GT}\) 的 L1 距离最小的 δ 值。实验发现 δ=0.1 在脑 MRI 上效果最佳。
- 设计动机:传统微调假设 θ 的更新足以覆盖分布迁移,但实际上隐空间分布 \(\mathcal{N}(\mu, \sigma)\) 从未被调整——未加 LD 时微调后的隐空间分布方差很大且不稳定,加了 LD 后分布达到稳定点,对分布偏移更鲁棒。
-
反事实生成的形式化框架:
- 功能:将医学反事实图像生成统一建模为约束优化问题。
- 核心思路:目标函数 \(L(x, x', y', \lambda) = \min_{\ell_o}[\lambda \cdot \ell_o(\hat{f}(x'), y')] + \min_{\ell_{in}}[\ell_{in}(x, x')]\)。其中 \(\ell_{in}\) 保证反事实图像 \(x'\) 与原图 \(x\) 的相似性,\(\ell_o\) 保证反事实结果符合目标标签 \(y'\)。两者互为约束:\(\ell_{in} \propto 1/\ell_o\)。当 \(\lambda=0\) 退化为标准微调(\(z=z'\)),当 \(\lambda>0\) 引入 LD 的 δ 偏移来增强条件控制。
- 设计动机:反事实图像生成本质上是在"改变目标特征"和"保留原始特征"之间的平衡问题,这个优化框架自然地将 LD 纳入条件控制。
-
与多种微调方案的结合:
- 功能:证明 LD 作为通用插件可以适配不同的微调策略。
- 核心思路:对四种微调方法分别结合 LD:(1) Textual Inversion——仅微调文本编码器的嵌入空间;(2) DreamBooth——微调去噪 U-Net 并用类先验保留损失;(3) Custom Diffusion——仅微调 U-Net 中的交叉注意力层权重;(4) Basic FT——微调整个去噪 U-Net。每种方法都简单地在扩散过程中加入 δ 偏移即可。对于 image-to-image 的反事实生成,还与 Pix2Pix Zero 和 InstructPix2Pix 结合。
- 设计动机:LD 是在扩散过程层面的修改(改变均值),与模型参数的微调方式正交,因此可以无缝嵌入任何微调方案,无需修改其内部机制。
损失函数 / 训练策略¶
基础训练损失为标准去噪目标 \(\mathbb{E}_{x,c,\epsilon,t}[w_t\|\hat{x}_\theta(\alpha_t x + \sigma_t \epsilon, c) - x\|_2^2]\),在此基础上 LD 仅修改采样分布。使用 SD-v1.4 预训练模型,δ 通过网格搜索在 [-0.2, 0.2] 范围内确定,使用 L1 归一化距离作为评价指标。Text-to-image 用 200 样本评估,image-to-image 用纵向数据集评估。
实验关键数据¶
主实验¶
| 微调方法 | LD | FID (脑MR)↓ | KID (脑MR)↓ | AUC (脑MR)↑ | FID (胸片)↓ | AUC (胸片)↑ |
|---|---|---|---|---|---|---|
| SD + Basic FT | ✗ | 92.13 | 0.071 | 0.704 | 112 | 0.672 |
| SD + Basic FT | ✓ | 49.68 | 0.035 | 0.724 | 84 | 0.746 |
| Textual Inversion | ✗ | 120.63 | 0.098 | 0.600 | 171.77 | 0.600 |
| Textual Inversion | ✓ | 67.56 | 0.065 | 0.670 | 133.18 | 0.640 |
| DreamBooth | ✗ | 130.92 | 0.125 | 0.500 | 188 | 0.567 |
| DreamBooth | ✓ | 92.37 | 0.099 | 0.512 | 177 | 0.582 |
| Real + Synthetic | ✓ | - | - | 0.883 | - | 0.892 |
LD 在所有微调方法上都带来显著改进,Basic FT + LD 在脑 MRI 上 FID 从 92.13 降至 49.68(降 46%),且合成+真实数据训练的分类器 AUC 甚至超过纯真实数据(0.883 vs 0.870)。
消融实验¶
| 配置 | FID (aging)↓ | SSIM↑ | LPIPS↓ | PSNR↑ |
|---|---|---|---|---|
| InstructPix2Pix (Binned) + SD + Basic FT + LD | 15.39 | 0.74 | 0.13 | 32.77 |
| InstructPix2Pix (Word) + SD + Basic FT + LD | 15.25 | 0.75 | 0.13 | 32.78 |
| InstructPix2Pix (Numerical) + SD + Basic FT + LD | 15.37 | 0.76 | 0.12 | 32.83 |
| InstructPix2Pix + SD + CD + LD (Numerical) | 24.05 | 0.32 | 0.23 | 30.70 |
Prompt 格式对照实验表明简单的 Diverse + Patient Info 组合最佳(FID 51.35, KID 0.0351),数值型年龄条件在 image-to-image 任务中综合最优。
关键发现¶
- LD 在所有微调方案中一致有效:无论是只调文本嵌入(Textual Inversion)还是调整 U-Net(Basic FT),LD 都能大幅降低 FID/KID。效果最好的是 Basic FT + LD。
- 合成数据增强超越真实数据:用 50% LD 合成 + 50% 真实数据训练分类器,AUC 超过 100% 真实数据(脑 MRI: 0.883 vs 0.870),验证了合成数据的实用价值。
- 视觉改善明显:加 LD 后脑 MRI 背景从灰色杂质变为纯黑,脑部结构更逼真,白质灰质边界更清晰。
- Prompt 中包含患者信息(年龄、性别、诊断)显著优于通用 prompt。
- δ 的最优值在 0.05-0.1 范围内,对不同微调方法较为稳定。
亮点与洞察¶
- 极简但有效:仅一个标量参数 δ 就能弥合自然图像和医学图像的分布差距,实现成本几乎为零。这个发现揭示了扩散模型中隐空间分布是一个被忽视的关键自由度。
- 方法无关性:作为插件可以嵌入任何微调方案,且都有效——这种正交于模型架构的改进方式非常优雅,可以直接应用于未来新出现的微调方法。
- 反事实生成的统一框架:将疾病添加/移除、年龄变化、性别转换等多种医学场景统一到一个反事实优化框架下,从 min-max 的角度理解条件生成,对该领域有理论贡献。
- 合成数据增强的证据:成功证明了 LD 生成的合成数据可以作为数据增强手段提升下游分类性能,为数据稀缺的医学 AI 提供了可行路径。
局限与展望¶
- δ 的确定方式:目前通过网格搜索确定 δ,对新的目标域需要重新搜索。可以考虑自动化地根据源域和目标域的分布差异估计 δ。
- 全局标量的局限:δ 是全局的、各通道相同的偏移,对于不同空间区域或通道可能需要不同的偏移量。可以探索空间自适应或通道自适应的 LD。
- 2D 切片处理:实验仅在 2D 脑 MRI 切片上进行,未处理 3D 体积数据。扩展到 3D 扩散模型需要验证 LD 在更高维空间的有效性。
- 反事实评估困难:缺乏真正的反事实 ground-truth(如"这个人如果得了阿尔茨海默病,MRI 应该长什么样"),评估主要依赖 FID/KID 等分布指标和下游分类 AUC。
- 可以尝试将 LD 与 ControlNet 等条件控制方法结合,实现更精细的医学图像编辑。
相关工作与启发¶
- vs DreamBooth: DreamBooth 通过类先验保留损失微调 U-Net引入新概念,但在脑 MRI 上 FID 高达 130.92;加上 LD 后降至 92.37,说明仅微调参数不够,还需要调整隐空间分布。
- vs Textual Inversion: TI 仅调文本嵌入是最轻量的方案,但对医学图像理解不足;LD 将其 FID 从 120.63 降至 67.56,且 AUC 从 0.600 提升到 0.670。
- vs 从头训练的医学扩散模型(如 Khader et al., Pinaya et al.): 这些方法需要大量医学数据训练,LD 利用预训练模型的先验知识仅需少量样本微调。
- LD 的思路可以推广到其他存在分布偏移的领域适配场景,如遥感图像、工业检测等。
评分¶
- 新颖性: ⭐⭐⭐⭐ 隐空间偏移的思路简洁独到,将隐变量视为可调条件是新视角
- 实验充分度: ⭐⭐⭐⭐ 覆盖多种微调方法、多个医学数据集、多种生成任务
- 写作质量: ⭐⭐⭐⭐ 理论推导清晰,可视化丰富
- 价值: ⭐⭐⭐⭐ 简单有效的插件方法对医学图像合成有直接实用价值