SPRINT: Sparse-Dense Residual Fusion for Efficient Diffusion Transformers¶
会议: ICLR2026
OpenReview: https://openreview.net/forum?id=aTVollXaaI
代码: 待确认
领域: 扩散模型 / 图像生成 / 高效训练
关键词: 稀疏训练, Diffusion Transformer, token dropping, 残差融合, 高效采样
一句话总结¶
SPRINT 把扩散 Transformer 的浅层密集局部特征和深层稀疏全局特征用残差方式融合起来,使 DiT 能在 75% token dropping 下高效预训练,并进一步用 Path-Drop Guidance 降低采样成本。
研究背景与动机¶
领域现状:Diffusion Transformer(DiT)已经成为高质量图像生成的重要骨干,从 DiT、SiT 到更大的 rectified-flow transformer,都依赖长序列 patch token 上的自注意力来建模图像结构。问题是自注意力对 token 数量近似二次增长,图像分辨率越高、patch 越多,预训练成本和显存压力就越难承受。
现有痛点:一个直接想法是在训练时丢掉部分 token,让中间层只处理更短的序列。可 naive token dropping 会破坏空间覆盖,尤其在扩散模型的噪声输入上,模型既要恢复局部高频细节,又要理解全局语义结构;如果深层看到的 token 太少、浅层信息又没有被可靠保留,最终推理时面对完整 token 序列就容易出现表示退化和 train-inference gap。
核心矛盾:DiT 的所有层都以同样方式处理全量 token,但不同深度的层承担的职责并不一样。浅层更接近 noisy patch,本来就适合保留局部纹理、噪声和边缘等细粒度信息;深层则更适合建模对象形状、类别和跨区域关系。让深层继续对每个局部 token 做昂贵计算,既浪费 FLOPs,也没有显式鼓励深层专注于全局语义。
本文目标:作者想解决三个子问题:第一,在高 drop ratio(典型为 75%)下稳定训练 DiT;第二,让稀疏训练学到的表示能迁移到完整 token 推理;第三,在训练之外,把这种双路径结构也转化成更便宜的 guidance 采样。
切入角度:论文的观察是“浅层密集信息”和“深层稀疏语义”不是互相替代,而是互补。浅层如果保留全量 token,就能为最终 velocity prediction 提供局部证据;深层如果只看结构化采样后的少量 token,反而被迫学习更全局、更 noise-invariant 的上下文。
核心 idea:用“浅层全 token 编码 + 中间层稀疏 token 处理 + 残差融合恢复全 token 解码”替代传统 DiT 的全层密集计算,从架构上把局部细节和全局语义分工开。
方法详解¶
整体框架¶
SPRINT 从一个标准 DiT/SiT 出发,把 transformer block 切成三段:前两层作为密集编码器 \(f_\theta\),中间若干层作为稀疏深层路径 \(g_\theta\),最后两层作为密集解码器 \(h_\theta\)。训练时,输入 noisy latent patch tokens 先经过 \(f_\theta\) 得到全量浅层特征;随后按结构化采样保留约 25% token 给 \(g_\theta\),再把深层输出 pad 回原长度,与浅层特征融合后交给 \(h_\theta\) 预测所有 token 的 velocity。
这套流程不是简单“少算一点 token”,而是把被丢掉 token 的局部信息通过 dense shallow path 保留下来,让 expensive middle blocks 只负责更稀疏、更语义化的上下文建模。预训练结束后,模型再用很短的 full-token fine-tuning 让中间层适应完整输入;推理时还可以利用浅层路径作为天然的弱 unconditional model,构成 Path-Drop Guidance(PDG)。
%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400}}}%%
flowchart TD
A["noisy latent tokens"] --> B["密集浅层路径<br/>保留局部证据"]
B --> C["结构化 token 子采样<br/>75% dropping"]
C --> D["稀疏深层路径<br/>建模全局语义"]
D --> E["稀疏-密集残差融合<br/>恢复全 token 表示"]
B --> E
E --> F["短程全 token 微调<br/>闭合训练推理差距"]
F --> G["Path-Drop Guidance<br/>低成本采样"]
关键设计¶
1. 密集浅层路径:把局部噪声证据从昂贵深层计算里解耦出来
SPRINT 首先保留一个处理全量 token 的浅层编码器 \(f_\theta\)。给定 noisy tokens \(x_t \in \mathbb{R}^{B \times N \times C}\),模型先计算 \(f_t=f_\theta(x_t)\),这里的 \(f_t\) 仍然覆盖所有空间位置。这样做的关键不是“多接一条 skip connection”这么简单,而是承认扩散 velocity prediction 对局部噪声、纹理和边缘信息非常敏感;这些信息如果在 token dropping 时直接丢掉,最终 decoder 很难凭少量语义 token 补回来。
论文的可视化也支持这个分工:只用 dense shallow path 时,样本的局部纹理相对完整,但全局结构不稳定。这说明浅层路径确实像一条局部证据高速通道,负责把每个位置的低层信息绕过高成本的中间 transformer blocks,送到最终融合处。它解决的是高比例 dropping 下“被丢 token 没有局部依据可恢复”的问题。
2. 稀疏深层路径:让中间 blocks 用少量 token 学全局语义
在 \(f_t\) 之后,SPRINT 对 token 做 dropping 得到 \(f_t^{drop}\),再送入中间 blocks:\(g_t^{drop}=g_\theta(f_t^{drop})\)。默认设置使用 75% drop ratio,也就是只保留约四分之一 token 给计算最贵的深层 transformer。由于 DiT 的注意力成本随序列长度快速增长,这一步直接减少了中间层 FLOPs 和显存压力。
更重要的是,深层路径不再被迫在所有局部 patch 上重复建模细节,而是从稀疏覆盖的 token 中学习对象轮廓、类别语义和长程关系。论文中只用 sparse deep path 的可视化会得到更清晰的全局形状但局部纹理破损,正好说明这条路径学到的是偏全局的结构信息。SPRINT 的关键判断是:深层少看一些 token 不一定更弱,只要浅层局部信息还在,深层稀疏化反而能促成更好的语义表示。
3. 稀疏-密集残差融合:用 mask token 对齐序列并恢复完整预测
因为 \(g_\theta\) 只处理保留 token,输出 \(g_t^{drop}\) 的序列长度小于原始 \(N\)。SPRINT 将被丢位置用固定 [MASK] token pad 回来,得到 \(g_t^{pad}\in\mathbb{R}^{B\times N\times C}\),再把它与密集浅层特征 \(f_t\) 在 channel 维拼接并投影回原维度,最后由解码器 \(h_\theta\) 预测完整 token 的 velocity。
这个融合机制的作用是把“每个位置都有的局部证据”和“少数位置带来的全局语义”重新放到同一个全长序列里。相比只对保留 token 预测或额外训练复杂 decoder,SPRINT 的改动很小:标准 DiT block 基本不变,只增加一个融合投影层,参数量约增加 0.3%。消融结果也很直接:只保留 dense path 或 sparse path 都会让 FID 大幅恶化,二者同时存在时才出现 27.5 的 FID,说明融合不是装饰,而是高比例稀疏训练能工作的核心。
4. 结构化分组采样:避免随机 dropping 留下空间空洞
单纯 uniform random dropping 会出现局部区域完全没 token 被保留的情况,这对图像生成很危险,因为某些区域的语义上下文会在深层路径里彻底缺席。SPRINT 因此按图像 token 的原始二维拓扑做 group-wise subsampling:把 patch grid 划成 \(n\times n\) 小组,每组随机保留 \(k\) 个 token,drop ratio 为 \(r=1-k/n^2\)。主设置取 \(n=2,k=1\),正好得到 75% dropping。
这个策略在局部覆盖和全局随机性之间做了非常具体的约束:每个 \(2\times2\) 局部区域至少有一个 token 进入深层,避免大块空洞;每次迭代又随机选择组内 token,防止模型记住固定采样模式。论文的消融显示,同样 75% drop ratio 下,结构化采样 FID 为 27.5,而随机采样是 30.1,说明 token dropping 的位置选择和 drop 数量一样重要。
5. Path-Drop Guidance:用浅层路径替代 CFG 的完整 unconditional pass
标准 classifier-free guidance(CFG)每个采样步需要 conditional 和 unconditional 两次完整前向,几乎把推理 FLOPs 翻倍。SPRINT 的双路径结构提供了一个天然的弱模型:unconditional 分支可以绕过 \(g_\theta\),只用 \(f_\theta\) 的 dense shallow features 与 [MASK] 填充后的深层占位进行融合。因此 PDG 中 conditional velocity 仍走完整路径,而 unconditional velocity 写作:\(v(x_t,\emptyset)=h_\theta(\mathrm{Fusion}(M,f_\theta(x_t,\emptyset)),\emptyset)\)。
这其实把“坏一点但便宜的 unconditional model”内生到同一个网络里,思想上接近 Auto Guidance,但无需另训弱网络。为让模型适应这种采样方式,训练和微调阶段还会以 10% 概率随机 drop sparse-deep path,用 mask token 替代。最终 PDG 在 ImageNet 256 上把 inference TFLOPs 从约 0.477 降到 0.274,同时 FDD/FID 还从 SPRINT+CFG 的 75.4/1.96 进一步变为 58.4/1.62。
一个完整示例¶
假设一张 256×256 图像被 VAE 编成 latent,再按 patch size 2 切成一串 noisy latent tokens。传统 SiT 会让所有 1024 个位置在 28 层 transformer 中一路密集计算;SPRINT 则先让前 2 层看完全部 1024 个 token,得到每个位置的浅层局部证据。
进入中间 24 层前,结构化分组采样把每个 \(2\times2\) 小组保留 1 个 token,于是深层只处理约 256 个 token。中间层输出后,模型把另外 768 个缺失位置填成 [MASK],恢复成 1024 长度的序列,并与浅层 1024 个 dense features 拼接融合。最后 2 层 decoder 看到的是“每个位置的局部证据 + 稀疏位置传来的全局语义”,所以仍然可以对全部 token 预测 velocity。
到推理采样时,如果要做 guidance,conditional pass 走完整 SPRINT;unconditional pass 不再跑中间 24 层,而是把 deep path 全部替换成 [MASK]。这就是为什么 PDG 能接近砍掉一半 guidance 成本:昂贵的中间 blocks 在每个采样步只为 conditional 分支执行一次。
损失函数 / 训练策略¶
SPRINT 沿用标准 flow matching / diffusion velocity loss。给定真实样本 \(x_0\) 和高斯噪声 \(x_1\),论文采用线性 schedule:\(x_t=(1-t)x_0+t x_1\),目标是让网络预测 velocity \(v\),优化 \(\mathbb{E}\|v(x_t,t)-v_\theta(x_t,t)\|^2\)。也就是说,SPRINT 的主要训练信号没有换成额外重建任务,方法收益主要来自稀疏-密集结构和采样策略。
训练分两阶段。第一阶段是长程 sparse pre-training,默认 75% token dropping,batch size 256、learning rate \(10^{-4}\)、EMA 0.9999。第二阶段是短程 full-token fine-tuning,中间 blocks 改为处理完整 token 序列,常用 100K 到 200K iterations;论文指出 20K steps 已经恢复了 200K fine-tuning 效果的 94% 以上。两阶段都加入 10% path-drop learning,使模型在训练中见过“deep path 被 mask 掉”的情况,从而支撑 PDG 推理。
实验关键数据¶
主实验¶
论文主要在 ImageNet-1K class-conditional generation 上评估,使用 SD VAE / Flux VAE latent、SiT-B/2 和 SiT-XL/2 等配置。指标包括 FDD、FID、IS、Precision/Recall,其中作者强调 FDD 对扩散模型语义质量更可靠。
| 设置 | 方法 | 训练成本 | FDD ↓ | FID ↓ | 备注 |
|---|---|---|---|---|---|
| ImageNet 256, SiT-XL/2, 400K, SD VAE, w/o CFG | Improved SiT-XL/2 | 24.4×10^6 TFLOPs | 351.1 | 12.8 | dense baseline |
| ImageNet 256, SiT-XL/2, 400K, SD VAE, w/o CFG | SPRINT | 18.7×10^6 TFLOPs | 262.6 | 9.30 | 75% token dropping |
| ImageNet 256, SiT-XL/2, 1M, SD VAE, w/ CFG | Improved SiT-XL/2 | 61.2×10^6 TFLOPs | 146.0 | 2.36 | full-token baseline |
| ImageNet 256, SiT-XL/2, 1M, SD VAE, w/ CFG | SPRINT | 31.5×10^6 TFLOPs | 126.1 | 2.29 | 约 1.94× 训练节省 |
| ImageNet 256, SDE 250 steps | SiT-XL† | 122.2×10^6 TFLOPs | 79.5 | 2.04 | 400 epochs |
| ImageNet 256, SDE 250 steps | SPRINT + PDG | 65.1×10^6 TFLOPs | 58.4 | 1.62 | inference 0.274 TFLOPs |
| ImageNet 512, SDE 250 steps | SiT-XL | 366.6×10^6 TFLOPs | - | 2.62 | baseline |
| ImageNet 512, SDE 250 steps | SPRINT + PDG | 184.8×10^6 TFLOPs | 46.9 | 1.96 | 约半训练/推理成本 |
| 兼容性实验 | 方法 | w/o CFG FDD ↓ | w/o CFG FID ↓ | w/ CFG FDD ↓ | w/ CFG FID ↓ | 结论 |
|---|---|---|---|---|---|---|
| REPA, 400K | Improved SiT-XL/2REPA | 279.6 | 10.0 | 146.6 | 2.42 | alignment baseline |
| REPA, 400K | REPA + SPRINT | 234.5 | 8.68 | 125.1 | 2.38 | FDD/FID 均改善 |
| U-ViT, 400K | Improved U-ViT-XL/2 | 335.1 | 12.1 | 193.7 | 3.36 | U 型 transformer baseline |
| U-ViT, 400K | U-ViT + SPRINT | 271.7 | 9.20 | 146.4 | 2.97 | 说明方法不依赖 SiT |
消融实验¶
| 配置 | 关键指标 | 说明 |
|---|---|---|
| 随机 token sampling | FID 30.1 | 同样 75% dropping,但局部覆盖不稳定 |
| 结构化 group-wise sampling | FID 27.5 | 每个局部 group 保留 token,效果更好 |
| 只保留 sparse path | FID 85.1 | 全局结构尚可但局部信息严重不足 |
| 只保留 dense path | FID 81.4 | 局部纹理存在但缺少深层全局语义 |
| dense + sparse 双路径 | FID 27.5 | 两条路径互补,是核心收益来源 |
| drop ratio 0% | FID 54.1 | 没有稀疏压力,收益有限 |
| drop ratio 50% | FID 32.3 | 已明显优于 dense 训练 |
| drop ratio 75% | FID 27.5 | 主设置,性能最好 |
| drop ratio 87.5% | FID 50.2 | 过度稀疏导致容量不足 |
| f/g/h = 2/8/2 | FID 27.5, 7.47G FLOPs/iter | 默认切分,成本最低且效果最好 |
| f/g/h = 3/6/3 | FID 29.1, 9.33G FLOPs/iter | 把层挪出 middle block 后成本上升 |
| f/g/h = 5/2/5 | FID 49.2, 13.1G FLOPs/iter | 中间语义建模层太少,效果变差 |
关键发现¶
- dense shallow path 和 sparse deep path 不是可互换组件。前者保留局部纹理和噪声证据,后者负责对象形状和全局语义;任何一条单独使用都会让 FID 从 27.5 恶化到 80 以上。
- 75% dropping 在 SPRINT 里不是勉强可用,而是主设置下效果最强;但 87.5% 已经过稀疏,说明 token 数量仍需保留最低语义容量。
- 短程 fine-tuning 很关键但不昂贵。论文显示 20K steps 已恢复 200K fine-tuning 收益的 94% 以上,说明 sparse pre-training 学到的表示可以较快迁移到 full-token regime。
- PDG 不只是省 FLOPs,也改善质量。ImageNet 256 上,SPRINT+PDG 把 inference TFLOPs 从约 0.477 降到 0.274,同时把 FDD 从 75.4 降到 58.4。
亮点与洞察¶
- 最巧妙的点是把 token dropping 从“删输入省算力”改写成“层级职责分工”。浅层全量、深层稀疏、末层融合,刚好对应局部细节、全局语义和 dense prediction 三种需求。
- 残差融合非常轻量,没有引入复杂 auxiliary decoder 或额外重建任务,却能支持 75% drop ratio。这让它比一些专门 token-dropping 架构更容易接到现有 SiT、REPA、U-ViT 代码里。
- 结构化 group-wise sampling 是一个容易被低估的细节。它没有学习额外策略,却用空间拓扑约束避免大块区域缺席,说明高效训练里“保留哪些 token”比“保留多少 token”同样重要。
- PDG 把训练时的双路径架构复用到推理 guidance,是很干净的系统设计。它让同一个模型内部天然包含一个弱 unconditional path,不需要额外训练小模型,也不需要做复杂蒸馏。
- 对其他任务的启发是:当 transformer 的浅层和深层职责天然不同,可以考虑让高成本深层只处理语义代表 token,再用低成本 dense residual 保持逐位置预测能力。视频生成、3D token 生成和 VLM diffusion decoder 都可能受益。
局限与展望¶
- 主要实验集中在 ImageNet class-conditional image generation,虽然论文声称方法可扩展到视频等模态,但跨模态长序列生成还需要更直接的大规模验证。
- SPRINT 仍需要一次短程 full-token fine-tuning 来闭合 train-inference gap。如果训练预算极端受限,fine-tuning 的额外复杂度和超参选择仍是工程成本。
- PDG 的最佳 guidance scale、path-drop probability 目前依赖经验设置。不同 backbone、分辨率、采样器和 VAE latent 空间下,是否稳定保持“更便宜且更好”还值得系统研究。
- 结构化采样目前是固定 \(2\times2\) group 中保留 1 个 token。未来可以探索自适应采样,例如根据 timestep、噪声强度、局部纹理复杂度或类别语义动态调整保留比例。
- 方法对超高分辨率、复杂文本条件生成和真实大模型训练的收益可能更大,但论文没有开源训练代码,复现成本和细节可信度仍需要社区验证。
相关工作与启发¶
- vs MaskDiT: MaskDiT 也通过 masked token 训练加速扩散模型,但需要额外 decoder / reconstruction 目标,且主要适用于中等 masking ratio。SPRINT 不额外引入重建任务,而是通过 dense residual path 保留局部信息,因此能更稳地支持 75% dropping。
- vs MicroDiT: MicroDiT 用 patch-mixer 等额外模块缓解高 masking ratio 下的信息损失,但会增加参数和计算。SPRINT 的参数增量很小,核心依赖浅层密集路径与深层稀疏路径的分工。
- vs TREAD: TREAD 让部分 token 绕过中间层,但中间层仍需要承担更多局部 velocity prediction 负担。SPRINT 则给 decoder 一条完整 dense shallow path,使中间层更专注于全局上下文。
- vs REPA: REPA 用 DINOv2 对齐中间表示来加速收敛,SPRINT 用结构化稀疏训练改变计算路径。二者不是竞争关系,论文实验显示 SPRINT 可以叠加到 REPA 上继续改善 FDD/FID。
- vs Progressive Training: Progressive training 先低分辨率再高分辨率,省的是早期分辨率成本;SPRINT 直接在目标 token grid 上减少中间层序列长度,因此更贴近 DiT 的 attention 瓶颈。
评分¶
- 新颖性: ⭐⭐⭐⭐☆ 把浅层密集、深层稀疏和 residual fusion 组合得很简洁,单个组件不复杂,但整体分工清晰且有效。
- 实验充分度: ⭐⭐⭐⭐☆ 覆盖 ImageNet 256/512、SiT/REPA/U-ViT、多组消融和 PDG 分析,但缺少真实文本到图像大模型或视频生成上的主实验。
- 写作质量: ⭐⭐⭐⭐☆ 论文动机、消融和系统收益讲得清楚,figures 能支撑“局部/全局分工”的叙事;部分 compute 表述需要读者仔细对齐不同 sampler/epoch 设置。
- 价值: ⭐⭐⭐⭐⭐ 对 DiT 训练成本问题非常实用,改动小、可叠加、同时影响训练和推理,值得高效扩散模型方向重点关注。