VEDA: Scalable Video Diffusion via Distilled Sparse Attention¶
会议: ICML 2026
arXiv: 2605.30325
代码: 待确认
领域: 视频生成 / 扩散模型 / 模型加速
关键词: 稀疏注意力, 视频扩散 Transformer, 蒸馏学习, 硬件优化
一句话总结¶
VEDA 把视频 DiT 的稀疏注意力问题重新表述为"对全注意力结构的显式蒸馏"——通过统计感知的瓦片评分 + 头感知分组搜索 + 硬件高效内核,在 90-95% 极端稀疏度下保持生成质量,给 Waver-12B 720P 10 秒视频带来 5.1× 端到端加速、10.5× 注意力加速。
研究背景与动机¶
领域现状:视频扩散 Transformer(DiT)已是高保真视频合成主流,但自注意力 \(O(N^2)\) 计算瓶颈在高分辨率长时序生成时极严重。
现有痛点:现有稀疏注意力方法在高度剪枝(≥ 90%)下有两个根本问题: - 静态方法(SVG、STA)依赖预定义时空掩膜,缺对头部特异性注意力几何的自适应性。 - 动态方法(VSA、VMOBA)通过隐式学习,缺显式监督;使用均值池化等粗糙统计量会忽略关键的信号峰值。
核心矛盾:高度稀疏剪枝导致"水纹畸变 / 空间翘曲 / 时间闪烁"等结构性伪影。但实验发现这不是稀疏比例本身造成的,而是稀疏掩膜与全注意力的瓦片级结构对齐度不足导致。
本文目标:在保持生成质量前提下实现视频 DiT 的激进稀疏化与实际加速。
切入角度:关键观察——"神谕级"掩膜(从全注意力 Top-k 得到)即便在 90% 稀疏度下也能保持高质量。这启发显式监督瓦片选择目标,而非依赖扩散目标的隐式学习。
核心 idea:把稀疏瓦片选择重新表述为对全注意力结构的显式蒸馏,加上头感知分组应对头部异质性,结合硬件高效内核实现真实加速。
方法详解¶
整体框架¶
VEDA 三个核心模块: - 蒸馏瓦片评分:用轻量级估计器学重建全注意力的瓦片级评分,把 token 级密集注意力映射为稀疏瓦片掩膜。 - 头感知分组搜索:为每个注意力头搜索最优瓦片分组 \((p_t, p_h, p_w)\)。 - 硬件高效内核:通过 ThunderKittens DSL 和 NVIDIA Hopper TMA 实现瓦片跳过注意力内核,达 FlashAttention-3 80% 运算效率。
关键设计¶
-
统计感知的瓦片评分估计器(TripPool):
- 功能:通过压缩的瓦片表示重建全注意力的瓦片级评分,用以生成稀疏掩膜。
- 核心思路:对每个查询 / 键瓦片构造 TripPool 描述子——均值 / 最大值 / 最小值的连接 \(\text{TripPool}[\cdot] = \text{Avg}[\cdot] \oplus \text{Max}[\cdot] \oplus \text{Min}[\cdot]\)。再通过头特异性 MLP 投影 \(\phi_q, \phi_k\) 映射至共享潜在空间,计算预测评分 \(S_{ij}^{\text{pred}} = \frac{\phi_q(\text{TripPool}[\tilde{Q}_i]) \cdot \phi_k(\text{TripPool}[\tilde{K}_j])^\top}{\sqrt{d'}}\)。最后用 KL 散度损失 \(\mathcal{L}_{\text{distill}} = \mathcal{D}_{KL}(A^{\text{tgt}} \| A^{\text{pred}})\) 对齐预测与全注意力。
- 设计动机:相比平均池化忽略信号峰值,TripPool 的最大 / 最小统计保留关键依赖;显式蒸馏目标避免隐式学习的漂移;关键的停梯度操作解耦掩膜学习与特征学习,防止扰动预训练生成流形。
-
头感知分组搜索:
- 功能:为每一层每个注意力头离线搜索最优时空瓦片分组配置。
- 核心思路:把瓦片配置限制在硬件瓦片大小 \(B\) 的因子分解 \(\Omega = \{(p_t, p_h, p_w) \in \mathbb{N}^3 \mid p_t p_h p_w = B\}\)。对每个候选 \(\pi\),在校准集上最小化全注意力输出的稀疏近似误差 \(\pi^*_{l, h} = \arg\min_{\pi \in \Omega} \mathbb{E}_{x \sim \mathcal{D}_{\text{cal}}} \|O^{\text{fu}}_{l, h}(x) - O^{\text{sp}}_{l, h}(x; \pi)\|_F^2\)。
- 设计动机:注意力头在空间 / 时间依赖上存在显著异质性;统一分组在高稀疏度下导致瓦片回忆率下降;针对性配置能保留不同头的关键信息。
-
瓦片跳过硬件内核 + 两阶段训练:
- 功能:把稀疏掩膜高效执行在 GPU 上;通过稳定的两阶段训练避免收敛问题。
- 核心思路:第一阶段冻结骨干只训投影器 1000 步对齐稀疏预测;第二阶段解冻所有参数在目标稀疏度下微调。利用异步 TMA + Warp 特化:生产者 warp 从全局内存非连续抓取选定的键 / 值瓦片到共享内存,消费者 warp 同时执行张量核心运算,达约 80% FlashAttention-3 效率。
- 设计动机:两阶段解耦避免梯度反传破坏预训练流形;硬件优化确保算法稀疏性转化为实际端到端加速而非内核开销。
训练目标¶
\(\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{diff}} + \lambda \mathcal{L}_{\text{distill}}\),KL 散度行级蒸馏 + 标准扩散去噪。停梯度确保骨干特征不接收掩膜估计器的梯度反传——实验证明允许梯度反传会导致显著生成质量下降。
实验关键数据¶
主实验(Waver-1B 与 Wan2.1-1.3B 上对比全注意力与 VSA)¶
| 模型 | 方法 | 稀疏度 | 主体一致性 | 背景一致性 | 运动平滑 | 美学质量 | 端到端时间 |
|---|---|---|---|---|---|---|---|
| Waver-1B | 全注意力 | 0% | 0.938 | 0.955 | 0.979 | 0.693 | 69.3s |
| Waver-1B | VSA | 87.5% | 0.933 | 0.949 | 0.978 | 0.692 | 34.3s |
| Waver-1B | VEDA | 90% | 0.940 | 0.954 | 0.980 | 0.699 | 31.9s |
| Waver-1B | VEDA | 95% | 0.934 | 0.951 | 0.978 | 0.698 | 30.6s |
| Wan2.1-1.3B | 全注意力 | 0% | 0.940 | 0.969 | 0.977 | 0.670 | 58.5s |
| Wan2.1-1.3B | VEDA | 90% | 0.887 | 0.941 | 0.972 | 0.663 | 37.6s |
消融实验¶
| 组件 | 配置 | 指标 ↓ | 说明 |
|---|---|---|---|
| 瓦片统计 | 平均池化 | 0.965 | 忽略峰值 |
| 瓦片统计 | 最大 / 最小 | 0.982 | 遗漏中等重要性 |
| 瓦片统计 | TripPool | 0.912 | 保留关键依赖 |
| 分组策略 | 静态 [8, 8, 2] | +3.2% 运动质量损失 | 偏空间 |
| 分组策略 | 静态 [4, 4, 8] | 基准 | 均衡配置 |
| 分组策略 | 头感知动态 | +7.2% 运动 / +9.6% 总体 | 适应头部异质性 |
关键发现¶
- 掩膜精度主导性能:90% 固定稀疏度下"神谕"掩膜的生成质量远优于平均池化掩膜——问题根源不在稀疏比例而在对齐质量。
- 头部异质性显著:不同层不同头的空间 / 时间依赖模式差异大,统一分组在高稀疏度下不行。
- 可扩展性:Waver-12B 720P 10 秒视频生成实现 5.1× 端到端加速 + 10.5× 注意力加速,注意力开销从 92% 降到 50%;序列越长 VEDA 加速越大。
亮点与洞察¶
- 实验性的根本观察:"神谕掩膜"实验精准定位真正瓶颈是结构对齐度而非稀疏比例,推翻既往假设并奠定方法设计基础。
- 显式监督的范式转变:相比让扩散目标隐式形塑稀疏结构,显式蒸馏直接监督瓦片评分,避免隐式学习的漂移;停梯度操作的设计巧妙保护预训练生成流形。
- 头感知分组的精细化设计:识别头部异质性并针对性搜索时空分组配置,比同期 VSA 等静态 / 全局动态方法更细粒度,可迁移到其他多头 Transformer 加速任务。
- 硬件-算法协设计:从 TMA 异步传输到 Warp 特化的完整内核实现,把 FLOPs 理论减少转化为真实端到端加速,工程闭环完整。
局限与展望¶
- 两阶段训练虽稳定但需手工设计学习率 / 步数,通用性待提升。
- 95%+ 稀疏度下仍需更多 kernel 融合以提升 MFU。
- 头感知分组依赖离线校准集,不同数据分布下可能需重新搜索。
- TripPool 对异常分布的鲁棒性未充分讨论(最大 / 最小值易被离群值影响)。
相关工作与启发¶
- vs SVG / STA(静态稀疏):依赖预定义模式缺自适应性;本文通过显式蒸馏实现内容与头部敏感的动态选择。
- vs VSA / VMOBA(动态稀疏):依赖隐式扩散目标 + 粗糙池化;本文显式蒸馏 + 精细统计量更准确捕捉全注意力结构。
- vs 其他加速(缓存复用 PAB / TeaCache、蒸馏 CausVid):VEDA 与它们正交,可叠加使用。
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ 在视频 DiT 稀疏化上首次系统引入显式监督 + 头感知分组;"掩膜精度主导" 的实验性发现改变了对稀疏注意力瓶颈的理解。
- 实验充分度: ⭐⭐⭐⭐⭐ 多模型规模(1B / 12B)、多分辨率(480P / 720P)、长序列(34K-245K)、人类评估 + VBench、消融细致。
- 写作质量: ⭐⭐⭐⭐⭐ 逻辑清晰层层递进,实验驱动的发现说服力强,方法各模块独立贡献明确。
- 价值: ⭐⭐⭐⭐⭐ 5.1× 加速对工业应用意义重大;稀疏注意力设计思路对 LLM 加速也有参考价值。