PPCL: Pluggable Pruning with Contiguous Layer Distillation for Diffusion Transformers¶
会议: CVPR 2026
arXiv: 2511.16156
代码: GitHub
领域: 模型压缩 / 扩散模型
关键词: diffusion transformer, structured pruning, contiguous layer redundancy, knowledge distillation, MMDiT
一句话总结¶
提出 PPCL 框架,针对超大规模 Multi-Modal Diffusion Transformer (MMDiT, 8–20B 参数) 设计结构化剪枝方案:通过线性探针 (Linear Probe) 学习每层的可替代性,结合 CKA 一阶差分自动定位连续冗余层区间,再以非顺序交替蒸馏实现深度+宽度双轴剪枝,最终在 Qwen-Image 20B 上实现 50% 参数缩减、1.8× 推理加速,平均性能仅下降 2.61%。
研究背景与动机¶
领域现状: 最新的文本到图像 (T2I) 扩散模型已从 UNet 架构全面转向 Multi-Modal Diffusion Transformer (MMDiT)。SDXL 为 2.6B 参数,而 FLUX.1 达 12B、Qwen-Image 达 20B (60 层 MMDiT block),生成质量大幅提升,但推理成本急剧增加。
现有痛点: (a) 已有的结构化剪枝方法 (如 TinyFusion、SnapFusion) 主要面向 UNet 架构,难以直接迁移到 MMDiT 的双流结构;(b) 现有方法逐层独立评估冗余度 (如敏感度分析),忽略了 DiT 中相邻层之间的功能耦合关系;(c) 传统顺序蒸馏中,早期层的压缩误差沿网络传播并累积,导致学生模型的表示严重偏离教师模型。
核心矛盾: 作者通过实验发现 DiT 的冗余呈现深度连续性——移除连续的层比移除等量的非连续层对性能的影响更小。现有剪枝方法未利用这一特性。
本文目标: 如何系统性识别 MMDiT 中的连续冗余层区间,并设计不累积误差的蒸馏方案实现高压缩比下的质量保持。
切入角度: 以"层的可替代性"(substitutability) 替代传统的层重要性评估——若一层的输入-输出映射可被线性变换近似,则该层对其相邻层是功能冗余的。
核心 idea: 在 MMDiT 中,冗余层沿深度方向连续分布,可通过线性探针+CKA 差分自动定位并整段移除,配合非顺序蒸馏消除误差累积。
方法详解¶
整体框架¶
PPCL 分两个阶段、三个步骤:
- 阶段一 (深度剪枝):(1) 为每个 MMDiT block 训练线性探针,评估层的可替代性;(2) 通过 CKA 一阶差分的拐点检测,自动划分连续冗余层区间集合 \(\mathcal{I}\);(3) 以非顺序蒸馏方式训练学生模型,每个区间独立优化。
- 阶段二 (宽度剪枝):识别文本流 (text stream) 和 FFN 中的冗余,替换为轻量线性投影器,进一步压缩参数。
- 最终执行短时全参数 fine-tuning。
关键设计¶
1. 基于线性探针的连续冗余层检测
- 功能: 自动识别一组不重叠的连续冗余层区间 \(\mathcal{I} = \{[u_i, v_i]\}\)
- 核心思路: 为教师模型的每层 \(T_i\) 构建带残差结构的线性探针 \(l_i\),通过最小二乘初始化后用对齐损失 \(\mathcal{L}_{fit}(i) = \|l_i(T_{i-1}^D) + T_{i-1}^D - T_i(T_{i-1}^D)\|_2^2\) 训练。训练完成后,在校准集上计算代理模型 (用线性探针替换连续层) 输出与教师层输出的 CKA 相似度,定义一阶差分 \(\Delta(u,k) = -(\text{cka}(u,k) - \text{cka}(u,k-1))\)。当 \(\Delta\) 先减后增时,拐点 \(v\) 标志冗余区间的右端点
- 设计动机: (a) 线性探针的训练输入与对应层的实际输入一致,保证每层评估独立不干扰;(b) 有限个线性变换的叠加仍满足线性,从而保证连续层替代性的传递;(c) 用 CKA 差分的拐点而非固定阈值,能自适应地定位冗余区间长度
- 关键公式: 线性探针初始化闭合解 \(W_i^* = (T_i(T_{i-1}^D) - T_{i-1}^D)(T_{i-1}^D)^\top(T_{i-1}^D(T_{i-1}^D)^\top)^{-1}\)
2. 非顺序深度剪枝蒸馏
- 功能: 对检测到的每个冗余区间 \([u,v]\),用单层学生模型替代,避免误差累积
- 核心思路: 每个区间独立优化——学生层 \(S^u\) 接收教师模型第 \(u-1\) 层的输出作为输入,对齐教师第 \(v\) 层的输出。损失函数 \(\mathcal{L}_{depth}^{[u,v]} = \|\text{Norm}(S^u(T_{u-1}^D)) - \text{Norm}(T_v^D)\|_2^2\),其中 Norm 为 L2 归一化,强调方向对齐
- 设计动机: 传统顺序蒸馏中学生模型第 \(k\) 层接收自身第 \(k-1\) 层的输出,早期误差会逐层累积。非顺序方案让每个区间直接从教师模型获取输入,切断误差传播链
- 即插即用特性: 由于每个区间独立训练,推理时可灵活选择启用或跳过某些区间——例如训练一个 10B 模型后,将部分学生层替换回教师层即可得到 12B 或 14B 变体,无需重训
3. 宽度剪枝:文本流与 FFN 压缩
- 功能: 在保留层内部进一步压缩参数——替换冗余的文本流整体结构和 FFN
- 核心思路: (a) 文本流剪枝:CKA 热力图揭示文本流跨层表示高度相似,将冗余层的文本流 (除 QKV 投影外) 替换为两个轻量线性投影 \(l_p^z\) 和 \(l_p^h\);(b) FFN 剪枝:测量用线性投影替代 FFN 的 MSE 极小的层,将其 FFN 替换为线性投影 \(l_q^{img}\) 和 \(l_q^{txt}\)
- 设计动机: MMDiT 的文本流 token 相似度高、层间变化小,可大幅压缩;FFN 显著过参数化,许多层的 FFN 功能接近线性变换
- 损失函数: 宽度蒸馏损失由两部分组成——层级对齐损失 \(\mathcal{L}_{width}^j\) (与深度蒸馏格式一致) 加上线性投影对齐损失 \(\mathcal{L}_{linear}^j\) (约束线性投影输出逼近教师对应中间表示)
训练策略¶
- 数据: 从 LAION-2B-en 采样 10 万张图,用 Qwen2.5-VL 生成精细描述
- 训练三阶段: 深度剪枝 6k steps → 宽度剪枝 2k steps → 全参数 fine-tuning 1k steps (8 × H20 GPU)
- 优化器: AdamW (\(\beta_1\)=0.9, \(\beta_2\)=0.95, weight decay=0.02),BF16 混合精度 + 梯度检查点
实验关键数据¶
主实验:FLUX.1-dev 上的对比¶
| 方法 | 参数(B) | 显存(%) | 延迟(ms) | DPG↑ | GenEval↑ | B-VQA↑ | UniDet↑ | 平均下降(%)↓ |
|---|---|---|---|---|---|---|---|---|
| Base model | 12 | 100 | 715 | 83.8 | 0.665 | 0.640 | 0.426 | 0 |
| TinyFusion | 8 | 74.4 | 534 | 77.2 | 0.511 | 0.584 | 0.369 | 13.80 |
| HierarchicalPrune | 8 | 74.4 | 543 | 75.7 | 0.503 | 0.579 | 0.371 | 13.38 |
| Dense2MoE | 12 | 100 | 312 | 73.6 | 0.403 | 0.473 | 0.311 | 21.52 |
| FLUX.1 Lite | 8 | 78.8 | 572 | 82.1 | 0.623 | 0.547 | 0.379 | 6.09 |
| Chroma1-HD | 8.9 | 82.5 | 1714 | 84.0 | 0.593 | 0.621 | 0.339 | 1.02 |
| PPCL(8B) | 8 | 74.4 | 535 | 80.0 | 0.605 | 0.615 | 0.391 | 4.03 |
| PPCL(6.5B) | 6.5 | 69.2 | 428 | 81.2 | 0.593 | 0.581 | 0.398 | 0.07 |
主实验:Qwen-Image 上的对比¶
| 方法 | 参数(B) | 显存(%) | 延迟(ms) | DPG↑ | GenEval↑ | LongText-EN↑ | LongText-ZH↑ | 平均下降(%)↓ |
|---|---|---|---|---|---|---|---|---|
| Base model | 20 | 100 | 2625 | 88.9 | 0.870 | 0.943 | 0.946 | 0 |
| TinyFusion(14B) | 14 | 79.4 | 1789 | 80.7 | 0.739 | 0.859 | 0.857 | 8.75 |
| HierarchicalPrune(14B) | 14 | 79.4 | 1786 | 83.3 | 0.766 | 0.884 | 0.881 | 6.49 |
| PPCL(14B) | 14 | 79.4 | 1792 | 87.9 | 0.847 | 0.929 | 0.946 | 0.42 |
| PPCL(12B) | 12 | 71.4 | 1514 | 83.6 | 0.801 | 0.893 | 0.917 | 3.03 |
| PPCL(10B+FT) | 10 | 66.9 | 1462 | 86.7 | 0.828 | 0.902 | 0.931 | 3.29 |
消融实验 (Qwen-Image, 均剪枝至 ~10-12B)¶
| 配置 | LongText↑ | DPG↑ | GenEval↑ | 平均 | 参数(B) | 平均下降(%)↓ |
|---|---|---|---|---|---|---|
| Original (20B) | 0.942 | 0.885 | 0.854 | 0.894 | 20 | 0 |
| Baseline (CKA+顺序蒸馏) | 0.625 | 0.763 | 0.728 | 0.706 | 12 | 18.2 |
| +LP (线性探针) | 0.712 | 0.795 | 0.776 | 0.761 | 12 | 14.5 |
| +LP-a (CKA 阈值替代差分) | 0.664 | 0.778 | 0.712 | 0.718 | 12 | 19.7 |
| +LP-b (扩大区间上界) | 0.678 | 0.769 | 0.731 | 0.726 | 12 | 18.8 |
| +DP (非顺序蒸馏) | 0.905 | 0.836 | 0.801 | 0.848 | 12 | 5.22 |
| +WP-text (文本流剪枝) | 0.915 | 0.846 | 0.819 | 0.860 | 11 | 3.79 |
| +WP-ffn (FFN 剪枝) | 0.906 | 0.835 | 0.809 | 0.850 | 10 | 4.91 |
| +Fine-tuning | 0.916 | 0.867 | 0.828 | 0.870 | 10 | 2.61 |
关键发现¶
- 连续 vs 非连续移除: 在 Qwen-Image 60 层模型上移除 1–3 层时,连续移除的生成质量一致优于非连续移除,证实了冗余的深度连续性假设
- 非顺序蒸馏提升巨大: 从+LP 的 14.5% 下降降至 +DP 的 5.22%,仅引入非顺序蒸馏就带来 ~9 个百分点的改善
- 宽度剪枝"减参数反而升指标": +WP-text 在减少 1B 参数的同时,平均性能反而从 0.848 提升至 0.860,原因是额外的可训练线性层改善了部分层对齐不足的问题
- 即插即用性: PPCL(14B) 和 PPCL(12B) 无需额外训练,直接由 10B 模型替换部分学生层为教师层即可获得
- 在已剪枝模型上再剪枝: 对 FLUX.1 Lite (8B) 再剪枝至 6.5B,平均性能下降仅 0.07%
亮点与洞察¶
- 连续冗余是 DiT 的固有特性: 与 CNN 中冗余分散分布不同,MMDiT 的相邻层在表示空间中做平滑过渡,形成可整段移除的功能单元。这一发现为 DiT 的剪枝提供了全新范式
- 线性探针作为可替代性度量: 相比直接移除层评估敏感度,线性探针能更稳定地量化层间的线性可近似程度,且仅需一次训练即可处理所有层
- 非顺序蒸馏是关键: 消融实验显示,非顺序蒸馏比冗余层检测方法本身贡献更大 (9pp vs 3.7pp),证明切断误差累积链是高压缩比下保持质量的核心
- 双轴压缩的互补性: 深度剪枝减少层数、宽度剪枝减少保留层内的参数,两者正交且可叠加,联合使用实现 50% 压缩
局限与展望¶
- CKA 一阶差分的拐点检测缺乏严格理论基础,作者承认这主要是成功的工程启发式方法
- INT4 量化在 PPCL 剪枝后效果不佳——剪枝降低了网络冗余度,收窄了量化容错空间,剪枝+量化的联合优化有待研究
- 训练仍需 8 × H20 GPU,对于更大规模模型 (如 100B 级) 的可扩展性待验证
- 宽度剪枝中哪些层属于 \(\mathcal{R}_{txt}\) 和 \(\mathcal{R}_{ffn}\) 的选择策略细节在附录中,主文描述有限
相关工作与启发¶
- vs TinyFusion: TinyFusion 使用可微分门控参数识别可移除层,但逐层独立评估忽视连续性,且采用标准(顺序)蒸馏导致误差累积。在 FLUX.1 和 Qwen-Image 上 PPCL 平均下降 4.03% / 0.42%,远优于 TinyFusion 的 13.80% / 8.75%
- vs HierarchicalPrune: HPP 的层重要性评估偏粗糙,生成结果出现视觉 artifacts;PPCL 在相同压缩比下优势明显 (0.42% vs 6.49%)
- vs Dense2MoE: Dense2MoE 用 MoE 替换 FFN 来降低激活成本但不减参数量,且平均下降达 21.52%,说明单纯替换子结构不如系统性剪枝+蒸馏
- vs Chroma1-HD: Chroma1-HD 在 FLUX.1 上平均下降最低 (1.02%) 但推理延迟反增至 2.4 倍,不满足加速需求
评分¶
- 新颖性: ⭐⭐⭐⭐ 连续冗余假设经实验充分验证,线性探针+CKA 差分的自动区间检测和非顺序蒸馏都是有效创新
- 实验充分度: ⭐⭐⭐⭐ 在 FLUX.1 和 Qwen-Image 两个主流 MMDiT 上验证,消融实验细致地拆解了每个组件的贡献
- 写作质量: ⭐⭐⭐⭐ 方法动机阐述清晰,从观察→假设→设计的逻辑链完整
- 价值: ⭐⭐⭐⭐⭐ 直击 20B 级 DiT 部署瓶颈,50% 压缩+1.8× 加速的工程价值极高,即插即用特性进一步提升实用性