跳转至

How Transformers Learn Causal Structures In-Context: Explainable Mechanism Meets Theoretical Guarantee

会议: ICLR 2026
OpenReview: https://openreview.net/forum?id=bpF8zgSt41
代码: 待确认
领域: 可解释性 / in-context learning 理论
关键词: in-context learning, 因果结构推断, Bayesian Model Averaging, 注意力机制可解释性, 信息论保证, Markov 链

一句话总结

本文证明并实证一个两层带相对位置编码的 Transformer 能在上下文中显式实现 Bayesian Model Averaging (BMA)——这一统计最优算法——来推断每个 token 的"父节点"因果结构,并用信息论 (DPI / 互信息) 给出可识别性与训练动力学保证。

研究背景与动机

  • 领域现状:in-context learning (ICL) 的理论分析大多假设序列元素之间的依赖结构是固定预设的,例如 [x1, f(x1), x2, f(x2)] 这样的刚性模板,或假设一个固定的 n-gram/bigram 因果模型,再证明 Transformer 能把该结构嵌进注意力权重。
  • 现有痛点:真实序列(自然语言句法、股价资产关联)的依赖图本身是变化的,不同序列、不同文档之间的结构会剧烈漂移。Nichani et al. (2024) 等工作只能处理"训练时固定"的因果结构,无法回答模型能否在推理期临时推断并适配新结构。
  • 核心矛盾:经验上 Transformer 表现出强大的结构适配能力,但理论侧既缺一个能让结构本身作为隐变量、需从上下文样本里被反推的任务框架,也缺一个能把"注意力到底在算什么统计量"讲清楚的机制级解释。
  • 本文目标:构造一个隐因果结构随机采样的任务,回答 (⋆) "Transformer 能否在上下文中推断并适配因果结构?",并同时给出可解释机制(attention 在算什么) 与理论保证(为什么能选对父节点)。
  • 核心 idea因果结构作为隐变量 + BMA 作为最优基准 —— 用随机父依赖的 Markov 链生成序列,把"推断父节点"形式化为 Bayesian Model Averaging 的后验,再证明两层 Transformer 的第二层 attention 恰好就是这个 BMA 后验。

方法详解

整体框架

任务把每条长度为 \(H\) 的序列建模成一棵有向树:token \(x_h\) 依赖于唯一父节点 \(x_{\mathrm{pa}(h)}\),其中 \(\mathrm{pa}(h)\sim\mathrm{Unif}(1,\dots,h-1)\),整套父关系 \(G=\{\mathrm{pa}(h)\}\) 在一个 context(\(L\) 个示例 + 1 个待预测样本)内共享但跨 context 随机变化。模型必须从 \(L\) 个示例里反推 \(G\),再据此预测第 \(L{+}1\) 个样本。统计上的最优解是 BMA,把父节点当参数估计其后验,而本文证明的核心是一个两层结构——第一层 RPE 注意力当"复制器",第二层注意力当"BMA 父节点选择器"——能精确逼近它。

flowchart LR
    A["输入: L+1 个长度 H 的 Markov 链<br/>共享隐因果结构 G"] --> B["第1层 K-head RPE 注意力<br/>每个 head 复制一份历史观测 x_h^l"]
    B --> C["拼接得到过去 L 次观测<br/>z_h = [x_h^{L+1}, v_h]"]
    C --> D["第2层单头注意力<br/>双线性打分 = Σ_l log π(x_h^l | x_h'^l)"]
    D --> E["softmax = BMA 后验<br/>选出父节点 pa(h)"]
    E --> F["WOV ≈ log π 输出预测<br/>π(· | x_pa(h))"]

关键设计

1. 随机父依赖 Markov 链任务:把"结构"提升为待推断的隐变量。 与以往固定结构的设定不同,本文让每条序列的父关系图 \(G\) 从均匀分布里随机采样,同一 context 内 \(L+1\) 个样本共享同一 \(G\)。给定 BMA 的视角,父节点 \(\mathrm{pa}(h)=h'\) 的后验可写成对数似然累加后的 softmax:\(P(\mathrm{pa}(h){=}h'\mid x^{1:L}_{1:H})=\sigma\big(\hat p_{h,L}\big)_{h'}\),其中 \(\hat p_{h,L}^{h'}=\sum_{l\in[L]}\log\pi(x^l_h\mid x^l_{h'})\)。这把"选父节点"变成一个纯粹由转移核 \(\pi\) 决定的可计算量,从而给后面的机制对齐提供了精确靶子。论文同时给出离散 Markov 链与连续线性动力系统 \(x_h=\rho A^\top x_{\mathrm{pa}(h)}+\sqrt{1-\rho^2}\,\eta_h\) 两个版本。

2. 两层构造 = 复制器 + BMA 选择器。 第一层用相对位置编码 (RPE),把注意力打分拆成跨示例项 \(w_L[l-l']\) 与示例内项 \(w_H[h-h']\) 两部分;理论构造让第 \(k\) 个 head 在 \(w_H[0]\) 处占主导、\(w_L[k]\) 处占主导,于是每个 head 精确地"复制"出历史上同一位置 \(h\) 的第 \(k\) 次观测,\(K=L\) 个 head 拼起来就恢复了过去 \(L\) 次观测 \(x^{1:L}_h\)。第二层在 \(W_{KQ}\) 取分块对角 \(W\)\(\sigma(W_{OV})=\pi\) 的约束下,注意力打分塌缩成双线性形式 \(\hat p^h_{h'}(W)=\sum_l x^{l\top}_{h'}W x^l_h\);只要取 \(W=\log\pi\),它就与 BMA 打分逐项相等。Theorem 1 给出极限结果:\(\lim_{\beta\to\infty}A^{(2)}_{h\to\cdot}=\sigma(\hat p^{h,L}_{\mathrm{BMA}})\),且 \(\lim_{\beta,L\to\infty}f_\theta(\cdot\mid H^L_h)=\pi(\cdot\mid x_{\mathrm{pa}(h)})\),即预测收敛到真实条件分布。

3. 训练得到的参数确实"长成" BMA。 仅有构造还不够,论文进一步验证梯度下降学到的 \(W_{tf}\) 是否等价于 \(\log\pi\)。关键障碍是 softmax 注意力 \(\sigma(v^\top_{1:h-1}Wv_h)\)\(\log\pi\) 存在列方向的平移自由度Proposition 1 (Invariance) 证明若 \(W_{tf}=\log\pi+\mathbf 1 a^\top\) 则注意力输出与 BMA 完全一致。因此正确的对齐判据不是 \(\sigma(W_{tf})=\pi\)(行 softmax,会有明显偏差),而是检查列 softmax \(\sigma_{\mathrm{col}}(W_{tf})=\sigma_{\mathrm{col}}(\log\pi)\),实测列向误差 \(\frac1d\|\sigma_{\mathrm{col}}(W_{tf})-\sigma_{\mathrm{col}}(\log\pi)\|_1<0.05\),确认训练参数确实实现了 BMA。

4. 信息论可识别性 + 训练动力学保证。 为解释"为什么能选对父节点",论文用数据处理不等式 (DPI) 与互信息给出结构可识别性。Lemma 3 在转移核下界条件下证明 \(I(x_h;x_{h'})\le\alpha\,I(x_h;x_{\mathrm{pa}(h)})\)\(\alpha<1\)\(h'\ne\mathrm{pa}(h)\)),即真父节点的互信息严格占优;Lemma 4 把它转成期望对数似然形式 \(\mathbb E[\log\pi(x_h\mid x_{\mathrm{pa}(h)})]>\mathbb E[\log\pi(x_h\mid x_{h'})]\)Theorem 2 据此证明 \(\lim_{L\to\infty}A^L_{h\cdot}=e_{\mathrm{pa}(h)}\),注意力渐近塌缩到 one-hot 真父节点。Theorem 3 再分析训练动力学:在初始化处 \(\partial\ell/\partial\hat p\) 对真父节点的梯度分量最大,且这些梯度项与 \(\chi^2\)-互信息直接相关,说明隐因果结构在训练早期就被梯度恢复出来。

实验关键数据

本文以理论构造为主,实验侧主要通过注意力/参数可视化与父节点选择损失 \(L_{pa}\) 来验证机制对齐,未给出大规模数值跑分表。核心证据如下:

主要验证结果

验证维度 设置 观察
第二层注意力 \(A^{(2)}\) \(L{=}10,H{=}10,d{=}5\), 1024 步 注意力高亮恰好落在真因果父节点上 (Fig. 2 红框)
参数结构 \(w_H^k, W_{KQ}, W_{OV}\) 同上 \(w_H\)\(h{=}0\) 处最大、\(W_{KQ}\) 呈对角块、\(\sigma(W_{OV})\approx\pi\),与构造 Eq. (7) 一致 (Fig. 3)
父节点选择损失 \(L_{pa}\) 训练全程 随训练下降并逼近 BMA 的损失,但始终略高于 BMA
泛化 \(L^{L'}_{pa}\) \(L\in\{1,..,20\}, d{=}10, H{=}15\) 不同测试规模 \(L'\) 下接近 BMA;训练 \(L\) 越小泛化越好;固定 \(L\)\(L'\) 增大损失迅速趋零 (Fig. 5)
列 softmax 对齐 \(d{=}20,H{=}50,L{=}3\), 2048 步 列误差 0.0278(行误差 0.350),证明 \(W_{tf}=\log\pi+\mathbf 1a^\top\) (Fig. 6)

架构鲁棒性 / 连续扩展

设定 结论
标准 disentangled Transformer + 绝对位置编码 / 带 FFN 的标准 Transformer 收敛到相同注意力模式、父选择性能与 BMA 相当 (Appendix G–I, Fig. 19),说明机制不依赖 RPE 简化
连续线性动力系统 (DS) 注意力仍能选对父节点;但 \(L'\) 较小时与 BMA 有明显 gap,\(L'\to20\) 才追平
DS 表示局限 (Prop. 2) BMA 打分含二次项 \(d\sum_l\|x^l_{h'}\|^2\),而 Transformer 双线性打分无法表示该项,故 Eq. (7) 下不存在 \(W_{tf}\) 精确实现 DS 的 BMA

关键发现

  • 离散 Markov 链下 \(W_{tf}=\log\pi\) 能精确匹配 BMA,但连续动力系统下因二次项缺失只能近似——揭示离散 vs 连续因果推断在表示需求上的本质差异。
  • 父结构在训练初始梯度里就已被 \(\chi^2\)-互信息驱动地恢复,结构发现发生在训练极早期。

亮点与洞察

  • 机制可解释 + 理论保证双管齐下:不仅构造性地证明"能算 BMA",还实证训练后的参数"确实是 BMA",并用 DPI 解释"为什么选得对",三层论证闭环少见。
  • 列平移不变性 (Prop. 1) 是关键巧思:纠正了"直接比 \(\sigma(W_{tf})\)\(\pi\)"的天然误判,给出正确的对齐验证判据,否则会错误地得出"参数没对齐"的结论。
  • 把结构本身变成隐变量,比以往"固定结构"设定更贴近真实序列的 context-dependent 依赖,框架本身有方法论价值。

局限与展望

  • 结构假设强:每个 token 恰好一个父节点的有向树 + Markov 性,离自然语言真实的多父、长程、非 Markov 依赖仍远。
  • 模型极小:两层、个位数维度、attention-only 的玩具规模,结论是否外推到真实 LLM 未验证。
  • 连续设定有硬限制:Prop. 2 证明 Transformer 双线性结构无法表达 DS 的 BMA 二次项,提示需要更高阶的注意力或额外非线性才能闭合连续因果推断。
  • BMA 仍是上界:训练得到的 \(L_{pa}\) 始终略高于 BMA,未完全达到统计最优。

相关工作与启发

  • 与 Nichani et al. (2024) 同源(共用转移核下界条件、DPI 工具),但关键区别是把"训练时固定结构"推进到"上下文中推断变化结构",并把 \(\chi^2\)-互信息框架推广到可归约为经典互信息的情形、且去掉了平稳分布假设。
  • 与 D'Angelo et al. (2025) 的 selective induction heads 任务相关,本文构造稍作修改即可覆盖其更简单的设定。
  • 接续 induction head(Olsson et al. 2022)、统计 induction head(Edelman et al. 2024)一脉,把"copier+selector"两层电路从 n-gram 推广到随机因果结构推断,对理解 ICL 的机制可解释性有直接启发。

评分

  • 新颖性: ⭐⭐⭐⭐ —— 把因果结构作为隐变量在上下文中推断,并证明两层 Transformer 显式实现 BMA,框架与机制解释都较新。
  • 实验充分度: ⭐⭐⭐ —— 机制级可视化与对齐验证扎实,但停留在玩具规模、以定性证据为主,缺真实模型/数据上的验证。
  • 写作质量: ⭐⭐⭐⭐ —— 构造、实证、信息论保证三段逻辑清晰,takeaway 标注到位;公式密度高,阅读门槛偏高。
  • 价值: ⭐⭐⭐⭐ —— 为 ICL 的"结构适配"提供了可解释 + 可证明的统一解释,对机制可解释性研究有参考价值。