Bigram Subnetworks: Mapping to Next Tokens in Transformer Language Models¶
会议: NeurIPS 2025
arXiv: 2504.15471
代码: https://github.com/tylerachang/bigram-subnetworks
领域: 可解释性
关键词: Bigram子网络, 机制可解释性, 连续稀疏化, 残差流, 最小电路
一句话总结¶
通过连续稀疏化在Transformer语言模型中找到仅包含~10M参数的bigram子网络,它们集中在第一个MLP层,足以复现bigram预测(\(r>0.95\)),且被消融后模型性能大幅下降,证明这些子网络是语言模型中既必要又充分的最小next-token预测电路。
研究背景与动机¶
领域现状:机制可解释性研究已发现了induction head、name mover head等特定电路,但这些电路通常只覆盖特定行为。缺乏一个定义在整个输入空间上的"最小基础电路",作为研究更复杂电路的起点。
现有痛点:电路研究通常只验证必要性(消融后行为消失)但不验证充分性(电路独立运行后行为是否仍存在)。验证充分性需要在某个已理解的最小电路之上叠加目标电路,但这个最小电路是什么,一直不清楚。
核心矛盾:Transformer已知会在预训练早期过拟合bigram分布,但即使模型后来偏离了bigram预测,bigram信息是否仍编码在模型参数中?以什么形式存在?
切入角度:Bigram预测 \(P(w_i|w_{i-1})\) 是最简单的非平凡next-token预测,在整个输入空间上有定义。如果能找到实现bigram预测的子网络,它就是研究更复杂电路的理想基础。
核心 idea:用连续稀疏化在冻结的LLM中搜索mask,找到仅占0.17%参数但能达到r=0.96 bigram相关的子网络,主要集中在第一个MLP层。
方法详解¶
整体框架¶
冻结LLM参数 → 用连续稀疏化学习参数mask \(M\) → 最小化masked模型输出与bigram分布的交叉熵 + L1稀疏惩罚 → 得到二值mask定义的子网络。在Pythia (70M-1B) 和 GPT-2 (small-large) 上实验。
关键设计¶
-
连续稀疏化找子网络:
- 每个模型参数对应一个可学习的mask值 \(m \in (-\infty, +\infty)\),通过sigmoid映射到(0,1)
- 训练过程中逐渐降低sigmoid温度,使mask趋向二值
- 损失:\(\text{CE}(P(x), \text{MaskedModel}_M(x)) + \lambda \|M\|_1/|M|\)
- \(\lambda\) 控制稀疏度,从0到1000变化以观察不同稀疏度下的表现
-
关键发现:~10M参数的普适性:
- 不管模型大小(70M到1B),bigram子网络在~10M活跃参数处达到性能平台
- Pythia 1B中仅0.17%的非embedding参数就能达到 \(r=0.959\) 的bigram相关
- 说明bigram预测所需的"电路容量"与模型规模无关
-
结构分析:第一个MLP层的统治地位:
- 在所有模型和预训练检查点中,bigram子网络的大部分参数集中在第一个Transformer MLP层
- 甚至在随机初始化的模型中也是如此——说明这是架构+损失函数的固有偏置
- 机制解释:第一个MLP层负责将激活从"当前token表征"旋转到"next-token预测空间"
残差流旋转分析¶
- 在完整模型中,第一层后激活的token预测准确率急剧上升——从当前token空间跳转到下一个token空间
- Bigram子网络精确复现了这一跳转,且在后续层基本不变
- 这表明bigram子网络捕获了Transformer做next-token预测的最基础机制
实验关键数据¶
主实验(Bigram相关系数 \(r\))¶
| 模型 | 子网络参数占比 | Bigram \(r\) | 全模型 \(r\) |
|---|---|---|---|
| Pythia 70M | ~15% | 0.961 | 0.737 |
| Pythia 410M | ~2.5% | 0.983 | 0.650 |
| Pythia 1B | 0.17% | 0.959 | 0.632 |
| GPT-2 medium | ~1% | 0.985 | 0.582 |
| GPT-2 large | ~1% | 0.986 | 0.583 |
消融实验¶
| 操作 | 模型困惑度变化 | 说明 |
|---|---|---|
| 消融bigram子网络 | 大幅退化 | 0.17%参数对性能至关重要 |
| 消融等量随机参数 | 轻微退化 | bigram参数比随机参数重要得多 |
| bigram子网络 ∩ 最优剪枝子网络 | 高度重叠 | bigram参数也是剪枝保留的关键参数 |
关键发现¶
- 模型越大,bigram子网络占比越小但绝对参数量恒定(~10M)
- 预训练过程中bigram子网络先收缩再扩展——在~4K步达到最高效表征
- 第一个MLP层承担了从current-token到next-token表征空间旋转的关键角色
- Bigram子网络与最优剪枝子网络高度重叠——说明bigram预测是LM的"核心功能"
亮点与洞察¶
- "最小电路"概念的提出非常有建设性:为未来的电路发现研究提供了一个有明确定义的baseline——先找到bigram电路,再往上叠加更复杂的电路
- 10M参数的普适常数很有趣:暗示bigram预测的信息论复杂度约为10M参数,与模型规模无关
- 第一层MLP的特殊角色得到了精确量化:它不只是任意的初始处理,而是承担了从"我是什么token"到"下一个是什么token"的关键空间旋转
局限与展望¶
- 仅测试到1B规模——7B/70B模型中bigram子网络是否仍为~10M参数?
- 连续稀疏化可能找不到全局最优mask——不同初始化可能得到不同子网络
- 只研究了bigram(1-gram上下文)——能否推广到trigram、4-gram子网络?它们如何叠加?
- 未连接到induction head等已知电路——bigram子网络与induction head的关系是什么?
相关工作与启发¶
- vs Voita et al. (2024):他们发现跨所有层分布的bigram促进神经元。本文发现bigram子网络集中在第一层——说明不同层的bigram相关活动可能有不同功能
- vs 传统电路发现(IOI等):IOI电路针对特定任务,bigram子网络覆盖整个输入空间,是更基础的电路
评分¶
- 新颖性: ⭐⭐⭐⭐ bigram子网络的概念新颖,~10M普适常数是有趣发现
- 实验充分度: ⭐⭐⭐⭐⭐ 8个模型、多检查点、消融、剪枝overlap、残差流分析,非常细致
- 写作质量: ⭐⭐⭐⭐⭐ 层层递进、图表丰富、论点清晰
- 价值: ⭐⭐⭐⭐ 为机制可解释性提供了"最小电路"的基础,启发后续电路叠加研究