跳转至

Pixel-Level Residual Diffusion Transformer: Scalable 3D CT Volume Generation

会议: ICLR 2026
OpenReview: https://openreview.net/forum?id=bWtRZQ1rm2
代码: https://github.com/Fredy-Zhang/PRDiT
领域: 医学图像 / 扩散模型 / 3D生成
关键词: 3D CT生成, 残差扩散, Diffusion Transformer, 体素级生成, 高分辨率扩展

一句话总结

PRDiT 提出一种直接在体素级生成高分辨率 3D CT 体的两阶段残差扩散框架:先用轻量 MLP「局部去噪器」从重叠 3D patch 里估出低频粗结构,再用「全局残差 DiT」用全卷积视野补回高频残差,配合 hot 预测-校正采样和「复用低分辨率主干」的扩展策略,在 LIDC-IDRI / RAD-ChestCT 上 3D FID、MMD、Wasserstein 全面超过 HA-GAN、3D LDM、WDM-3D,且 256³ 训练成本仅为对手的 1/4 ~ 1/6。

研究背景与动机

领域现状:3D 医学体(尤其 CT)生成对诊断、分割、异常检测都很有用,主流路线分两类——GAN(HA-GAN、3D-StyleGAN)能出逼真局部细节,扩散模型(3D-DDPM、3D LDM、WDM-3D、triplane 扩散)训练更稳、保真度更高,近年成为主力。

现有痛点:体素特征图的体积随分辨率立方级膨胀,直接在高分辨率 3D 体上跑深层 U-Net 内存/算力都爆炸,于是现有方法被迫用各种「妥协」:patch 分块、降采样、或用 VAE/VQ-VAE 压到隐空间。但 patch 和降采样会截断有效感受野,丢掉全局解剖一致性;隐空间压缩在 3D 医学这种样本稀缺场景下又训不出鲁棒的编码器,重建质量差、关键解剖细节流失。

核心矛盾:局部细节保真度 ↔ 全局结构一致性 ↔ 计算可行性,三者难以兼得。卷积 U-Net 擅长局部却建模不了长程依赖;而把 2D 成功的 DiT 直接搬到稠密 3D 又会遇到训练不稳、优化困难,以及分辨率翻倍时 token 数 ×8、注意力开销 ×64 的天价成本。

本文目标:在不引入自编码器瓶颈的前提下,做到(1)体素级高保真合成、(2)全局结构一致、(3)能廉价扩展到 256³ 高分辨率。

切入角度:把生成任务按频率分解——低频粗结构其实局部 patch 内部就能估出来,真正难、需要全局上下文的是跨 patch 边界的高频残差。于是不必让一个大模型同时扛低频和高频。

核心 idea:用「轻量局部去噪器估低频 + 全局 DiT 只学高频残差」的两阶段 coarse-to-fine 残差学习,直接在体素级生成,并通过冻结复用低分辨率模型来近乎免费地扩到高分辨率。

方法详解

整体框架

PRDiT 把 3D 扩散过程拆成两条互补的支路,串成一个 coarse-to-fine 的残差流水线。给定体 \(X \in \mathbb{R}^{C\times H\times W\times D}\),先用滑窗(窗口 \(p\)、步长 \(s<p\)、带重叠)切成 \(N\) 个 3D patch 并展平。第一支路是 Local Denoiser:一个 MLP 盲估计器,对每个 patch 独立地、只凭 patch 内信息预测出干净信号 \(\hat{x}_i\) 和噪声 \(\hat{\epsilon}_i\),给整个体提供一份「先验粗估」。第二支路是 Global Residual DiT:把所有 patch embedding 通过多头自注意力联合起来,用全局视野算出残差修正 \(\Delta\hat{x}_i,\Delta\hat{\epsilon}_i\),把局部估计在 patch 边界处的误差补回去。采样时用 predictor-corrector(hot 扩散) 方案在确定性引导与可控随机性之间取平衡。最后,要扩到更高分辨率时,冻结复用已训好的低分辨率 PRDiT 当结构先验,只额外训一个「高分辨率残差细化模块」补回降/升采样丢掉的高频。

%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400}}}%%
flowchart TD
    A["输入体 X<br/>滑窗切重叠 3D patch"] --> B["局部去噪器<br/>MLP 盲估计每 patch<br/>信号 x̂ᵢ + 噪声 ε̂ᵢ"]
    B --> C["全局残差 DiT<br/>多头自注意力<br/>只学跨 patch 高频残差 Δ"]
    C --> D["预测-校正采样<br/>cold 预测 + hot 校正"]
    D -->|低分辨率| E["输出 3D CT 体"]
    D -->|扩到 256³| F["高分辨率残差细化<br/>冻结复用低分辨率主干"]
    F --> E

关键设计

1. 局部去噪器:用重叠 patch 上的 MLP 盲估计器先吃掉低频

针对「直接在高分辨率 3D 体上跑深模型内存爆炸」这个痛点,作者把低频粗结构的活儿交给一个极轻量的逐 patch MLP,而不是让大 Transformer 全包。前向扩散用 Zhang et al. (2023) 的角度调度:在时刻 \(t\) 把干净 patch \(v_i\) 加噪成 \(v_i^t = \cos(\tfrac{t}{T}\tfrac{\pi}{2})\,v_i + \sin(\tfrac{t}{T}\tfrac{\pi}{2})\,\epsilon_i\)。MLP 由两层 Adaptive SwiGLU(用时间嵌入经 AdaLN 调制出 \(\gamma,\beta\) 来 shift/scale)加残差跳连和线性投影组成,同时输出去噪 patch \(\hat{x}_i\) 和噪声估计 \(\hat{\epsilon}_i\),对齐标准扩散「联合估清晰信号与噪声」的目标,损失为 \(L_{local}=\mathbb{E}_i[\|\hat{x}_i-x_i\|_2^2+\|\hat{\epsilon}_i-\epsilon_i\|_2^2]\)

两个细节是关键:一是滑窗带重叠\(s<p\)),让相邻 patch 在边界共享上下文,缓解拼接伪影、提升连续性(消融里去掉重叠 FID 从 2.04 劣化到 3.27);二是它本质是个盲估计器——只看 patch 内信息,所以注定在跨 patch 的全局结构上有误差,这恰好把「难的部分」干净地留给下一阶段。

2. 全局残差 DiT:冻结局部支路,只让 Transformer 学跨 patch 的高频修正

局部去噪器逐 patch 独立、看不到全局,patch 边界和长程解剖一致性会出错。这里冻结局部去噪器,再训一个 DiT 用多头自注意力同时 attend 到所有 patch embedding,输出的不是从头重建,而是残差 \(\Delta\):精修后的 patch 为 \(\tilde{x}_i=\hat{x}_i+\Delta\hat{x}_i\)\(\tilde{\epsilon}_i=\hat{\epsilon}_i+\Delta\hat{\epsilon}_i\),用同样的逐 patch 损失 \(L_{global}=\mathbb{E}_i[\|\tilde{x}_i-x_i\|_2^2+\|\tilde{\epsilon}_i-\epsilon_i\|_2^2]\) 训练。

这样设计有两层好处:其一,DiT 只需补「残差」而非学整个分布,学习复杂度和训练时间都降下来;其二,因为局部支路只训一次,扩展 DiT 深度(4/8/12 层)时可以复用同一个局部去噪器,把规模化的开销摊薄。消融非常说明问题——去掉全局 DiT(只剩局部)FID 直接崩到 41.92,证明跨 patch 的全局残差才是高保真的命门;去掉局部去噪器(DiT-only)也明显变差,说明两阶段分工缺一不可。位置编码 PE 只喂给全局 DiT,让它独占建模空间关系的能力。

3. 预测-校正采样:从 cold 升级到 hot 扩散注入可控随机性

确定性的 cold 采样(单步梯度更新)多样性和细节都不够。作者把 Zhang et al. (2023) 的梯度生成路径改造成分离的 cold 预测步 + hot 校正步:预测步先向前跳 \(k\)\(x_{t-k}=x_t-k\cdot\nabla(\cos(\beta_t)\hat{x}_0+\sin(\beta_t)\hat{\epsilon})\)\(\beta_t=\tfrac{t}{T}\tfrac{\pi}{2}\));校正步再回退 \(k-1\)并注入新鲜噪声 \(x_{t-1}=\Gamma_t^{(k)}x_{t-k}+\sqrt{1-(\Gamma_t^{(k)})^2}\,\epsilon'\),其中 \(\Gamma_t^{(k)}=\cos(\beta_{t-1})/\cos(\beta_{t-k})\) 做方差保持。

\(k=1\) 退化为标准 cold 单步更新;\(k>1\) 配上校正器就变成hot 扩散,每步注入受控随机性增加探索。这套预测-校正在「确定性引导」和「自适应随机性」之间取平衡,既保细节又保全局连贯。消融显示 \(k=1\)(cold)FID 高达 8.889,升到 \(k=2\)(hot)骤降到 2.173,\(k\) 不必是整数但实验里取 2 最好。

4. 高分辨率扩展:冻结复用低分辨率主干,只训一个高频残差模块

从头在 256³ 训 DiT 几乎不可行——比 128³ token 数 ×8、注意力开销约 ×64,逼出极小 batch、频繁 OOM、优化崩坏。作者不重训,而是复用并冻结已训好的低分辨率 PRDiT 当结构先验:给定高分辨率噪声体 \(X_{HR}\),先降采样→过低分辨率模型→升采样得到一份「粗」的信号与噪声初值,再用一个基于局部去噪器结构的高分辨率残差细化模块补回高频,且集成进采样循环内(每步降采样查询低分辨率模型、升采样、立刻细化高频),而非事后级联超分(后者常破坏扩散学到的低频结构)。

这里有个反直觉但关键的采样选择:信号升采样用三线性插值(更平滑、保解剖结构连续),但噪声的降/升采样都用最近邻——因为三线性会平均像素、衰减高频噪声能量,破坏预训练模型期望的噪声统计;最近邻保住噪声能量,让降采样输入对齐低分辨率模型的训练分布。训练时只更新细化模块、冻结主干,目标用 patch 级信号-噪声估计损失加一个低频一致性项(约束高分辨率预测降采样后要匹配低分辨率输出),让模块专注高频。结果 256³ 上 FID 2.28 远超 HA-GAN(3.98)和 WDM-3D(5.60),训练只要 36 GPUh vs 对手 120~140 GPUh。

损失函数 / 训练策略

  • 局部去噪器:\(L_{local}=\mathbb{E}_i[\|\hat{x}_i-x_i\|_2^2+\|\hat{\epsilon}_i-\epsilon_i\|_2^2]\),联合估信号与噪声。
  • 全局 DiT:冻结局部支路后用 \(L_{global}=\mathbb{E}_i[\|\tilde{x}_i-x_i\|_2^2+\|\tilde{\epsilon}_i-\epsilon_i\|_2^2]\) 训残差。
  • 高分辨率细化:patch 级信号-噪声估计损失 + 低频一致性项;主干全程冻结。
  • 数据预处理:裁到肺窗、重采样到 1mm 各向同性、crop/pad 到 256³,低分辨率实验用平均池化降采样并归一化到 \([-1,1]\)

实验关键数据

主实验

LIDC-IDRI 与 RAD-ChestCT,128³ 分辨率,FID 已 ×10³,三个随机种子均值±标准差,W-Score 以 PRDiT-12L 为参照(越接近 1 越好)。

数据集 指标 HA-GAN 3D-LDM WDM-3D PRDiT-4L PRDiT-12L
LIDC-IDRI FID ↓ 3.26 7.62 3.67 2.04 1.41
LIDC-IDRI MMD ↓ 0.2071 0.3458 0.1885 0.1852 0.1501
RAD-ChestCT FID ↓ 3.92 4.14 4.11 1.92 1.45
RAD-ChestCT MMD ↓ 0.183 0.228 0.213 0.169 0.159

即便最浅的 PRDiT-4L 已全面超过所有 baseline;加深到 8/12 层稳定提升,体现可扩展性,而所有变体复用同一个只训一次的局部去噪器。

256³ 高分辨率(LIDC-IDRI,FID ×10³,A100 GPU 小时):

模型 FID ↓ MMD ↓ 训练成本
HA-GAN 3.98 0.2237 140 GPUh
3D-LDM OOM OOM
WDM-3D 5.60 0.2590 120 GPUh
PRDiT-4L↑256 2.28 0.1370 36 GPUh

消融实验

PRDiT-4L,LIDC-IDRI,FID ×10³。

配置 FID ↓ MMD ↓ 说明
Full model 2.04 0.1853 完整模型
w/o overlap 3.27 0.2304 去掉 patch 重叠,边界不一致
w/o local denoiser 3.10 0.2174 只剩全局 DiT
w/o global DiT 41.92 0.7795 只剩局部,FID 崩盘

预测步 \(k\) 的影响(hot vs cold):

\(k\) FID ↓ MMD ↓ 说明
1.0 (cold) 8.889 0.3490 纯确定性,最差
2.0 2.173 0.1849 最优
3.0 3.112 0.2112 过多随机性
4.0 4.184 0.2425 继续劣化

高分辨率扩展策略(LIDC-IDRI):从头训 PRDiT-128 得 FID 2.04 / 80 GPUh,而 PRDiT-64↑128 升采样方案 FID 2.89 / 仅 12 GPUh,快 6 倍以上,质量-算力权衡很划算。

关键发现

  • 全局 DiT 是命门:去掉它 FID 从 2.04 暴涨到 41.92,说明跨 patch 的全局高频残差才是高保真的根本;两阶段分工缺一不可。
  • hot 优于 cold\(k\) 从 1 到 2,FID 从 8.889 骤降到 2.173,但过大(\(k=3,4\))反而变差——随机性要适度。
  • 重叠 patch 必要:去重叠 FID 劣化 60%,验证滑窗重叠对边界连续性的作用。
  • 复用主干极省算力:256³ 训练 36 GPUh vs 对手 120~140 GPUh,且 3D-LDM 在该分辨率直接 OOM、WDM-3D 对随机种子敏感(少数退化样本拉高方差)。

亮点与洞察

  • 按频率分工,而非按模块堆叠:把「局部 patch 内能估的低频」和「需要全局视野的高频残差」拆给轻量 MLP 和 DiT,这个频率视角让小模型也能打——4 层就 SOTA。
  • 残差学习降低 DiT 负担:DiT 只学 \(\Delta\) 而非从头重建,既稳了训练又省了算力,是把 2D DiT 搬进稠密 3D 的关键工程解法。
  • 噪声用最近邻、信号用三线性:这个采样选择背后是「保住噪声能量统计以对齐预训练分布」的洞察,值得迁移到任何「复用低分辨率扩散模型做超分」的场景。
  • 冻结复用 + in-loop 细化:不重训主干、把高分辨率细化嵌进采样循环而非事后级联,避免破坏扩散学到的低频结构——这套「先验冻结、残差补高频」范式可推广到其他 3D/视频扩散扩展。

局限与展望

  • 评测集中在无条件生成与分布距离指标(3D FID / MMD / Wasserstein),未直接验证生成体对下游诊断/分割任务的实际增益。
  • 仅在胸部 CT(LIDC-IDRI、RAD-ChestCT)上验证,跨模态(MRI、PET)和其他解剖部位的泛化未知。
  • \(k\) 固定取 2 是经验最优,作者也承认 \(k\) 可非整数、还有进一步优化空间,但未系统搜索;hot 扩散对不同数据集是否都以 \(k=2\) 最优待考。
  • 体素级直接生成虽避开自编码器瓶颈,但对超大体(>256³)的内存可行性、以及条件生成(病灶可控合成)尚未展开。

相关工作与启发

  • vs HA-GAN(GAN 路线):HA-GAN 用分层 patch 生成器/判别器出高分辨率体,但有 mode collapse、训练不稳、显存大;PRDiT 走扩散更稳,且 FID/MMD 全面更优、训练更省。
  • vs 3D LDM(隐空间扩散):3D LDM 用 VQ-GAN 压到隐空间再扩散,但 3D 医学样本稀缺导致编码器训不好、丢解剖细节,256³ 还直接 OOM;PRDiT 不要自编码器瓶颈,直接体素级保细节。
  • vs WDM-3D(小波 U-Net 扩散):WDM-3D 在小波子带上用 3D U-Net 扩散提效,但仍是卷积、长程依赖弱,对种子敏感;PRDiT 用 DiT 拿全局视野,质量和鲁棒性都更好。
  • vs 原始 2D DiT / TCAM-Diff:DiT 在 2D 和稀疏 3D(点云)有效,但作者发现直接搬到稠密 3D 会训练不稳、扩展昂贵;PRDiT 的两阶段残差分解正是为驯服稠密 3D DiT 而生。

评分

  • 新颖性: ⭐⭐⭐⭐ 两阶段频率分解 + 残差 DiT + 冻结复用扩展,组合新颖且针对 3D 医学痛点。
  • 实验充分度: ⭐⭐⭐⭐ 两数据集、多深度、128³/256³、完整消融与算力对比;但缺下游任务与跨模态验证。
  • 写作质量: ⭐⭐⭐⭐ 动机清晰、图文(Fig 1-3)对照到位、消融说服力强。
  • 价值: ⭐⭐⭐⭐ 体素级高分辨率 3D CT 生成且训练成本骤降,对数据稀缺的医学影像有实用价值。