跳转至

Bigram Subnetworks: Mapping to Next Tokens in Transformer Language Models

会议: NeurIPS 2025
arXiv: 2504.15471
代码: https://github.com/tylerachang/bigram-subnetworks
领域: 可解释性
关键词: Bigram子网络, 机制可解释性, 连续稀疏化, 残差流, 最小电路

一句话总结

通过连续稀疏化在Transformer语言模型中找到仅包含~10M参数的bigram子网络,它们集中在第一个MLP层,足以复现bigram预测(\(r>0.95\)),且被消融后模型性能大幅下降,证明这些子网络是语言模型中既必要又充分的最小next-token预测电路。

研究背景与动机

领域现状:机制可解释性研究已发现了induction head、name mover head等特定电路,但这些电路通常只覆盖特定行为。缺乏一个定义在整个输入空间上的"最小基础电路",作为研究更复杂电路的起点。

现有痛点:电路研究通常只验证必要性(消融后行为消失)但不验证充分性(电路独立运行后行为是否仍存在)。验证充分性需要在某个已理解的最小电路之上叠加目标电路,但这个最小电路是什么,一直不清楚。

核心矛盾:Transformer已知会在预训练早期过拟合bigram分布,但即使模型后来偏离了bigram预测,bigram信息是否仍编码在模型参数中?以什么形式存在?

切入角度:Bigram预测 \(P(w_i|w_{i-1})\) 是最简单的非平凡next-token预测,在整个输入空间上有定义。如果能找到实现bigram预测的子网络,它就是研究更复杂电路的理想基础。

核心 idea:用连续稀疏化在冻结的LLM中搜索mask,找到仅占0.17%参数但能达到r=0.96 bigram相关的子网络,主要集中在第一个MLP层。

方法详解

整体框架

冻结LLM参数 → 用连续稀疏化学习参数mask \(M\) → 最小化masked模型输出与bigram分布的交叉熵 + L1稀疏惩罚 → 得到二值mask定义的子网络。在Pythia (70M-1B) 和 GPT-2 (small-large) 上实验。

关键设计

  1. 连续稀疏化找子网络

    • 每个模型参数对应一个可学习的mask值 \(m \in (-\infty, +\infty)\),通过sigmoid映射到(0,1)
    • 训练过程中逐渐降低sigmoid温度,使mask趋向二值
    • 损失:\(\text{CE}(P(x), \text{MaskedModel}_M(x)) + \lambda \|M\|_1/|M|\)
    • \(\lambda\) 控制稀疏度,从0到1000变化以观察不同稀疏度下的表现
  2. 关键发现:~10M参数的普适性

    • 不管模型大小(70M到1B),bigram子网络在~10M活跃参数处达到性能平台
    • Pythia 1B中仅0.17%的非embedding参数就能达到 \(r=0.959\) 的bigram相关
    • 说明bigram预测所需的"电路容量"与模型规模无关
  3. 结构分析:第一个MLP层的统治地位

    • 在所有模型和预训练检查点中,bigram子网络的大部分参数集中在第一个Transformer MLP层
    • 甚至在随机初始化的模型中也是如此——说明这是架构+损失函数的固有偏置
    • 机制解释:第一个MLP层负责将激活从"当前token表征"旋转到"next-token预测空间"

残差流旋转分析

  • 在完整模型中,第一层后激活的token预测准确率急剧上升——从当前token空间跳转到下一个token空间
  • Bigram子网络精确复现了这一跳转,且在后续层基本不变
  • 这表明bigram子网络捕获了Transformer做next-token预测的最基础机制

实验关键数据

主实验(Bigram相关系数 \(r\)

模型 子网络参数占比 Bigram \(r\) 全模型 \(r\)
Pythia 70M ~15% 0.961 0.737
Pythia 410M ~2.5% 0.983 0.650
Pythia 1B 0.17% 0.959 0.632
GPT-2 medium ~1% 0.985 0.582
GPT-2 large ~1% 0.986 0.583

消融实验

操作 模型困惑度变化 说明
消融bigram子网络 大幅退化 0.17%参数对性能至关重要
消融等量随机参数 轻微退化 bigram参数比随机参数重要得多
bigram子网络 ∩ 最优剪枝子网络 高度重叠 bigram参数也是剪枝保留的关键参数

关键发现

  • 模型越大,bigram子网络占比越小但绝对参数量恒定(~10M)
  • 预训练过程中bigram子网络先收缩再扩展——在~4K步达到最高效表征
  • 第一个MLP层承担了从current-token到next-token表征空间旋转的关键角色
  • Bigram子网络与最优剪枝子网络高度重叠——说明bigram预测是LM的"核心功能"

亮点与洞察

  • "最小电路"概念的提出非常有建设性:为未来的电路发现研究提供了一个有明确定义的baseline——先找到bigram电路,再往上叠加更复杂的电路
  • 10M参数的普适常数很有趣:暗示bigram预测的信息论复杂度约为10M参数,与模型规模无关
  • 第一层MLP的特殊角色得到了精确量化:它不只是任意的初始处理,而是承担了从"我是什么token"到"下一个是什么token"的关键空间旋转

局限与展望

  • 仅测试到1B规模——7B/70B模型中bigram子网络是否仍为~10M参数?
  • 连续稀疏化可能找不到全局最优mask——不同初始化可能得到不同子网络
  • 只研究了bigram(1-gram上下文)——能否推广到trigram、4-gram子网络?它们如何叠加?
  • 未连接到induction head等已知电路——bigram子网络与induction head的关系是什么?

相关工作与启发

  • vs Voita et al. (2024):他们发现跨所有层分布的bigram促进神经元。本文发现bigram子网络集中在第一层——说明不同层的bigram相关活动可能有不同功能
  • vs 传统电路发现(IOI等):IOI电路针对特定任务,bigram子网络覆盖整个输入空间,是更基础的电路

评分

  • 新颖性: ⭐⭐⭐⭐ bigram子网络的概念新颖,~10M普适常数是有趣发现
  • 实验充分度: ⭐⭐⭐⭐⭐ 8个模型、多检查点、消融、剪枝overlap、残差流分析,非常细致
  • 写作质量: ⭐⭐⭐⭐⭐ 层层递进、图表丰富、论点清晰
  • 价值: ⭐⭐⭐⭐ 为机制可解释性提供了"最小电路"的基础,启发后续电路叠加研究