Joint Model and Data Sparsification via the Marginal Likelihood¶
会议: ICML 2026
arXiv: 2605.29107
代码: 待确认
领域: 模型压缩 / 数据稀疏化 / 贝叶斯学习
关键词: 联合稀疏化, 边缘似然, 拉普拉斯近似, 神经切线核
一句话总结¶
JMDS 通过最大化边缘似然的统一目标同时实现模型和数据稀疏化——避免分阶段优化的次优性,在 CIFAR / ImageNet / WikiText 上以 5-10× 联合压缩比下保持优于独立稀疏化的性能。
研究背景与动机¶
领域现状:神经网络稀疏化已被广泛研究,但模型剪枝(去除权重)和数据稀疏化(去除训练样本)通常独立处理——分阶段方法忽略二者的耦合。
现有痛点:(1)训练→模型稀疏→数据稀疏的管道易陷入局部最优;(2)已有联合方法多基于启发式,缺乏严格理论;(3)大型模型联合稀疏化中关键问题——模型稀疏与数据稀疏的关联未被原则性回答。
核心矛盾:模型和数据应同时优化以最大化联合压缩效益,但缺乏统一目标函数。
本文目标:提出原则性的联合稀疏化框架,理论分析其复杂度,验证实际有效性。
切入角度:贝叶斯框架中边缘似然自然地集成了模型复杂度(如先验体积)与数据似然——这是评估模型 + 数据组合质量的天然指标。
核心 idea:将模型稀疏(权重二元掩码 \(\mathbf{m}\))和数据稀疏(样本二元权重 \(\mathbf{s}\))同时纳入边缘似然目标 \(\log p(\mathcal{D}_s | \mathbf{m}, \mathbf{s}) = \int p(\mathcal{D}_s | \theta, \mathbf{m}) p(\theta) d\theta\) 并通过拉普拉斯近似可处理。
方法详解¶
整体框架¶
(1)联合参数化:模型 \(\theta\) + 模型掩码 \(\mathbf{m}\) + 数据权重 \(\mathbf{s}\);(2)目标:联合最大化边缘似然 \(\log p(\mathcal{D}^{(s)} | \mathbf{m}) - \lambda_1 \|\mathbf{m}\|_0 - \lambda_2 \|\mathbf{s}\|_0\);(3)优化:用拉普拉斯近似简化边缘似然;(4)算法:交替最大化 \(\theta, \mathbf{m}, \mathbf{s}\)。
关键设计¶
-
统一边缘似然目标:
- 功能:联合优化模型稀疏与数据稀疏的统一目标。
- 核心思路:最大化 \(\log p(\mathcal{D}^{(s)} | \mathbf{m}) - \lambda_1 \|\mathbf{m}\|_0 - \lambda_2 \|\mathbf{s}\|_0\),其中 \(\mathcal{D}^{(s)} = \{(\mathbf{x}_i, y_i, s_i)\}\) 为带样本权重的数据集;\(\mathbf{m}\) 为模型权重掩码。边缘似然通过 Occam 剃刀自然惩罚冗余权重。
- 设计动机:相比分阶段方法,统一目标确保 \((\mathbf{m}, \mathbf{s})\) 联合最优;相比启发式联合方法,理论性更强。
-
拉普拉斯近似 + Hessian 近似:
- 功能:将不可解的边缘似然积分转化为可计算的解析形式。
- 核心思路:对 \(p(\mathcal{D}^{(s)} | \theta) p(\theta)\) 在 \(\theta^* = \arg\max\) 处做拉普拉斯近似得 \(\log p(\mathcal{D}^{(s)}) \approx -\mathcal{L}(\theta^*) + \frac{1}{2} \log \det H^{-1}\);用 K-FAC 块对角近似 Hessian \(H \approx H_{\text{kfac}}\) 使复杂度 \(O(N + d^2)\) → \(O(N + d \cdot l)\);用 NTK 近似 + 子采样进一步加速。
- 设计动机:精确边缘似然需 \(O(d^3)\) Hessian 分解,大模型不可行;K-FAC + NTK 近似既保精度又获得可扩展性。
-
三相交替优化算法:
- 功能:通过迭代最大化 \(\theta, \mathbf{m}, \mathbf{s}\) 收敛到联合最优。
- 核心思路:阶段 A(参数训练):固定 \(\mathbf{m}, \mathbf{s}\) 用 SGD 训 \(\theta\);阶段 B(模型稀疏化):固定 \(\theta, \mathbf{s}\) 优化 \(\mathbf{m}\)——边缘似然梯度 \(\partial \log p / \partial m_j \approx -|\theta_j| \cdot \mathbb{E}[H_{jj}]\) 给出每个权重的"边际贡献得分";阶段 C(数据稀疏化):固定 \(\theta, \mathbf{m}\) 优化 \(\mathbf{s}\)——样本对边缘似然的边际贡献 \(\partial \log p / \partial s_i \approx \log p(y_i | \mathbf{x}_i, \theta, \mathbf{m}) + \text{Hessian 项}\) 给出每样本得分。
- 设计动机:联合优化的非凸问题难一次求解;交替最大化收敛快速、内存友好。
实验关键数据¶
主实验:联合稀疏化效果(CIFAR-100 + ResNet-50)¶
| 方法 | 模型稀疏度 | 数据稀疏度 | 测试 ACC | 训练时间 | 推理 FLOPs |
|---|---|---|---|---|---|
| 密集基线 | 0% | 0% | 78.3% | 1.0× | 1.0× |
| 仅模型剪枝(IMP) | 80% | 0% | 76.1% | 1.0× | 0.21× |
| 仅数据修剪(forget) | 0% | 50% | 75.8% | 0.5× | 1.0× |
| 分阶段(IMP→forget) | 80% | 50% | 74.2% | 0.5× | 0.21× |
| JMDS(本工作) | 80% | 50% | 77.5% | 0.4× | 0.21× |
| JMDS(极端) | 90% | 70% | 76.3% | 0.3× | 0.13× |
跨数据集 / 模型泛化¶
| 数据集 | 模型 | 分阶段 ACC | JMDS ACC | 提升 |
|---|---|---|---|---|
| CIFAR-10 | ResNet-18 | 91.2 | 93.4 | +2.2 |
| CIFAR-100 | ResNet-50 | 74.2 | 77.5 | +3.3 |
| ImageNet | ResNet-50 | 72.1 | 74.8 | +2.7 |
| WikiText-2 | GPT-2 (Small) | 27.3 PPL | 24.9 PPL | -2.4 PPL |
| WikiText-103 | GPT-2 (Medium) | 24.5 PPL | 22.1 PPL | -2.4 PPL |
计算开销分析¶
| 方法 | Hessian 近似成本 | 算法收敛步数 | 总时间 vs 密集 |
|---|---|---|---|
| 精确 Hessian | \(O(d^3)\) → 不可行 | — | — |
| K-FAC + NTK 子采样 | \(O(d \cdot l + s d)\) | 50-100 步 | 0.4-1.5× |
| 完全启发式 | \(O(1)\) | 100+ | 0.5× |
关键发现¶
- 联合优势在高稀疏度下尤其显著:80% 模型 + 50% 数据稀疏度下,JMDS 比分阶段高 3.3%。
- 理论与实验一致:边缘似然下降量与精度损失高度相关。
- 跨任务一致性:CV 和 NLP 上均稳定提升,说明框架的通用性。
亮点与洞察¶
- 首次原则性联合稀疏化:突破独立稀疏化的传统观念,揭示模型与数据的耦合关系。
- 理论与实践的精彩结合:拉普拉斯近似 + K-FAC + NTK 子采样使理论目标可行。
- 统一视角:边缘似然作为统一指标既衡量模型复杂度又评估数据贡献。
局限与展望¶
- 大模型扩展性:当前 K-FAC 近似在 GPT-2 Medium 上仍有局限,对 GPT-2 Large+ 需要进一步近似。
- 非梯度方法的边际贡献得分:当前公式基于梯度信息,对一些非梯度任务(如检索)不直接适用。
- 收敛性:交替优化的全局收敛保证未给出。
- 改进:开发更高效的 Hessian 近似(如二阶 NTK);扩展到多模态、强化学习场景。
相关工作与启发¶
- vs 独立稀疏化(IMP, Forget-score):本工作首次给出耦合优化框架。
- vs Bayesian Pruning:贝叶斯剪枝主要针对模型稀疏;JMDS 扩展到数据稀疏化。
- 启发:边缘似然作为"组合复杂度"的统一指标,可扩展到架构搜索 + 数据选择联合问题。
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ 首次提供原则性联合稀疏化框架,超越启发式联合方法。
- 实验充分度: ⭐⭐⭐⭐⭐ CV + NLP + 5 个数据集 + 详细消融 + 理论与实验印证。
- 写作质量: ⭐⭐⭐⭐ 数学严谨,算法清晰,但部分近似的推导需要补充。
- 价值: ⭐⭐⭐⭐⭐ 大模型时代联合压缩具有重大实用价值;理论框架可启发更多联合优化问题。