MPDiT: Multi-Patch Global-to-Local Transformer Architecture for Efficient Flow Matching¶
会议: CVPR 2026
arXiv: 2603.26357
代码: https://github.com/quandao10/MPDiT
领域: 扩散模型
关键词: 扩散Transformer, 流匹配, 多尺度patch, 高效架构, 图像生成
一句话总结¶
提出 MPDiT,一个多尺度 patch 的全局到局部扩散 Transformer 架构,前期用大 patch(4×4)处理全局上下文仅需 64 个 token,后期上采样到小 patch(2×2)的 256 个 token 精修局部细节,将 GFLOPs 降低高达 50%,且 XL 模型在 240 epoch 即达到 FID 2.05(cfg)。
研究背景与动机¶
-
领域现状:扩散模型/流匹配模型已成为视觉生成的主流范式,Transformer 架构(DiT/SiT)由于出色的可扩展性逐渐取代 UNet 成为主流骨干。但 DiT 的等距设计在每层都处理相同数量的 patch token,计算成本很高。
-
现有痛点:训练效率是核心瓶颈。线性注意力(如 SANA、LIT)虽然减少了计算量但性能显著下降。Mamba/SSM 在扩散模型的 token 数量级(<1K)上优势不明显。MaskDiT 在高掩码率时性能急剧恶化(75%掩码率 FID≈100)。
-
核心矛盾:MaskDiT 的失败和 DiT-XL/4 的相对成功提供了关键观察——大 patch 的少量 token 虽然缺乏局部细节,但能有效捕获全局结构信息。而 MaskDiT 的随机掩码让每个训练样本只学到部分 token 之间的关系,全局和局部信息都建模不好。
-
本文目标 如何在保持生成质量的前提下显著降低扩散 Transformer 的计算量和训练成本?
-
切入角度:受全局-局部注意力的启发,但不是在注意力层级实现(效果不好且增益可忽略),而是在整个架构层级实现——前面的层看"粗粒度"全局,后面的层看"细粒度"局部。
-
核心 idea:将等距 DiT 改为"先粗后细"的层次化结构——多数 Transformer 块在大 patch(64 token)上高效获取全局语义,少量尾部块在小 patch(256 token)上精修局部细节。
方法详解¶
整体框架¶
MPDiT 想解决的问题很直接:标准 DiT 在每一层都处理同样多的 patch token,计算成本居高不下,但其实绝大部分计算并不需要那么细的粒度。它的做法是把等距的 DiT 改成「先粗后细」的层次结构。输入是 VAE 编码后的潜变量 \(z \in \mathbb{R}^{32 \times 32 \times 4}\)(ImageNet 256×256),标准 DiT 用 patch size=2 切成 256 个 token,MPDiT 则让前 \(N-k\) 个 Transformer 块用 patch size=4 只处理 64 个 token,在低分辨率上把全局结构建模好;随后一个上采样模块把 64 个 token 展开成 256 个,交给最后 \(k\) 个块精修局部细节;输出再经反向 patchify 和 VAE 解码成图像。换句话说,多数块在「缩略图」上跑,只有尾部少量块在「全图」上收尾。条件信号(时间步、类别)则通过 FNO 时间嵌入与多 token 类别嵌入注入到所有块。
%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400}}}%%
flowchart TD
A["VAE 潜变量 z (32×32×4)"] --> B["大 patch 切块<br/>patch=4 → 64 token"]
B --> C["前 N−k 块:全局建模<br/>64 token 扛全局结构(多尺度 Patch 架构)"]
C --> D["上采样模块<br/>pixel-unshuffle 64→256 + skip 引回细节"]
D --> E["尾部 k 块:局部精修<br/>256 token 补细节(多尺度 Patch 架构)"]
E --> F["反 patchify + VAE 解码 → 图像"]
G["FNO 时间嵌入 + 多 token 类别嵌入"] -->|条件信号| C
G -->|条件信号| E
关键设计¶
1. 多尺度 Patch 架构:用大 patch 扛全局、小 patch 补细节
整篇方法的出发点来自一个反差观察——MaskDiT 在 75% 掩码率下 FID≈100,几乎崩掉,而处理 token 数量相近的 DiT-XL/4 却能到 FID≈40。差别在于:随机掩码让每个样本只看到零散的局部关系,全局和局部都没学好;而大 patch 是一种结构化的降采样,token 虽少却完整覆盖整张图,能把全局语义建模得很扎实。MPDiT 顺着这个观察走:总共 \(N\) 个 Transformer 块,前 \(N-k\) 个接收 patch size=4 的 64 个 token。由于自注意力计算量正比于 token 数的平方,这部分开销只有标准 DiT 的 \(\frac{1}{16}\);大 patch 唯一的短板是缺局部细节,而这恰好可以靠尾部少量精修块补回来(实验里 \(k=4\sim6\) 就够)。因为绝大多数块只在 64 个 token 上跑,MPDiT-XL 的 GFLOPs 从 118.66 降到 59.30,直接砍掉一半。这套层次还能往上叠:面对 512² 这种更高分辨率,可以扩成三级 patch \(\{8, 4, 2\}\)。
2. 上采样模块:把 64 个粗 token 干净地展开成 256 个细 token
大 patch 跑完全局后,需要把 token 数从 64 拉回 256 才能精修,难点在于这次扩展不能破坏前面已经建好的条件关系。模块先把 image token 和 class token 分开,对 image token 做线性投影再用 pixel-unshuffle 实现 4× 空间展开(64→256),经 GELU 后与 class token 重新拼接,最后用 LayerNorm + 线性层把 class–image 的关系重新对齐。之所以要这一步对齐,是因为前面的块是在 64 个 token 的尺度上学到的 class–image 交互,token 数量一变两者就会错位,必须有个线性层来重建。另一个关键是一路 skip connection:它从原始 patch size=2 的嵌入直接加到上采样结果上,把大 patch 阶段丢掉的细粒度空间信息重新引回来。消融里这个设计的选择很敏感——线性投影(默认)FID=24.74,换成 ConvTranspose 直接掉到 29.45。
3. FNO 时间嵌入 + 多 Token 类别嵌入:把条件信号喂得更足
这两点都在解决「条件信号太单薄」的问题。FNO 时间嵌入针对传统正弦+MLP 表达力有限:它把标量时间步 \(t\) 加到一个 32 点的 1D 均匀网格上构成一段 1D 信号,经线性层升到 32 通道,再过 3 个 MixedFNO 块(SpectralConv1D 与 Conv1D 混合)学习平滑的时间结构,最后全局平均池化 + 线性投影。这个设计借鉴 Neural Operator,与流匹配本身是 ODE/SDE 连续动态的直觉吻合,实测带来约 4 点 FID 提升(3 个 MixedFNO 块最优,2 个略差、4 个反而不稳)。多 Token 类别嵌入则针对单个 class token 信息过度压缩的问题:每个类别改用 \(m=16\) 个可学习 token 表示,作为前缀拼到 image token 前,替代原来的 AdaIN 调制——更分布式的语义让收敛快了约 7 点 FID,而 \(m=32\) 几乎不再有额外收益,说明 16 个 token 已基本编码完类别语义。
损失函数 / 训练策略¶
- 使用标准流匹配目标:\(L_{FM} = \|f_\theta(z_t, t, c) - (n - z)\|_2^2\)
- AdaIN 参数跨所有 Transformer 块共享(参数从 130M 降到约 90M,FID 仅上升 0.4)
- 训练设配:8×A100-40GB,固定学习率 \(2 \times 10^{-4}\),batch size 1024,EMA 0.9999
- 采样使用 250 步 Euler 求解器
实验关键数据¶
主实验¶
| 模型 | Epochs | GFLOPs | FID↓ (non-cfg) | FID↓ (cfg) | IS↑ (cfg) |
|---|---|---|---|---|---|
| DiT-XL/2 | 1400 | 118.66 | 9.62 | 2.27 | 278.24 |
| SiT-XL/2 | 1400 | 118.66 | 9.35 | 2.15 | 258.09 |
| DiG-XL/2 | 240 | 89.40 | 8.60 | 2.07 | 278.95 |
| DiCo-XL | 80 | 87.30 | 11.67 | - | - |
| MPDiT-XL | 240 | 59.30 | 7.36 | 2.05 | 278.73 |
消融实验¶
| 组件 | Params(M) | GFLOPs | FID↓ |
|---|---|---|---|
| DiT-B/2 baseline | 130.0 | 23.0 | 34.84 |
| + Shared AdaIN | 90.3 | 22.9 | 35.31 |
| + Multi-token Class (m=16) | 101.9 | 24.3 | 28.56 |
| + FNO Time Embedding | 101.2 | 24.3 | 24.52 |
| + MPDiT (k=6) | 104.8 | 16.6 | 24.74 |
| k 值 (XL) | GFLOPs | FID↓ |
|---|---|---|
| k=4 | 53.2 | 11.11 |
| k=6 (默认) | 59.3 | 9.92 |
| k=8 | 65.4 | 9.73 |
| Class Token 数 m | FID↓ |
|---|---|
| m=1 | 32.31 |
| m=4 | 30.91 |
| m=8 | 28.12 |
| m=16 (默认) | 24.74 |
| m=32 | 24.47 |
关键发现¶
- k=6 是最优平衡点:仅 6 个精修块即可在效率和质量间取得最佳折中。k=4 太少导致 FID 明显上升(XL: 11.11 vs 9.92),k=8 改善极小但 GFLOPs 增加 10%
- 多 token 类别嵌入收益巨大:从 m=1 到 m=16,FID 从 32.31 降到 24.74(降 7.5 点!),且 m=32 几乎不再有提升,说明 16 个 token 已充分编码类别语义
- FNO 时间嵌入稳定提升 4 点 FID:3 个 MixedFNO 块是最优(2 个略差,4 个反而不稳定)
- 上采样模块设计关键:Linear+Linear(默认)FID=24.74 vs ConvTranspose=29.45,选对上采样方式影响很大
- 训练吞吐量翻倍:MPDiT-XL 的采样速度是 DiT-XL/2 的 2 倍以上
亮点与洞察¶
- "先粗后细"的架构设计简洁而有效:与 MaskDiT 的失败对比特别有说服力——有结构的降采样(大patch)远优于随机的降采样(掩码)。这个洞察可以迁移到任何需要减少 token 数量的 Transformer 架构
- FNO 时间嵌入是一个有趣的尝试:用 Neural Operator 的思路来建模扩散过程中的连续时间动态,既新颖又有直觉上的合理性(流匹配本身就是 ODE/SDE 问题)
- Shared AdaIN 的发现有实用价值:直接共享时间/类别调制层可以减少 30% 参数、FID 仅上升 0.4,这在资源受限场景下非常实用
局限与展望¶
- 仅在 ImageNet 256×256 上验证,缺乏文本到图像或更高分辨率的实验
- 三级 patch 层次(用于 512²)只是提出了思路但没有实验验证
- 上采样模块的设计比较简单(线性投影),更复杂的设计可能进一步提升效果
- FNO 时间嵌入中维度 128 不稳定的原因未深入分析
- 与 REPA 等表示对齐方法的结合未探索,可能带来进一步加速
相关工作与启发¶
- vs DiT/SiT:标准等距设计,每层 256 tokens。MPDiT 通过分层 patch 将大部分计算压缩到 64 tokens,GFLOPs 减半但 FID 更优
- vs MaskDiT:同样是减少处理 token 数量的思路,但 MaskDiT 的随机掩码在高比例时严重失效(75% mask → FID≈100),而 MPDiT 的结构化降采样效果好得多
- vs DiCo/DiC:卷积重构的扩散模型,GFLOPs 相近但 MPDiT 在相同训练 epoch 下 FID 更优,说明 Transformer 在全局建模上仍有优势
- vs SANA/LIT:线性注意力方案需要从预训练全注意力模型初始化,MPDiT 可以从头训练
评分¶
- 新颖性: ⭐⭐⭐⭐ 多尺度 patch 的思路并非全新(灵感来自全局-局部注意力),但在扩散 Transformer 中的应用和效果验证有价值
- 实验充分度: ⭐⭐⭐⭐ ImageNet 上的消融非常详尽,但缺乏其他领域/分辨率的验证
- 写作质量: ⭐⭐⭐⭐ 动机推导清晰,与 MaskDiT 的对比分析有说服力
- 价值: ⭐⭐⭐⭐ 50% GFLOPs 减少且质量不降,对扩散模型训练效率有实际推动