Beyond Linearity in Attention Projections: The Case for Nonlinear Queries¶
会议: ICLR 2026 Workshop (GRaM)
arXiv: 2603.13381
代码: GitHub
领域: 其他
关键词: nonlinear query, attention projection, identity prior, bottleneck MLP, transformer architecture
一句话总结¶
基于 \(W_Q\) 代数冗余性的理论发现,将线性 Query 投影替换为非线性残差形式 \(Q(X)=(X+f_\theta(X))/2\),在不增加参数的情况下超越 +12.5% 参数的基线模型。
研究背景与动机¶
代数冗余性发现:Karbevski & Mijoski (2024) 证明 Transformer 中存在重参数化不变性——对任意可逆矩阵 \(\Theta\),将 \((X, W_Q, W_K, W_V, W_O)\) 映射为 \((X\Theta, \Theta^{-1}W_Q, ...)\) 后 MHA 输出不变。取 \(\Theta = W_Q\) 可使 \(W_Q \to I\),表明 \(W_Q\) 的线性参数完全可被相邻层吸收——代数上冗余
实验验证:设置 \(W_Q = I\) 的模型与标准基线性能相当,且在 3× 更低 weight decay 下仍稳定,证明恒等映射是 query 路径的良好先验
核心推理:既然线性参数冗余(可被吸收),若要在 query 路径有效分配参数,就必须引入非线性——非线性是不可被吸收的
信息瓶颈视角:从单个 token \(x\) 生成 q/k/v/残差四个向量全为 \(x\) 的线性函数,构成信息瓶颈。非线性 query 部分解耦了这个瓶颈
为何选择 Query:GQA 共享 \(W_K/W_V\),只有 \(W_Q\) 可安全替换而不破坏共享结构
方法详解¶
核心公式¶
\(f_\theta\) 结构¶
瓶颈 MLP:\(f_\theta(X) = \text{LN}(\text{GELU}(\text{RMSNorm}(X) \cdot W_1) \cdot W_2)\) - \(W_1 \in \mathbb{R}^{d \times r}\), \(W_2 \in \mathbb{R}^{r \times d}\), \(r = d/2\) - 矩阵参数总量 \(2dr = d^2\),与原始 \(W_Q\) 同阶 - 归一化层仅增加 \(O(d)\) 参数(<0.1%)
设计要点¶
- Identity Anchor:\(X\) 项锚定到已知良好先验(\(W_Q=I\)),保证梯度直通和训练稳定性
- 1/2 缩放因子:遵循 Karbevski & Mijoski 的建议,防止幅度膨胀
- K/V 不变:Key 和 Value 保持标准线性投影
- 兼容性:兼容 GQA/MQA、RoPE、MoE 等现代架构
实验关键数据¶
主实验(GPT-3 Small, ~124M, OpenWebText, 60k steps ≈ 29B tokens)¶
| 模型 | 非嵌入参数 | Val Loss (59k) | 相对提升 |
|---|---|---|---|
| Baseline | 85M | 2.956 | 0 |
| MLP 4.75(宽MLP, +12.5%参数) | 96M | 2.927 | 0.98% |
| MLP 4.75 (scaled LR) | 96M | 2.928 | 0.94% |
| Res. GELU (本文) | 85M | 2.919 | 1.24% |
| Res. GELU (最优超参) | 85M | 2.915 | 1.40% |
训练稳定性¶
| 配置 | 结果 | 说明 |
|---|---|---|
| Baseline, WD=0.05 | 20k步前发散 | 标准模型不稳定 |
| Res. GELU, WD=0.03, LR=3e-3 | 稳定到60k | 可承受5×更高LR |
关键发现¶
- 训练远超 Chinchilla 最优(29B tokens vs 2.5B optimal),确保改进不是因 token 不足产生的"水分"
- 所有模型在固定随机种子下看到完全相同的训练和验证数据,控制变量严格
- 非线性变体在 warmup 阶段增益最大,中期减小,末期最佳变体增益回升
- 作者明确表示 1.40% 可能是下界而非上界——超参搜索非常有限
亮点与洞察¶
- 理论驱动架构修改:从代数冗余性出发,逻辑链完整(\(W_Q\) 冗余→线性无用→引入非线性)
- 参数中性改进:不增加参数即超越 +12.5% 参数的模型——说明参数效率而非容量才是瓶颈
- 训练稳定性双赢:非线性版本不仅性能更好,还能在更激进的超参数(低 WD、高 LR)下保持稳定
- 代码和 checkpoint 完全开源
局限与展望¶
- 仅在 ~124M 单一规模验证,未测试大模型(是否在大模型上冗余性依然成立?)
- 未进行多种子实验(通过固定数据顺序和长时间训练缓解)
- 未测量推理速度——非线性引入串行依赖(瓶颈 MLP 必须先于注意力完成)
- 超参数搜索极不充分,各种归一化变体和激活函数未系统探索
- 未评估下游任务表现(仅报告预训练 val loss)
相关工作与启发¶
- Kernel Attention (Performer 等):在 \(Q=XW_Q\) 之后加非线性特征映射;本文直接替换 \(W_Q\)
- MLP-Attention (Zhang'24):用 MLP 替代 Q/K/V 所有投影,但增加 ~10% 参数且缺乏理论动机
- Nonlinear LoRA:面向微调场景;本文面向预训练架构设计
- Always Skip Attention (Ji et al.):揭示自注意力对 skip connection 的独特依赖,与本文的 identity anchor 呼应
评分¶
- 新颖性: ⭐⭐⭐⭐ 理论驱动的架构修改,方向新颖,但改动量较小
- 实验充分度: ⭐⭐⭐ 单一规模、单一数据集,但控制变量极为严谨
- 写作质量: ⭐⭐⭐⭐ Workshop 篇幅内逻辑清晰,数学简洁
- 价值: ⭐⭐⭐⭐ 如能在大规模模型上验证将有重要意义