Distribution Transformers: Fast Approximate Bayesian Inference With On-The-Fly Prior Adaptation¶
会议: ICML 2026
arXiv: 2502.02463
代码: https://github.com/GWhittle110/distribution-transformers
领域: 科学计算 / 贝叶斯推断 / Transformer 摊销推断
关键词: 摊销贝叶斯推断, 先验自适应, 高斯混合模型, 序贯滤波, Transformer
一句话总结¶
Distribution Transformer (DT) 把"先验分布"显式 token 化为一组高斯混合分量、把"观测"通过交叉注意力注入解码器,端到端学一个"先验+数据 → 后验"的映射,在保持与先验同族(GMM→GMM)以支持序贯滤波的同时,把推断时间从分钟级压到毫秒级,并允许测试时任意更换先验而无需重训。
研究背景与动机¶
领域现状:摊销贝叶斯推断(Amortized Bayesian Inference, ABI)把"为每个新数据集解一次后验"这件昂贵的事预先训练好——离线训练阶段学一个 \(z \mapsto q(x|z)\),在线只跑一次前向。基于 Transformer 的代表性方法 PFN/TabPFN/ACE 已经能在小样本场景下做到单次前向出后验,效果与 SVI/MCMC 接近。
现有痛点:(1) 这些 ABI 模型在训练时把先验"焊死"了——一旦想换先验,模型要重训甚至重新生成训练数据;(2) 即便少数方法支持"先验灵活性",输出分布族(如 PFN 的 Riemann 桶状分布)与先验不一致,输出后验无法再喂回去当下一轮先验,因此根本不能做序贯滤波(Kalman/粒子滤波场景);(3) 经典序贯方法(EKF/PF)灵活但要么强 Gaussian 假设要么算力随粒子数爆炸,且不支持跨任务摊销。
核心矛盾:摊销 + 先验灵活 + 共轭性(先验和后验同族)这三者必须同时满足,序贯贝叶斯滤波才能跑——以往的工作总是顾此失彼。
本文目标:(i) 单次前向出后验(摊销);(ii) 测试时任意换先验,不重训(先验摊销);(iii) 先验和后验同属 GMM 族,可递归级联做滤波;(iv) 在静态推断基准上不输 PFN/TabPFN/ACE,序贯任务上能追上粒子滤波但快几十至上千倍。
切入角度:找一个"通用万能近似器分布族"并用 Transformer 在该族上操作。作者选择高斯混合模型——任意紧支撑光滑密度都能被 \(k\) 分量 GMM 任意精度逼近,且 GMM 的参数 \(\{(w_i,\boldsymbol{\mu}_i,\boldsymbol{\Sigma}_i)\}_{i=1}^{k}\) 是天然的"无序 token 序列",正好契合 Transformer 的置换不变假设。
核心 idea:把贝叶斯推断重写为 GMM-序列到 GMM-序列的映射,由 transformer decoder 实现,先验和观测都被嵌成 token,输出又回到 GMM 同族——同族性是序贯滤波的钥匙。
方法详解¶
整体框架¶
四个模块串起来:先验嵌入 → 观测嵌入 → transformer decoder → GMM 反嵌入。给定先验参数 \(\phi\),可学嵌入网络把它映射成长度为 \(k\) 的无序 token 序列(隐空间里的 GMM 表示);给定观测 \(z\)(数据集/传感器读数/查询点),用按数据源定制的可学嵌入嵌成另一组 token;transformer decoder(无位置编码,保置换等变)让先验 token 之间自注意力、并与观测 token 全局交叉注意力,输出隐空间里的后验 token 序列;最后一个 component-wise 的可学反嵌入把每个 token 解成 (logit, \(\boldsymbol{\mu}_i\), \(\boldsymbol{\Sigma}_i\)),跨 token softmax 给出权重 \(w_i\),组装成 GMM 后验 \(q_\theta(x|z,\phi) = \sum_i w_i \mathcal{N}(x;\boldsymbol{\mu}_i,\boldsymbol{\Sigma}_i)\)。可选地,作者还引入 sample-space 变换 \(f(\cdot)\) 把支撑做改变测度(如对带正支撑的逆 Gamma 先验做 log-warp),让 GMM 在 \(\mathbb{R}^n\) 上展开。
关键设计¶
-
GMM-as-token 表示 + Transformer 解码器:
- 功能:把"分布"做成 transformer 能吃的 token 序列,并保证输入输出都在同一个 GMM 族里。
- 核心思路:先验参数 \(\phi\) 经可学网络嵌成 \(k\) 个 token;观测按数据源各自嵌成 token,统一拼接成 context 序列;transformer decoder 让先验 token 间自注意力、并与 context cross-attention,得到后验 token;不加 positional encoding 以匹配 GMM 分量的置换不变性。反嵌入对每个 token 输出 \((\text{logit}, \boldsymbol{\mu}_i, \boldsymbol{\Sigma}_i)\),跨 token softmax 归一化权重 \(\sum_i w_i = 1\)。
- 设计动机:GMM 是普适近似器,又是"无序参数集合",与 transformer 的 set-to-set 性质完美匹配;同族输入输出意味着上一时刻的后验 token 序列可直接作为下一时刻的先验 token 序列,这是序贯滤波能跑的代数前提——PFN 的 Riemann 桶状分布做不到这一点。
-
元先验 (meta-prior) + KL 形式的对偶训练目标:
- 功能:让模型一次性见到一族先验而非单个先验,从而支持测试时换先验。
- 核心思路:引入"先验之上的分布"——元先验 \(p(\phi)\),联合分布写成 \(p(\phi,x,z) = p(\phi)p(x|\phi)p(z|x)\)。训练时每个 batch 先采 \(\phi_i \sim p(\phi)\),再采 \(x_i \sim p(x|\phi_i)\)、\(z_i \sim p(z|x_i)\)。主损失 \(\ell_\theta = \mathbb{E}_{p(\phi,x,z)}[-\log q_\theta(f(x)|z,\phi)]\)。Prop 3.1 证明该损失等价于 \(\mathbb{E}_{p(\phi,z)}[\mathrm{KL}(p(\cdot|z,\phi) \,\|\, q_\theta(\cdot|z,\phi))]\) 加常数,因此直接最小化平均后验 KL,且只需从 \(p(\phi,x,z)\) 采样、无需求真后验密度。
- 设计动机:把"先验"从训练时常量提升为联合分布里的随机变量,等价于在 \(\Phi \times \mathcal{Z} \to \mathcal{Q}\) 的更大映射空间里摊销;KL 形式说明这不是临时的最大似然 hack,而是直接逼近真后验。
-
先验一致性正则 (prior loss) 锁定隐空间共轭性:
- 功能:把反嵌入也作用在"先验 token 序列"上,得到先验的 GMM 近似 \(q_\theta(x|\phi)\),强迫先验与后验共享同一个隐空间表示。
- 核心思路:定义 \(\ell_\theta^{\mathrm{prior}} = \mathbb{E}_{p(\phi,x)}[-\log q_\theta(x|\phi)]\),组合损失 \(\ell_\theta' = \ell_\theta^{\mathrm{prior}} + \ell_\theta\)。即先验 token 经 transformer 之前直接解码,损失要求它本身就是先验的 GMM 近似;后验 token 经 transformer 之后解码,损失要求它是后验的 GMM 近似。
- 设计动机:先验和后验都必须用"同一个反嵌入"解出 GMM 才算共轭——没有这条正则,先验 token 与后验 token 可能落在不同的隐空间区域,使得"上一步后验作为下一步先验"在数值上失效。该项在性能上只是轻微提升,但是序贯级联的必要条件。
实验关键数据¶
主实验¶
实验 4.1:逆 Gamma 先验 + 正态方差似然的解析共轭对照,分窄/宽两个元先验设定,1000 未见问题。
| 方法 | 窄元先验 KL | 宽元先验 KL | 1000 问题推断时间 (s) |
|---|---|---|---|
| SVI | 0.0425 ± 0.0003 | 0.0558 ± 0.0016 | 148 |
| PFN-15 | 0.517 ± 1.009* | 331.5 ± 646.6* | 0.003 |
| PFN-5000 | 0.0038 ± 0.0789 | 0.2935 ± 0.0237 | 0.003 |
| TabPFNv2 | 0.0112 ± 0.0013 | 0.1513 ± 0.0168 | 1.52 |
| ACE-5 | 0.0094 ± 0.0000 | 0.0048 ± 0.0014 | 0.037 |
| DT-2 | 0.0044 ± 0.0001 | 0.0058 ± 0.0002 | 0.014 |
| DT-5 | 0.0004 ± 0.0000 | 0.0003 ± 0.0000 | 0.016 |
DT-5 比 PFN-5000 后验 KL 低近一个数量级(窄元先验)、宽元先验下差 3 个数量级;推断时间 16 ms / 1000 问题,比 SVI 快约 \(10^4\) 倍。
实验 4.2.1(5 维 GP 预测后验 + 超后验):DT 同时在 PPD NLL(0.81)与超后验 NLL(0.31)上击败 PFN/TabPFNv2/ACE,且 9.5 s 是最快。
实验 4.3.1(4 维状态空间贝叶斯传感器融合):
| 方法 | 期望 NLL | 100 序列批次单步时间 (s) |
|---|---|---|
| EKF | 95.9 ± 4.40 | 0.010 |
| Particle Filter | -0.244 ± 0.047 | 0.818 |
| DT-4 | -0.197 ± 0.040 | 0.017 |
DT 几乎追上"准 ground truth"的 PF,单步快约 50×;EKF 因线性化假设彻底失败。
消融与机制对比¶
| 维度 / 方法 | 关键观察 | 含义 |
|---|---|---|
| GMM 分量数 \(k = 2\) vs \(5\)(4.1 节) | KL 从 0.0044 降到 0.0004 | 分量数提供"逼近能力的旋钮",且和参数量解耦 |
| Riemann 输出(PFN)vs GMM(DT/ACE) | Riemann 在宽元先验下 KL 飙到 331 | 桶状分布表达力差是 PFN 的瓶颈 |
| 有/无 prior loss | 性能提升微小,但序贯级联必需 | 隐空间共轭性是序贯能力的代数前提 |
| 序贯任务可否套用 PFN(拼观测) | 推断时间随 \(T\) 线性甚至 \(\mathcal{O}(T^2)\) 增长 | DT 的常时间递推是关键工程优势 |
| 实验 4.3.2(10 维随机波动率) | PF 需 3 个数量级更多算力才能匹敌 DT | 高维稀疏信息场景 DT 拉开差距 |
关键发现¶
- 同族性是序贯能力的钥匙:GMM→GMM 同族意味着上一步后验可直接当下一步先验,单步推断时间与序列长度 \(T\) 解耦;而 PFN/TabPFN/ACE 即便强行拼接观测,时间随 \(T\) 线性甚至平方增长。
- GMM 表达力天花板高:与桶状 Riemann 分布相比,5 分量 GMM 在共轭对照实验里已逼近真后验到肉眼不可分;这是 DT/ACE 共同碾压 PFN/TabPFNv2 的根本原因。
- 元先验越宽,先验灵活性越值钱:窄元先验下 PFN-5000 还能凑合(因边缘分布接近真先验),宽元先验下完全垮掉,DT 几乎不变。
- prior loss 性能收益小但功能必需:去掉它静态 KL 几乎不变,但隐空间共轭性丢失,4.3 节序贯滤波直接失效。
亮点与洞察¶
- "分布作为输入"是被低估的设计自由度:以往摊销推断把先验当超参注入或干脆当成训练时常量;本文把先验参数 \(\phi\) 显式作为 transformer 的另一组 token,模型架构自然支持"换先验",这一思路可被迁移到任何需要在测试时调先验的概率建模任务(贝叶斯优化、ABC、传感器融合)。
- 架构对称性 ↔ 概率对称性:transformer 无位置编码 ↔ GMM 分量无序、cross-attention ↔ 观测条件独立,整套架构和贝叶斯图模型的不变性结构是同构的——这种"先选概率结构、再挑神经架构"的设计范式值得在科学计算 ML 里推广。
- 从"学后验"到"学算子":DT 真正学到的是"先验+数据 → 后验"这个算子,而不是某个具体后验。这把"摊销"从单层(跨任务)推到双层(跨任务 + 跨先验家族),抽象层级显著提升。
- 可堆叠的实时贝叶斯滤波:在毫秒级吞吐里跑非高斯、非线性 SSM,且能复现 PF 的精度,对自动驾驶感知、量子参数实时追踪、工业控制等需要快速贝叶斯更新的场景有直接工程价值。
局限与展望¶
- 训练成本随先验空间维度上升:要在 \(\Phi\) 上覆盖更广,离线训练样本量和时长显著增加(附录 Table 8)。
- 元先验需要"还算合理":若实际部署时遇到的先验完全在元先验之外,性能会衰减;附录 C.2 给出了一些鲁棒性证据,但远非全面。
- 高维 GMM 是已知瓶颈:分量数自注意力是平方、分量内 full-covariance 解码是隐维度平方,10 维以下 ok,几十维以上需要稀疏/低秩协方差。
- 序贯长链上的误差累积:每步都做近似,长链上误差会缓慢漂移;附录 C.5 验证中等深度上仍可控,但极长序列下未严格验证。
- 超参选择经验主义:分量数 \(k\)、嵌入维度、注意力头数等都是手调,缺乏系统化的自动选择策略。
相关工作与启发¶
- vs PFN / TabPFN / TabPFNv2 (Müller 2021, Hollmann 2022/2025):PFN 系列固定先验,输出 Riemann 桶状分布;DT 把先验 token 化、输出 GMM、且能换先验做序贯滤波,是定性增量。
- vs ACE (Chang 2024):ACE 已支持先验灵活性且也用 GMM 输出,性能最接近;DT 的关键差异是更灵活的嵌入设计和显式的同族共轭性保证(prior loss),后者使序贯应用成为可能。
- vs 经典 Kalman / 粒子滤波 (Kalman 1960; Doucet 2001):EKF 假设线性高斯,垮在非线性观测上;PF 渐近精确但维数灾难。DT 把两者的弱点都补上——非线性表达力 + 摊销带来的恒定吞吐。
- vs 变分推断 / 神经过程 (Kingma & Welling 2013; Garnelo 2018):经典 VI 每次问题都要重新优化;神经过程摊销但通常预测数据空间分布而非潜变量后验。DT 同时摊销 + 输出潜变量后验 + 允许换先验。
- vs 模拟基推断 (Cranmer 2020; Wildberger 2023):SBI 表达力强但通常假设固定先验且无法做序贯递归;DT 是"先验灵活 + 同族 + 摊销"的另一条路。
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ 把分布本身 token 化、显式追求先验-后验同族以支持序贯滤波,是摊销贝叶斯里少见的真定性突破。
- 实验充分度: ⭐⭐⭐⭐ 解析共轭、GP 超后验、量子参数、传感器融合、随机波动率覆盖广;可惜缺真实机器人/自动驾驶上的端到端 demo。
- 写作质量: ⭐⭐⭐⭐ 动机—架构—训练—理论命题—实验链条清晰,但 prior loss 的"非性能但必要"角色对读者略反直觉,可再展开。
- 价值: ⭐⭐⭐⭐⭐ 在毫秒级吞吐 + 任意换先验 + 序贯可级联三件事同时拿到,对工业实时贝叶斯应用是真实推动。