Learning to Recall with Transformers Beyond Orthogonal Embeddings¶
会议: ICLR 2026
arXiv: 2603.15923
代码: 无
领域: Transformer 理论 / 优化理论
关键词: Transformer, 记忆与检索, 存储容量, 非正交嵌入, 梯度下降分析
一句话总结¶
在随机(非正交)嵌入条件下分析单层 Transformer 在 token 检索任务上经验梯度下降的"早期阶段",推导出模型存储容量的显式公式,揭示了样本量 N、嵌入维度 d 和序列长度 L 之间的乘法依赖关系,并证明这一缩放关系是信息论下界固有的。
研究背景与动机¶
大型语言模型(LLM)在事实回忆、问答这类需要存储并检索知识的任务上表现出色,而 Transformer 能在训练时把信息编码进参数、推理时再检索出来,正是这一能力的核心。理解它"怎么学会记忆与检索"也因此成为深度学习理论的重要问题。
但现有理论分析大多建立在两条与现实脱节的理想假设上。其一是无限数据:分析在总体梯度(population gradient)下进行,忽略了有限样本带来的统计涨落。其二是正交嵌入:假设 token 嵌入向量两两正交——这只在嵌入维度 \(d\) 远大于词汇表大小 \(V\) 时才近似成立,而实际模型里 \(d<V\),嵌入是随机的非正交向量。更微妙的是,已有工作(NLB25)指出严格正交的嵌入其实并非容量最优,反倒是随机(非正交)嵌入靠"叠加"(superposition)能逼近最优的事实存储容量。
问题在于,一旦放弃正交假设,随机嵌入会在 token 之间引入干扰(interference),让优化轨迹变复杂、并从根本上改变存储容量随各因素的缩放方式。本文的目标,就是在"有限样本 + 非正交随机嵌入 + 经验梯度下降"这一更贴近真实训练的设定下,精确刻画单层 Transformer 学习一个事实检索任务时的优化复杂度与样本复杂度。
方法详解¶
整体框架¶
论文把"Transformer 如何记住并检索知识"剥成一个最小但有代表性的合成任务:给定一条长度为 \(L\) 的序列,里面恰好有一个位置 \(\ell\) 是"信息 token"、其余都是噪声 token,标签由一个固定的置换矩阵 \(\Pi^*\) 作用在信息 token 上得到(\(p=\Pi^* x_\ell\))。模型要做两件事——先靠注意力从 \(L\) 个候选里定位 \(\ell\),再学会从该 token 到标签的一一映射。承担这两步的是一个单层 Transformer:自注意力头(参数 \(W_{KQ}\),借助标记信息 token 的触发向量 \(z_{\text{trig}}\) 和序列末尾向量 \(z_{\text{EOS}}\))负责定位,其后的值矩阵 \(V\)(可选再叠一层宽度为 \(m\)、保持随机初始化不训练的 MLP)按联想记忆的方式输出标签;嵌入维度 \(d<V\),嵌入是随机非正交向量。
为了能把学习过程算清楚,论文不去证明全局收敛,而是只跟踪梯度下降最开始的几步——一个三步训练算法:从零初始化出发,先更新一次值矩阵 \(V\)、再更新一次注意力 \(W_{KQ}\)、最后再精修一次 \(V\),每一步都用 \(N\) 个有限样本的经验梯度。整套分析的落点,是从这几步的演化里读出"成功学习"对词汇量 \(V\)、样本量 \(N\)、嵌入维度 \(d\)、序列长度 \(L\)、MLP 宽度 \(m\) 的联合要求,并用一个统计下界证明这个要求无法被任何只访问初始梯度信息的算法绕过。
关键设计¶
1. token 检索任务与单层架构:把事实回忆抽象成"定位 + 联想"两步
直接分析真实 LLM 的记忆行为无从下手,论文先把它压到只剩最本质的两步。任务里,一条长度 \(L\) 的序列只有位置 \(\ell\) 藏着信息 token,标签是它经过固定置换 \(\Pi^*\) 后的结果,其余 token 纯属干扰。模型对应地分两段:自注意力头用触发向量 \(z_{\text{trig}}\) 标记信息 token、用 \(z_{\text{EOS}}\) 作查询,softmax 注意力把权重集中到 \(\ell\) 上完成上下文定位;随后值矩阵 \(V\)(Attention-only 模型)或"\(V\) + 固定随机 MLP"(Attention-MLP 模型,用宽度 \(m\) 换取在小 \(d\) 下仍有大容量)把选中的 token 映射到正确标签,完成内容输出。这两步正是 LLM 做事实检索的核心计算结构,所以设定虽简单、结论却能对真实模型说话;同时任务被剥到只剩这两步,后续每个统计量的演化才能被逐项精确追踪。关键的现实性来自 \(d<V\) 且嵌入随机非正交——这正是经典正交分析回避、而本文要正面处理的难点。
2. 三步早期阶段训练算法:只算决定成败的前几步
非正交嵌入会让完整训练轨迹出现震荡等复杂行为,难以全局求解。论文沿用 ORST23 的思路,只刻画从初始化出发的"早期阶段",把训练压成三步梯度更新:参数初始化为 \(V^{(0)}=0,\ W_{KQ}^{(0)}=0\),然后
其中 \(\hat{L}\) 是 \(N\) 个样本上的经验交叉熵。之所以盯住前几步,是因为这类任务里信号方向(注意力是否选中正确 token、值矩阵是否指向正确标签)几乎在最初就被定下,微弱信号会在后续被持续放大;技术上则用高维概率的集中不等式,把有限样本经验梯度对总体梯度的偏离控制住,从而把"能否学成"翻译成这几步里关键统计量必须满足的显式条件。先更新 \(V\) 再更新 \(W_{KQ}\) 最后回头精修 \(V\) 的顺序也有讲究:先让值矩阵建立粗略的 token→标签映射,注意力才有可用的梯度信号去对齐到信息 token。
3. 乘法型存储容量公式:\((V,N,d,L,m)\) 的耦合与相图
把早期阶段的成功条件解开,核心结论是一条乘法形式的缩放关系:能否成功检索取决于 \((V,N,d,L,m)\) 五个量如何相乘耦合,而非各自独立的阈值。方向上,\(N\)、\(d\)、\(m\) 越大越容易学(更多数据、更高维因而更接近正交的嵌入、更宽的 MLP 带来更大容量),\(V\)、\(L\) 越大越难(词汇更多、要从更长序列的更多干扰项里挑出信息 token)。论文进一步给出相图(phase diagram),把所需参数规模 \(m\cdot d\) 的条件按主导噪声项分区——均值偏置(mean bias)、梯度噪声(gradient noise)、MLP 噪声各自在不同区域占主导,对应不同的 \(md\gtrsim\dots\) 门槛。三者之所以乘在一起而不能分开看,根子就在非正交:随机嵌入的 token 间干扰把数据、维度、上下文长度的效应纠缠到同一个量上。由此还导出一个权衡——减小 \(d\) 会增强叠加、提升存储容量,却同时让学习问题变难(需要更大的 \(N\)),短序列下容量与样本复杂度可兼得,长序列下则必须在"增大 \(d\) 牺牲容量"和"增大 \(N\) 恶化样本复杂度"之间二选一。
4. 匹配的信息论下界:乘法瓶颈是问题固有而非算法所限
仅有算法侧(Transformer + 三步梯度下降)的上界还不足以下定论,论文又从统计角度给出该问题的固有难度下界:对任何只访问初始化 Transformer 梯度信息的估计量,这条乘法型权衡都成立,且下界与上界同阶。结论是,\((V,N,d,L,m)\) 的乘法缩放不是这套训练算法没设计好,而是非正交嵌入下任务的内在性质——换架构或换优化器都绕不开由干扰本身设下的瓶颈。这也反过来说明,过去在正交假设下推出的容量估计是系统性偏乐观的。
实验关键数据¶
主实验:存储容量缩放验证¶
论文通过数值实验验证理论预测的缩放关系:
| 维度 d | 序列长度 L | 理论预测的临界 N | 实际观测的临界 N | 匹配度 |
|---|---|---|---|---|
| 小 d | 小 L | 较低 | 与理论一致 | ✓ |
| 小 d | 大 L | 较高 | 与理论一致 | ✓ |
| 大 d | 小 L | 较低 | 与理论一致 | ✓ |
| 大 d | 大 L | 中等 | 与理论一致 | ✓ |
消融实验:正交 vs 非正交嵌入¶
| 嵌入类型 | 存储容量缩放 | 说明 |
|---|---|---|
| 正交嵌入 | N 与 d, L 分别独立缩放 | 经典设置,因素可分离 |
| 随机(非正交)嵌入 | N, d, L 乘法耦合 | 更现实设置,三者不可分 |
下界验证¶
| 设置 | 算法上界(Transformer+GD) | 信息论下界 | 间隙 |
|---|---|---|---|
| 非正交嵌入 | \(O(f(N,d,L))\) | \(\Omega(g(N,d,L))\) | 紧致(同阶) |
关键发现¶
- 乘法缩放是固有的:\((V,N,d,L,m)\) 的耦合关系源自非正交嵌入带来的 token 间干扰,不是算法的缺陷,且有匹配的下界佐证
- 正交假设导致过度乐观:在正交假设下推导的容量会高估真实容量
- 早期阶段是关键:三步梯度更新的最初几步就决定了注意力能否锁定正确的信息 token
- \(d\) 是把双刃剑:增大 \(d\) 让嵌入更接近正交、削弱干扰、更易学,但减小 \(d\) 反而增强叠加、提升存储容量——容量与可学性之间存在权衡
- \(V\)、\(L\) 加大任务难度:词汇量越大、序列越长,注意力要从越多干扰项里挑出信息 token,需要更多样本或更大维度来补偿
亮点与洞察¶
- 填补了理论与实践之间的关键鸿沟:放松正交嵌入和无限数据假设后的分析更贴近真实 LLM 的工作方式
- 乘法缩放关系的优雅:一个简洁的公式统一了三个看似独立的因素(数据量、维度、序列长度)
- 信息论下界的重要性:不仅说明了 Transformer 能做到什么,更说明了任何方法都不能做到什么
- 对实际 LLM 设计的暗示:在固定计算预算下,增大嵌入维度 vs 增加训练数据 vs 缩短上下文窗口之间存在最优权衡
- 将 Transformer 的"记忆能力"从经验直觉提升到精确理论
局限与展望¶
- 仅分析单层单头 Transformer:实际 LLM 是多层多头的,层间交互和多头协作可能改变容量缩放
- 早期阶段分析:未覆盖训练的全局收敛行为,后期阶段可能有不同的动力学
- Token 检索任务简化:真实 LLM 的任务远比单一 token 检索复杂,涉及组合和推理
- 随机嵌入假设:实际中嵌入是学习得到的,具有特定结构(如低秩、聚类),非均匀随机
- 未讨论位置编码的影响:位置编码会改变注意力计算中的有效嵌入结构
相关工作与启发¶
- 与 Bietti & Cabannes (2024) 的联系:后者在正交嵌入下分析了类似的检索任务,本文推广到非正交设置
- 与 Ahn et al. (2024) 的关系:后者分析了线性 Transformer 的 in-context learning,侧重不同方面
- 与联想记忆(Hopfield Networks)的类比:经典的存储容量分析(如 \(0.14N\) 模式数上界)在 Transformer 中的对应
- 对 KV Cache 设计的启发:存储容量的缩放关系暗示了 KV cache 压缩的理论极限
- 对 RAG 系统的理论支撑:检索增强生成的核心就是"在上下文中找到相关信息"
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ (非正交嵌入下的分析填补重要理论空白)
- 实验充分度: ⭐⭐⭐⭐ (数值验证充分,但限于理论设定)
- 写作质量: ⭐⭐⭐⭐ (理论严谨,清晰度良好)
- 价值: ⭐⭐⭐⭐ (对理解 Transformer 记忆能力有重要贡献)