Efficient Diffusion Transformer with Step-wise Dynamic Attention Mediators¶
会议: ECCV 2024
arXiv: 2408.05710
代码: 有 (https://github.com/LeapLabTHU/Attention-Mediators)
领域: 图像复原
关键词: 扩散Transformer, 注意力中介者, 动态网络, 线性注意力, 去噪冗余
一句话总结¶
发现 Diffusion Transformer 中 query-key 交互存在显著冗余(尤其在去噪早期),提出 Attention Mediator 机制将注意力复杂度降至线性,并设计逐步动态调整策略,在 SiT-XL/2 上实现 SOTA FID 2.01,同时减少计算量。
研究背景与动机¶
Diffusion Transformer 的效率困境¶
Diffusion Transformer(DiT)因其简洁性、有效性和可扩展性,正在取代 U-Net 成为扩散模型的主流骨架,驱动了 Stable Diffusion V3、Pixart-α/Σ/δ、华为 DiT、Sora 等应用。然而,DiT 的广泛批评在于全局注意力机制的高计算消耗——self-attention 的 \(O(N^2C)\) 复杂度成为推理瓶颈,严重阻碍了高分辨率图像和长视频的实际部署。
虽然视觉识别领域已有多种注意力加速方法(窗口注意力、线性注意力等),但扩散生成领域的注意力效率优化几乎是空白。
关键观察:去噪过程中的注意力冗余¶
本文通过定量分析发现了 DiT 中的两个关键现象:
观察1:大量 query-key 冗余普遍存在。在所有 self-attention 层中,不同 query 对 key 的注意力分布高度相似。例如 DiT-S/2 的第10层,在初始几步中所有 query 的内部距离几乎为零——它们完全同质化。
观察2:冗余随去噪进程递减。在去噪初期(纯噪声阶段),注意力冗余最为严重;随着去噪推进,query 变得越来越多样化。这意味着早期的全量一对一注意力交互是不必要的。
量化冗余的度量方法¶
本文使用 Jensen-Shannon Divergence (JSD) 设计冗余度量。将注意力矩阵 \(\mathbf{A}^{(m)}\) 的每一行视为一个概率分布(query 对所有 key 的权重分布),计算第 \(l\) 层的冗余分数:
\(S_l\) 低 → 注意力分布高度相似 → 冗余严重。实验在 DiT-S/2 和 SiT-S/2 上测量了所有层和所有时间步的 \(S_l\),验证了上述两个观察。
方法详解¶
整体框架¶
在标准 self-attention 层中引入一组额外的中介者 tokens(Attention Mediators),数量远少于原始 tokens(如 <10%),分别与 query 和 key 交互。同时,根据去噪时间步的冗余程度动态调整中介者数量——早期少、后期多。
关键设计¶
1. Attention Mediators 机制¶
功能:用一组少量中介者 tokens \(\mathbf{t}^{(m)} \in \mathbb{R}^{n \times d}\)(\(n \ll N\))压缩 query-key 之间的冗余交互。
核心思路:将标准注意力的一步 Q-K-V 交互拆分为两步:
Step 1:中介者聚合 key 信息( \(n \times N\) 交互): $\(\mathbf{v}_{\text{med}}^{(m)} = \text{Softmax}\left(\frac{\mathbf{t}^{(m)} \mathbf{k}^{(m)\top}}{\sqrt{d}}\right) \mathbf{v}^{(m)}\)$
Step 2:query 从中介者提取信息(\(N \times n\) 交互): $\(\mathbf{h}^{(m)} = \text{Softmax}\left(\frac{\mathbf{q}^{(m)} \mathbf{t}^{(m)\top}}{\sqrt{d}}\right) \mathbf{v}_{\text{med}}^{(m)}\)$
中介者 tokens 的生成方式:对 query tokens 进行自适应池化——先 reshape 为 latent 图像形状 \(\mathbb{R}^{H \times W \times d}\),在空间维度池化到 \(\mathbb{R}^{h \times w \times d}\),再 flatten 得到 \(n = h \times w\) 个中介者。
设计动机:(1) 中介者作为信息"瓶颈",压缩了冗余的一对一 Q-K 交互;(2) 由于 Q 和 K 被中介者解耦,可以交换计算顺序——先计算 \(\mathbf{A}_{\text{tk}}^{(m)} \cdot \mathbf{v}^{(m)}\)(\(n \times N \cdot N \times d\)),再与 Q 交互,避免了 \(N \times N\) 的矩阵;(3) 补充 DWConv 弥补线性注意力的特征多样性损失。
2. 复杂度分析¶
标准 self-attention 的复杂度为 \(O(N^2 C)\)。中介者注意力的每一步均为 \(O(Nnd)\),总复杂度为 \(O(nNC)\)。由于 \(n \ll N\),计算量从二次降至线性。
| 操作 | 标准注意力 | 中介者注意力 |
|---|---|---|
| 复杂度 | \(O(N^2C)\) | \(O(nNC)\) |
| 256×256图像(\(N=256\)) | \(\propto 65536\) | \(\propto 256n\)(\(n=64\)时约1/4) |
| 分辨率增长 | 二次增长 | 线性增长 |
高分辨率优势:图像分辨率越高,线性复杂度的优势越突出。
3. 时间步动态中介者调整¶
功能:根据去噪过程中冗余程度的变化,动态增加中介者 tokens 数量。
核心思路:利用相邻去噪步之间的 latent 距离 \(\Delta_t = \|x_t - x_{t+1}\|\) 来量化冗余变化程度。当距离降至初始距离的某个阈值以下时,切换到更多中介者:
每个样本独立调度:阈值切换是样本自适应的,因为不同图像的去噪过程不同,latent 变化速度也不同。
设计动机:(1) 早期冗余高,少量中介者即可充分表达——大幅节省计算;(2) 后期细节丰富,需更多中介者保留多样性;(3) L1 距离比 L2 效果更好(消融验证)。
损失函数 / 训练策略¶
- 训练使用 ImageNet-1k,类条件扩散模型
- AdamW 优化器,无 weight decay,学习率 \(1 \times 10^{-4}\)
- 全局 batch size 256,训练 400K 迭代
- EMA decay 0.9999
- 仅替换前 4 层 self-attention 为中介者注意力(XL 模型)
- 高分辨率(512/1024)通过从 256 模型 finetune 获得
实验关键数据¶
主实验:ImageNet 256×256 类条件生成¶
| 模型 | FID↓ | sFID↓ | IS↑ | Precision↑ | Recall↑ |
|---|---|---|---|---|---|
| ADM | 10.94 | 6.02 | 100.98 | 0.69 | 0.63 |
| StyleGAN-XL | 2.30 | 4.02 | 265.12 | 0.78 | 0.53 |
| VDM++ | 2.12 | - | 267.7 | - | - |
| DiT-XL (cfg=1.5) | 2.27 | 4.60 | 278.24 | 0.83 | 0.57 |
| SiT-XL (cfg=1.5) | 2.06 | 4.50 | 270.27 | 0.82 | 0.59 |
| Ours (cfg=1.5) | 2.01 | 4.49 | 271.04 | 0.82 | 0.60 |
在 SiT-XL/2 基础上,方法取得 FID 2.01 的 SOTA 结果,同时减少了计算量。
消融实验:静态中介者数量对比(SiT-S/2, 256×256)¶
| 配置 | FLOPs(G) | FID↓ | sFID↓ | IS↑ | Precision↑ | Recall↑ |
|---|---|---|---|---|---|---|
| SiT-S/2 baseline | 6.06 | 58.61 | 9.25 | 24.31 | 0.41 | 0.59 |
| + Ours (n=4) | 5.49 (-9.4%) | 57.67 | 10.01 | 26.66 | 0.42 | 0.56 |
| + Ours (n=16) | 5.55 (-8.4%) | 54.55 | 9.28 | 26.55 | 0.43 | 0.59 |
| + Ours (n=64) | 5.78 (-4.6%) | 53.57 | 9.01 | 27.26 | 0.43 | 0.61 |
即使使用最少的 4 个中介者,FID 也优于 baseline;n=64 时 FID 降低 5.04,FLOPs 仍减少 4.6%。
消融实验:对比简单 Q-K 维度压缩¶
| 方法 | FLOPs(G) | FID↓ | Precision↑ | Recall↑ |
|---|---|---|---|---|
| SiT-S/2 baseline | 6.06 | 58.61 | 0.41 | 0.59 |
| Q-K 维度压缩 r=0.875 | 5.91 | 58.98 (+) | 0.40 | 0.60 |
| Q-K 维度压缩 r=0.750 | 5.76 | 59.18 (+) | 0.39 | 0.59 |
| Q-K 维度压缩 r=0.500 | 5.46 | 60.02 (+) | 0.40 | 0.57 |
| Ours (n=64) | 5.78 | 53.57 | 0.43 | 0.61 |
直接降低 Q-K 隐藏维度虽然节省计算,但 FID 持续恶化;而本文方法在节省计算的同时显著提升质量。
关键发现¶
- 中介者不仅降低计算量,还提升生成质量:这是因为压缩冗余交互等价于一种隐式正则化,减少了 attention 输出的同质化
- 高分辨率加速更显著:SiT-B/2 在 512² 分辨率加速 15.7%,在 1024² 分辨率加速 45.4%,线性复杂度的优势随分辨率增长而放大
- 动态策略优于静态:通过时间步自适应调整中介者数量,在相同 FLOPs 预算下始终取得更好的 FID
- L1 距离优于 L2:在 latent 变化量度量中,L1 距离是更好的阈值判据
亮点与洞察¶
- 从冗余分析到方法设计的self-contained逻辑:先用JSD量化冗余→发现时间步变化规律→中介者机制解决冗余→动态调整适应变化规律,整条技术路线一气呵成
- "质量提升+效率提升"的罕见双赢:通常加速方法会牺牲质量,但这里压缩冗余反而改善了特征多样性
- 中介者的语义理解:中介者tokens不仅是计算优化手段,还可以理解为对 latent 信息的语义压缩——用少量代表性表示引导生成过程
- 样本自适应的无训练调度:动态阈值基于 latent 变化量,无需额外训练网络来决定何时切换
局限与展望¶
- 仅替换部分层:XL 模型仅替换前4层为中介者注意力,未全面探索最优替换策略
- 阈值搜索:动态调整的阈值 \(\rho_i\) 通过 grid search 获取,搜索空间有限
- 仅验证了类条件生成:未在 text-to-image(如 Stable Diffusion)等更实际的场景中验证
- 训练成本未大幅降低:方法主要在推理阶段加速,训练阶段的效率提升有限
- 未与其他加速方法组合:如 distillation、step reduction 等方法可能互补
相关工作与启发¶
- Agent Attention [ECCV 2024]:在视觉识别任务中也使用额外tokens作为Q-K桥梁,本文将此思路延伸到扩散生成
- SiT [ICML 2024]:本文的主要基座模型,引入插值框架从离散到连续时间
- DiT [ICCV 2023]:证明了ViT在扩散模型中的可扩展性,本文在其架构上优化注意力效率
评分¶
- 新颖性: ⭐⭐⭐⭐ — 冗余分析驱动的中介者设计+时间步动态策略,思路清晰独特
- 实验充分度: ⭐⭐⭐⭐⭐ — 多尺度模型、多分辨率、详细消融、与多种方法对比、可视化验证
- 写作质量: ⭐⭐⭐⭐ — 从观察到方法的叙事逻辑流畅,复杂度分析清晰
- 价值: ⭐⭐⭐⭐ — SOTA FID + 计算减少,对DiT推理优化有直接实践价值