跳转至

PPCL: Pluggable Pruning with Contiguous Layer Distillation for Diffusion Transformers

会议: CVPR 2026
arXiv: 2511.16156
代码: GitHub
领域: 模型压缩 / 扩散模型
关键词: diffusion transformer, structured pruning, contiguous layer redundancy, knowledge distillation, MMDiT

一句话总结

提出 PPCL 框架,针对超大规模 Multi-Modal Diffusion Transformer (MMDiT, 8–20B 参数) 设计结构化剪枝方案:通过线性探针 (Linear Probe) 学习每层的可替代性,结合 CKA 一阶差分自动定位连续冗余层区间,再以非顺序交替蒸馏实现深度+宽度双轴剪枝,最终在 Qwen-Image 20B 上实现 50% 参数缩减、1.8× 推理加速,平均性能仅下降 2.61%。

研究背景与动机

领域现状: 最新的文本到图像 (T2I) 扩散模型已从 UNet 架构全面转向 Multi-Modal Diffusion Transformer (MMDiT)。SDXL 为 2.6B 参数,而 FLUX.1 达 12B、Qwen-Image 达 20B (60 层 MMDiT block),生成质量大幅提升,但推理成本急剧增加。

现有痛点: (a) 已有的结构化剪枝方法 (如 TinyFusion、SnapFusion) 主要面向 UNet 架构,难以直接迁移到 MMDiT 的双流结构;(b) 现有方法逐层独立评估冗余度 (如敏感度分析),忽略了 DiT 中相邻层之间的功能耦合关系;(c) 传统顺序蒸馏中,早期层的压缩误差沿网络传播并累积,导致学生模型的表示严重偏离教师模型。

核心矛盾: 作者通过实验发现 DiT 的冗余呈现深度连续性——移除连续的层比移除等量的非连续层对性能的影响更小。现有剪枝方法未利用这一特性。

本文目标: 如何系统性识别 MMDiT 中的连续冗余层区间,并设计不累积误差的蒸馏方案实现高压缩比下的质量保持。

切入角度: 以"层的可替代性"(substitutability) 替代传统的层重要性评估——若一层的输入-输出映射可被线性变换近似,则该层对其相邻层是功能冗余的。

核心 idea: 在 MMDiT 中,冗余层沿深度方向连续分布,可通过线性探针+CKA 差分自动定位并整段移除,配合非顺序蒸馏消除误差累积。

方法详解

整体框架

PPCL 分两个阶段、三个步骤:

  • 阶段一 (深度剪枝):(1) 为每个 MMDiT block 训练线性探针,评估层的可替代性;(2) 通过 CKA 一阶差分的拐点检测,自动划分连续冗余层区间集合 \(\mathcal{I}\);(3) 以非顺序蒸馏方式训练学生模型,每个区间独立优化。
  • 阶段二 (宽度剪枝):识别文本流 (text stream) 和 FFN 中的冗余,替换为轻量线性投影器,进一步压缩参数。
  • 最终执行短时全参数 fine-tuning。

关键设计

1. 基于线性探针的连续冗余层检测

  • 功能: 自动识别一组不重叠的连续冗余层区间 \(\mathcal{I} = \{[u_i, v_i]\}\)
  • 核心思路: 为教师模型的每层 \(T_i\) 构建带残差结构的线性探针 \(l_i\),通过最小二乘初始化后用对齐损失 \(\mathcal{L}_{fit}(i) = \|l_i(T_{i-1}^D) + T_{i-1}^D - T_i(T_{i-1}^D)\|_2^2\) 训练。训练完成后,在校准集上计算代理模型 (用线性探针替换连续层) 输出与教师层输出的 CKA 相似度,定义一阶差分 \(\Delta(u,k) = -(\text{cka}(u,k) - \text{cka}(u,k-1))\)。当 \(\Delta\) 先减后增时,拐点 \(v\) 标志冗余区间的右端点
  • 设计动机: (a) 线性探针的训练输入与对应层的实际输入一致,保证每层评估独立不干扰;(b) 有限个线性变换的叠加仍满足线性,从而保证连续层替代性的传递;(c) 用 CKA 差分的拐点而非固定阈值,能自适应地定位冗余区间长度
  • 关键公式: 线性探针初始化闭合解 \(W_i^* = (T_i(T_{i-1}^D) - T_{i-1}^D)(T_{i-1}^D)^\top(T_{i-1}^D(T_{i-1}^D)^\top)^{-1}\)

2. 非顺序深度剪枝蒸馏

  • 功能: 对检测到的每个冗余区间 \([u,v]\),用单层学生模型替代,避免误差累积
  • 核心思路: 每个区间独立优化——学生层 \(S^u\) 接收教师模型第 \(u-1\) 层的输出作为输入,对齐教师第 \(v\) 层的输出。损失函数 \(\mathcal{L}_{depth}^{[u,v]} = \|\text{Norm}(S^u(T_{u-1}^D)) - \text{Norm}(T_v^D)\|_2^2\),其中 Norm 为 L2 归一化,强调方向对齐
  • 设计动机: 传统顺序蒸馏中学生模型第 \(k\) 层接收自身第 \(k-1\) 层的输出,早期误差会逐层累积。非顺序方案让每个区间直接从教师模型获取输入,切断误差传播链
  • 即插即用特性: 由于每个区间独立训练,推理时可灵活选择启用或跳过某些区间——例如训练一个 10B 模型后,将部分学生层替换回教师层即可得到 12B 或 14B 变体,无需重训

3. 宽度剪枝:文本流与 FFN 压缩

  • 功能: 在保留层内部进一步压缩参数——替换冗余的文本流整体结构和 FFN
  • 核心思路: (a) 文本流剪枝:CKA 热力图揭示文本流跨层表示高度相似,将冗余层的文本流 (除 QKV 投影外) 替换为两个轻量线性投影 \(l_p^z\)\(l_p^h\);(b) FFN 剪枝:测量用线性投影替代 FFN 的 MSE 极小的层,将其 FFN 替换为线性投影 \(l_q^{img}\)\(l_q^{txt}\)
  • 设计动机: MMDiT 的文本流 token 相似度高、层间变化小,可大幅压缩;FFN 显著过参数化,许多层的 FFN 功能接近线性变换
  • 损失函数: 宽度蒸馏损失由两部分组成——层级对齐损失 \(\mathcal{L}_{width}^j\) (与深度蒸馏格式一致) 加上线性投影对齐损失 \(\mathcal{L}_{linear}^j\) (约束线性投影输出逼近教师对应中间表示)

训练策略

  • 数据: 从 LAION-2B-en 采样 10 万张图,用 Qwen2.5-VL 生成精细描述
  • 训练三阶段: 深度剪枝 6k steps → 宽度剪枝 2k steps → 全参数 fine-tuning 1k steps (8 × H20 GPU)
  • 优化器: AdamW (\(\beta_1\)=0.9, \(\beta_2\)=0.95, weight decay=0.02),BF16 混合精度 + 梯度检查点

实验关键数据

主实验:FLUX.1-dev 上的对比

方法 参数(B) 显存(%) 延迟(ms) DPG↑ GenEval↑ B-VQA↑ UniDet↑ 平均下降(%)↓
Base model 12 100 715 83.8 0.665 0.640 0.426 0
TinyFusion 8 74.4 534 77.2 0.511 0.584 0.369 13.80
HierarchicalPrune 8 74.4 543 75.7 0.503 0.579 0.371 13.38
Dense2MoE 12 100 312 73.6 0.403 0.473 0.311 21.52
FLUX.1 Lite 8 78.8 572 82.1 0.623 0.547 0.379 6.09
Chroma1-HD 8.9 82.5 1714 84.0 0.593 0.621 0.339 1.02
PPCL(8B) 8 74.4 535 80.0 0.605 0.615 0.391 4.03
PPCL(6.5B) 6.5 69.2 428 81.2 0.593 0.581 0.398 0.07

主实验:Qwen-Image 上的对比

方法 参数(B) 显存(%) 延迟(ms) DPG↑ GenEval↑ LongText-EN↑ LongText-ZH↑ 平均下降(%)↓
Base model 20 100 2625 88.9 0.870 0.943 0.946 0
TinyFusion(14B) 14 79.4 1789 80.7 0.739 0.859 0.857 8.75
HierarchicalPrune(14B) 14 79.4 1786 83.3 0.766 0.884 0.881 6.49
PPCL(14B) 14 79.4 1792 87.9 0.847 0.929 0.946 0.42
PPCL(12B) 12 71.4 1514 83.6 0.801 0.893 0.917 3.03
PPCL(10B+FT) 10 66.9 1462 86.7 0.828 0.902 0.931 3.29

消融实验 (Qwen-Image, 均剪枝至 ~10-12B)

配置 LongText↑ DPG↑ GenEval↑ 平均 参数(B) 平均下降(%)↓
Original (20B) 0.942 0.885 0.854 0.894 20 0
Baseline (CKA+顺序蒸馏) 0.625 0.763 0.728 0.706 12 18.2
+LP (线性探针) 0.712 0.795 0.776 0.761 12 14.5
+LP-a (CKA 阈值替代差分) 0.664 0.778 0.712 0.718 12 19.7
+LP-b (扩大区间上界) 0.678 0.769 0.731 0.726 12 18.8
+DP (非顺序蒸馏) 0.905 0.836 0.801 0.848 12 5.22
+WP-text (文本流剪枝) 0.915 0.846 0.819 0.860 11 3.79
+WP-ffn (FFN 剪枝) 0.906 0.835 0.809 0.850 10 4.91
+Fine-tuning 0.916 0.867 0.828 0.870 10 2.61

关键发现

  • 连续 vs 非连续移除: 在 Qwen-Image 60 层模型上移除 1–3 层时,连续移除的生成质量一致优于非连续移除,证实了冗余的深度连续性假设
  • 非顺序蒸馏提升巨大: 从+LP 的 14.5% 下降降至 +DP 的 5.22%,仅引入非顺序蒸馏就带来 ~9 个百分点的改善
  • 宽度剪枝"减参数反而升指标": +WP-text 在减少 1B 参数的同时,平均性能反而从 0.848 提升至 0.860,原因是额外的可训练线性层改善了部分层对齐不足的问题
  • 即插即用性: PPCL(14B) 和 PPCL(12B) 无需额外训练,直接由 10B 模型替换部分学生层为教师层即可获得
  • 在已剪枝模型上再剪枝: 对 FLUX.1 Lite (8B) 再剪枝至 6.5B,平均性能下降仅 0.07%

亮点与洞察

  • 连续冗余是 DiT 的固有特性: 与 CNN 中冗余分散分布不同,MMDiT 的相邻层在表示空间中做平滑过渡,形成可整段移除的功能单元。这一发现为 DiT 的剪枝提供了全新范式
  • 线性探针作为可替代性度量: 相比直接移除层评估敏感度,线性探针能更稳定地量化层间的线性可近似程度,且仅需一次训练即可处理所有层
  • 非顺序蒸馏是关键: 消融实验显示,非顺序蒸馏比冗余层检测方法本身贡献更大 (9pp vs 3.7pp),证明切断误差累积链是高压缩比下保持质量的核心
  • 双轴压缩的互补性: 深度剪枝减少层数、宽度剪枝减少保留层内的参数,两者正交且可叠加,联合使用实现 50% 压缩

局限与展望

  • CKA 一阶差分的拐点检测缺乏严格理论基础,作者承认这主要是成功的工程启发式方法
  • INT4 量化在 PPCL 剪枝后效果不佳——剪枝降低了网络冗余度,收窄了量化容错空间,剪枝+量化的联合优化有待研究
  • 训练仍需 8 × H20 GPU,对于更大规模模型 (如 100B 级) 的可扩展性待验证
  • 宽度剪枝中哪些层属于 \(\mathcal{R}_{txt}\)\(\mathcal{R}_{ffn}\) 的选择策略细节在附录中,主文描述有限

相关工作与启发

  • vs TinyFusion: TinyFusion 使用可微分门控参数识别可移除层,但逐层独立评估忽视连续性,且采用标准(顺序)蒸馏导致误差累积。在 FLUX.1 和 Qwen-Image 上 PPCL 平均下降 4.03% / 0.42%,远优于 TinyFusion 的 13.80% / 8.75%
  • vs HierarchicalPrune: HPP 的层重要性评估偏粗糙,生成结果出现视觉 artifacts;PPCL 在相同压缩比下优势明显 (0.42% vs 6.49%)
  • vs Dense2MoE: Dense2MoE 用 MoE 替换 FFN 来降低激活成本但不减参数量,且平均下降达 21.52%,说明单纯替换子结构不如系统性剪枝+蒸馏
  • vs Chroma1-HD: Chroma1-HD 在 FLUX.1 上平均下降最低 (1.02%) 但推理延迟反增至 2.4 倍,不满足加速需求

评分

  • 新颖性: ⭐⭐⭐⭐ 连续冗余假设经实验充分验证,线性探针+CKA 差分的自动区间检测和非顺序蒸馏都是有效创新
  • 实验充分度: ⭐⭐⭐⭐ 在 FLUX.1 和 Qwen-Image 两个主流 MMDiT 上验证,消融实验细致地拆解了每个组件的贡献
  • 写作质量: ⭐⭐⭐⭐ 方法动机阐述清晰,从观察→假设→设计的逻辑链完整
  • 价值: ⭐⭐⭐⭐⭐ 直击 20B 级 DiT 部署瓶颈,50% 压缩+1.8× 加速的工程价值极高,即插即用特性进一步提升实用性