DiffusionBlocks: Block-wise Neural Network Training via Diffusion Interpretation¶
会议: ICLR2026
arXiv: 2506.14202
代码: SakanaAI/DiffusionBlocks
领域: 图像复原
关键词: block-wise training, diffusion models, score matching, memory efficiency, residual networks
一句话总结¶
提出 DiffusionBlocks,将残差网络的逐层更新解释为连续时间扩散过程的离散化步骤,从而将网络切分为可完全独立训练的 block,在保持端到端训练性能的同时按 block 数 B 倍减少训练显存。
背景与动机¶
- 端到端反向传播需要存储所有层的中间激活值,显存随网络深度线性增长,严重制约模型规模和实际部署
- 已有的 block-wise 训练方法(如 Forward-Forward、greedy layer-wise training)依赖临时性的局部目标函数,缺乏理论保证,且基本仅在分类任务上验证,无法自然扩展到生成任务
- Score-based diffusion models 的去噪目标天然具有"各噪声级别可独立优化"的性质——这恰好为 block-wise 训练提供了缺失的理论基础
- 残差连接(ResNet、Transformer 等)的更新规则 \(\mathbf{z}_{\ell+1} = \mathbf{z}_\ell + f_{\theta_\ell}(\mathbf{z}_\ell)\) 可对应扩散过程 probability flow ODE 的 Euler 离散化
核心问题¶
如何为基于 Transformer 的网络设计一种有理论根据的 block-wise 训练框架,使得:
- 每个 block 可以完全独立训练(不需要其他 block 的梯度或激活值)
- 与端到端训练保持竞争力
- 能通用于分类和生成等多种任务/架构
方法详解¶
整体框架¶
DiffusionBlocks 的出发点是一个被忽视的对应关系:残差网络逐层叠加的更新,本质上就是一条连续时间扩散过程的离散化轨迹。顺着这个视角,它把 \(L\) 层网络切成 \(B\) 个 block,让每个 block 只负责扩散过程中某一段噪声级别的去噪。具体落地分三步:先把网络切成 \(B\) 个连续 block,再把噪声区间 \([\sigma_{\min}, \sigma_{\max}]\) 按等概率切成 \(B\) 段一一分配给各 block,最后给每个 block 做噪声条件化改造,使它能在自己那段噪声里独立去噪。改造完成后,\(B\) 个 block 各自拥有一个完整、互不通信的去噪损失——训练时无需在 block 之间传递梯度或激活,任何时刻只缓存一个 block 的激活,显存随之按 \(B\) 倍下降;推理时再从高噪声到低噪声依次调用各 block 还原目标。
%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400, 'subGraphTitleMargin': {'top': 8, 'bottom': 16}}}}%%
flowchart TD
IN["残差网络 L 层<br/>z(l+1)=z(l)+f(z(l))"]
LENS["残差连接 = 扩散离散步<br/>(理论根:各噪声级别可独立优化)"]
subgraph CONV["三步转换流程"]
direction TB
S1["block 划分<br/>L 层 切成 B 个 block"]
S2["噪声范围分配<br/>噪声区间切成 B 段"]
S3["噪声条件化改造<br/>输入加 z_σ,AdaLN 注入 σ"]
S1 --> S2 --> S3
end
EQ["等概率划分<br/>按 log-normal 质量等分每段"]
LOSS["block 独立去噪损失<br/>各 block 一张卡,互不通信<br/>→ 训练显存按 B 倍缩减"]
OUT["顺序推理<br/>高噪声 → 低噪声依次调用 block"]
IN --> LENS --> CONV
EQ -.->|约束切分| S2
CONV --> LOSS --> OUT
关键设计¶
1. 残差连接 = 扩散过程离散步:为 block 独立训练找到理论根
block-wise 训练之所以一直缺乏理论保证,是因为没人能说清"局部目标"为什么等价于全局目标。本文在 Variance Exploding (VE) 扩散框架下给出了答案:给定噪声级别序列 \(\sigma_0 > \sigma_1 > \cdots > \sigma_T\),对 probability flow ODE 做 Euler 离散化得到更新式
它与残差网络的 skip connection 更新规则 \(\mathbf{z}_\ell = \mathbf{z}_{\ell-1} + f_{\theta_\ell}(\mathbf{z}_{\ell-1})\) 形式完全一致。一旦把每一层(每个 block)看成扩散轨迹上的一段,score matching 自带的"各噪声级别可独立优化"性质就被继承下来——block 独立训练不再是启发式拼凑,而是有去噪目标背书的等价分解。后面所有设计都是把这个对应关系落到具体架构上的工程化步骤。
2. 三步转换流程:把任意残差网络改造成可分块的扩散去噪器
在上面的对应关系之上,本文用三步把现成架构改造成 DiffusionBlocks,对应框架图里的 CONV 子图。先做 block 划分,把 \(L\) 层网络切成 \(B\) 个连续 block \(\mathcal{F}_1, \ldots, \mathcal{F}_B\);再做 噪声范围分配,定义噪声分布 \(p_{\text{noise}}\)(推荐 log-normal),把区间 \([\sigma_{\min}, \sigma_{\max}]\) 切成 \(B\) 段 \(\{[\sigma_b, \sigma_{b-1}]\}_{b=1}^B\),让第 \(b\) 个 block 专门负责这一段噪声的去噪;最后做 噪声条件化改造,把每个 block 的输入扩展为 \(\tilde{\mathbf{x}} = (\mathbf{x}, \mathbf{z}_\sigma)\),其中 \(\mathbf{z}_\sigma = \mathbf{y} + \sigma\epsilon\),并通过 AdaLN 等机制注入当前噪声级别条件,使同一组参数能在所负责的噪声区间内连续工作。三步走完后,每个 block 都变成一个只看自己那段噪声、独立预测干净目标 \(\mathbf{y}\) 的小去噪器。
3. block 独立的去噪损失:显存与推理同时按 \(B\) 倍缩减
改造后第 \(b\) 个 block 的损失写成
即在自己负责的噪声子分布 \(p_{\text{noise}}^{(b)}\) 上做加权去噪回归。关键在于期望只对本 block 的噪声区间取,所以 \(B\) 个损失彼此完全解耦:可以各放一张卡、各自反向传播、互不等待,而它们的噪声区间拼起来又恰好覆盖整条扩散轨迹,合起来等价于训练完整网络。这正是训练显存按 \(B\) 倍缩减的来源——任何时刻只需缓存一个 block 的激活。推理时则从高噪声到低噪声依次调用各 block,对 diffusion model 而言每个去噪步只需把对应 block 载入显存,于是推理也拿到 \(B\) 倍的显存/调度优势;代价是各 block 必须按序执行、无法并行化去噪步骤。
4. 等概率划分:让每个 block 承担等量去噪难度
噪声区间怎么切,直接决定各 block 的负载是否均衡,因此框架图里用 等概率划分 这条约束去指导"噪声范围分配"那一步。若按 \(\sigma\) 均匀划分,高噪声和低噪声两端会分到大量样本稀疏、信息量低的区域,白白浪费容量。本文改用按 log-normal 累积概率质量等分,即让每段满足 \(\int_{\sigma_{b-1}}^{\sigma_b} p_{\text{noise}}(\sigma)\,d\sigma = 1/B\)。这样每个 block 都处理等量的训练分布,自动在去噪最难、样本最密的中间噪声级别切出更细的区间,把容量用在刀刃上。实验里这一步贡献明显:CIFAR-10 上等概率划分把 FID 从均匀划分的 43.53 压到 38.03。
实验关键数据¶
| 任务 / 架构 | 数据集 | 端到端基线 | DiffusionBlocks | Block 数 / 显存缩减 |
|---|---|---|---|---|
| ViT 分类 | CIFAR-100 | 60.25% Acc | 59.30% Acc | B=3 / 3× |
| DiT 图像生成 | CIFAR-10 | 32.84 FID | 30.59 FID | B=3 / 3× |
| DiT 图像生成 | ImageNet 256 | 12.09 FID | 10.63 FID | B=3 / 3× |
| Masked Diffusion 文本 | text8 | 1.56 BPC | 1.45 BPC | B=3 / 3× |
| AR Transformer 文本 | LM1B | 0.50 MAUVE | 0.71 MAUVE | B=4 / 4× |
| AR Transformer 文本 | OpenWebText | 0.85 MAUVE | 0.82 MAUVE | B=4 / 4× |
| Huginn (recurrent-depth) | LM1B | 0.49 MAUVE | 0.70 MAUVE | 消除 32 次迭代 |
- Forward-Forward 在 CIFAR-100 上仅达 7.85% 准确率,远逊于 DiffusionBlocks
- ImageNet 上 B=2 时 FID=9.90,优于端到端训练 (12.09),适度划分反而提升性能
- 等概率划分在所有层分配方案下均显著优于均匀划分(CIFAR-10 FID: 38.03 vs 43.53)
亮点¶
- 理论基础扎实:从 score matching 的噪声级别独立性出发,自然推导出 block 独立训练目标,非启发式拼凑
- 通用性极强:一套三步转换流程适用于 ViT、DiT、AR Transformer、Masked Diffusion、Recurrent-depth 共五类架构
- 等概率划分是简洁而关键的设计——让每个 block 承担等量去噪难度,无需手工调整层分配
- 多重效率收益:训练 \(B\) 倍显存缩减;diffusion model 推理 \(B\) 倍加速;recurrent-depth 模型省去 BPTT
- 部分场景超越端到端:ImageNet B=2/3 的 FID 优于不分 block 的端到端训练,说明适度专业化有正收益
局限与展望¶
- 实验中 ViT 分类仅在 CIFAR-100 上验证(60.25→59.30),大规模 ImageNet 分类未测试
- 推理时仍需按序调用各 block,无法并行化推理步骤
- 噪声条件化改造(AdaLN 等)增加了少量参数和工程复杂度
- B 过大时性能下降(B=6 时 FID 14.43),block 粒度有下限
- 主要面向 Transformer 类残差架构,对无残差连接的网络适用性未讨论
与相关工作的对比¶
| 方法 | 理论基础 | 任务通用性 | 连续时间 | Block 独立 |
|---|---|---|---|---|
| Forward-Forward | 对比目标 | 仅分类 | ✗ | ✓ |
| NoProp | 扩散相关 | 仅分类 | ✓(CT) 或 ✗(DT) | ✗(CT) 或 ✓(DT) |
| DiffusionBlocks | Score matching | 分类+生成 | ✓ | ✓ |
- NoProp 与自定义 CNN 架构捆绑,无法直接迁移到 Transformer;DiffusionBlocks 在 NoProp 的架构上也优于其所有变体(46.88 vs 46.06/21.31/37.57)
- 与 stage-specific diffusion models (eDiff-I 等) 的区别在于:后者是联合训练或从共享参数微调,DiffusionBlocks 各 block 完全隔离
启发与关联¶
- "残差连接 ≈ 扩散离散化步骤"的视角可进一步推广:任何具有残差结构的深层模型都可能受益于这种分块独立训练
- 等概率划分思想可迁移到其他需要"分段处理不同难度子任务"的场景(如课程学习、多尺度训练)
- 对 recurrent-depth 模型消除 BPTT 的能力值得关注——随着 universal transformer / Huginn 等模型兴起,该方法可降低其训练成本
- 结合模型并行(每个 block 放不同 GPU),可实现更激进的深度扩展
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ — 将扩散独立性引入 block-wise 训练是原创性极高的理论贡献
- 实验充分度: ⭐⭐⭐⭐ — 五类架构覆盖面广,但分类任务规模偏小
- 写作质量: ⭐⭐⭐⭐⭐ — 数学推导清晰,三步流程直观易懂
- 价值: ⭐⭐⭐⭐ — 为大模型训练显存瓶颈提供了有理论保证的新范式