Amortising Inference and Meta-Learning Priors in Neural Networks (BNNP)¶
会议: ICLR 2026
arXiv: 2602.08782
代码: 有
领域: 贝叶斯深度学习
关键词: Bayesian neural network, neural process, meta-learning, amortised inference, prior learning
一句话总结¶
提出 BNNP(Bayesian Neural Network Process),一种将 BNN 权重作为隐变量、BNN 本身作为解码器的 neural process,通过逐层 amortised variational inference 在多数据集上联合学习 BNN 先验和推断网络,首次回答了"在良好先验下,近似推断方法还重要吗?"——答案是肯定的,没有免费午餐。
研究背景与动机¶
领域现状:BNN 理论优雅但先验选择是核心难题——权重缺乏可解释性,方便先验(各向同性高斯)将 BNN 退化为高斯过程,失去了层次表示学习能力。Neural process 通过元学习隐式先验但无法评估或采样。
现有痛点:(1) 不知道什么是 BNN 的"好先验";(2) 即使有好先验,现有的近似推断方法(MFVI, HMC 等)是否足够好不清楚;(3) Neural process 不能做 within-task minibatching(大数据集内存爆炸)。
核心矛盾:在合理先验下研究 BNN 行为 → 需要先有好先验 → 好先验需要从数据中学 → 学先验需要好的推断方法 → 形成循环。
本文目标 同时解决先验学习和 amortised inference:从多个相关数据集中元学习 BNN 先验,同时获得高质量的逐数据集后验。
切入角度:将 BNN 推断重新表述为 neural process——隐变量是权重,解码器就是这些权重参数化的网络本身。逐层条件后验通过贝叶斯线性回归闭式求解。
核心 idea:BNN 权重 = neural process 的隐变量,BNN 本身 = 解码器,从多任务数据中联合学习先验和 amortised 推断。
方法详解¶
整体框架¶
BNNP 对 BNN 的每一层,用推断网络将数据点映射为伪似然参数(伪观测值+噪声水平),然后与高斯先验做闭式贝叶斯线性回归得到层条件后验。从第一层采样,传到下一层,逐层进行。先验参数和推断网络参数在多任务数据集上联合优化。
关键设计¶
-
Amortised Linear Layer(核心组件):
- 功能:对任意位于网络中间的线性层做 amortised 推断
- 核心思路:推断网络 \(g_{\theta_l}\) 将每个数据点 \((x_n, y_n)\) 映射为伪观测值 \(y_n^l\) 和噪声水平 \(\sigma_{n,d}^l\)。与高斯先验 \(p_{\psi_l}(W^l)\) 做闭式贝叶斯线性回归得到条件后验 \(q(W^l | W^{1:l-1}, \mathcal{D})\)
- 设计动机:闭式后验 = 精确推断在层内,接近 BNN 的全后验。推断网络 amortise 了推断过程——一次前向传播即可
-
PP-AVI 训练目标:
- 功能:联合优化先验参数 \(\Psi\) 和推断网络参数 \(\Theta\)
- 公式:\(\mathcal{L}_{PP-AVI} = \log q(Y_t | \mathcal{D}_c, X_t) + \mathcal{L}_{ELBO}(\mathcal{D}_c)\)
- 第一项(后验预测密度)驱动预测质量和先验学习;第二项(ELBO)驱动近似推断质量
- Proposition 1 证明在无限数据集下同时满足三个 desiderata
-
Within-Task Minibatching(独特能力):
- 功能:通过 sequential Bayesian inference 实现大数据集上的 minibatch 推断
- 核心思路:将数据集分成小批,逐批更新各层后验(后验即先验的下一步更新)。预测时结果与全批处理完全相同
- 设计动机:这是现有 neural process 几乎都做不到的——它们需要一次性处理整个 context set
-
可调先验灵活度:
- 功能:固定部分权重的先验、只学习其余权重的先验
- 设计动机:在元数据集很小时防止先验过拟合——推断网络和先验参数可以独立控制
损失函数 / 训练策略¶
PP-AVI 目标(后验预测 + ELBO)。LoRA 风格的推断网络。元学习范式——多任务训练。
实验关键数据¶
近似后验质量(KL 散度 ↓)¶
| 方法 | KL(q || p(W|D)) | |------|----------------| | MFVI | 高(尤其低噪声) | | FCVI | 很高 | | GIVI | 中 | | BNNP (amortised) | 最低 | | BNNP (per-task) | 最低 |
先验学习能力¶
| 数据生成过程 | 先验样本质量 |
|---|---|
| 锯齿函数 | 几乎无法区分真实 vs 学到的先验 |
| Heaviside 函数 | 多模态先验被成功学到 |
| MNIST 像素回归 | 可辨认数字,支持超分辨率 |
| ERA5 降水 | 真实世界先验有效 |
关键发现¶
- 好先验 ≠ 免费午餐:即使在学到的先验下,不同近似推断方法的性能差异仍然很大。HMC 在好先验下仍然最好,说明近似推断质量始终重要
- Gaussian 先验足以产生高度复杂的函数先验(包括多模态的 Heaviside)——先验设计没有想象中那么难
- BNNP 的部分可训练先验在小元数据集设置中优于全可训练 neural process——防止先验过拟合
- Within-task minibatching 使 BNNP 可以处理大数据集——这是其他 neural process 的新能力
亮点与洞察¶
- "BNN 权重即 neural process 隐变量"的统一视角:将两个不同领域(BDL + probabilistic meta-learning)优雅地连接起来
- 首次在实证上回答了 BDL 的基本问题:好先验是否消除了对好推断方法的需求?答案是否——这对 BDL 社区有重要指导意义
- Gaussian 先验的出人意料的灵活性:简单的 Gaussian 权重先验在适当学习后可以产生极其丰富的函数先验——推翻了需要复杂先验结构的假设
局限与展望¶
- 推断复杂度随网络宽度不利地增长——目前主要验证在小型 BNN 上
- 需要多个相关数据集来学习先验——单数据集场景仍未解决
- BNAM(注意力版本)破坏了一致性(不再是合法的随机过程)——理论不完备
- 实验主要在 1D/2D 回归任务上,更高维的复杂任务需要验证
相关工作与启发¶
- vs Neural Process 家族: BNNP 的隐变量是解码器权重本身(而非抽象表示),这使得先验可以显式评估和采样
- vs MFVI / HMC: BNNP 提供的逐层 amortised 推断质量接近 HMC,远优于 MFVI
- vs 函数空间先验 (Cinquin et al.): 函数空间先验通常退化为 GP,BNNP 的权重空间先验保持了 BNN 的层次表示能力
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ 将 BNN 重塑为 neural process 的统一框架非常新颖和深刻
- 实验充分度: ⭐⭐⭐⭐⭐ 4 个研究问题、合成+真实数据、与多种 VI 方法对比
- 写作质量: ⭐⭐⭐⭐⭐ 理论推导严谨,叙事引人入胜("没有免费午餐"的核心信息)
- 价值: ⭐⭐⭐⭐⭐ 对 BDL 领域的基础性贡献——提供了研究 BNN 行为的新工具