跳转至

In-Place Test-Time Training

会议: ICLR 2026
OpenReview: https://openreview.net/forum?id=dTWfCLSoyl
代码: 无
领域: LLM效率 / 长上下文 / 测试时训练
关键词: 测试时训练, 快权重, 长上下文, MLP 复用, Next-Token 对齐目标

一句话总结

本文把 Transformer 里 MLP 块的下投影矩阵 \(W_{down}\) 当作可在推理时更新的「快权重」,配上一个对齐 Next-Token Prediction 的训练目标和分块更新机制,让现成的预训练 LLM 不改架构、不从头训就「即插即用」获得测试时训练(TTT)能力,在 128k 乃至 256k 长上下文上稳定超过原模型与 GLA / DeltaNet / LaCT 等竞品。

研究背景与动机

领域现状:当下 LLM 都是「先训练后部署」的静态范式——权重在海量语料上训好后就冻结,推理时一字不改。为了让模型处理超长、不断演进的任务,主流有两条路:一是靠 in-context learning 把历史 token 全塞进上下文窗口,但受限于注意力的二次复杂度;二是测试时训练(Test-Time Training, TTT),它引入一小撮「快权重」(fast weights),在推理时对每个新输入做一步梯度下降更新,把上下文信息在线压缩进这个动态记忆里。

现有痛点:TTT 概念上很美,但在 LLM 生态里落地有三道坎。其一,架构不兼容:现有 TTT 通常是替换注意力的专用循环层,随机初始化的新层和数十亿训练好的参数冲突,几乎必须从头重训,对大模型而言代价高到不现实。其二,计算低效:经典 TTT 是逐 token 串行更新,严重浪费 GPU/TPU 并行能力;即便有分块加速,TTT 作为主 token mixer 又被迫用小 chunk 来保性能,依然喂不饱现代加速器。其三,目标错位:TTT 普遍用通用的重建(reconstruction)目标,让快权重去关联同一个 token 的 \((k,v)\),本质是「记住当前 token」,和语言模型真正关心的「预测下一个 token」并不对齐。

核心矛盾:TTT 想替注意力的「野心」恰恰是它落地难的根源——一旦定位成「取代注意力的核心 token mixer」,就同时背上了从头重训、严格逐 token 因果、小 chunk 的三重包袱。

本文目标:在不动注意力、不从头训的前提下,给现成 LLM 装上 TTT 能力,同时解决效率与目标对齐。

切入角度:作者的关键观察是——快权重的选择没有任何约束,任意参数都能当快权重;而 Transformer 里的 MLP 块本身就可以看作一种 key-value 记忆,存的是预训练学到的「慢权重」通用知识。那么自然可以让同一个 MLP 兼职快权重,在推理时动态吸收上下文。

核心 idea:把 MLP 的下投影矩阵原地(in-place)当快权重更新,用一个对齐 NTP 的目标 + 大 chunk 并行更新,让 TTT 从「破坏性重构」变成「即插即用」的轻量增强。

方法详解

整体框架

In-Place TTT 的整体思路是:不新增层、不换注意力,而是把每个 Transformer 块里那个无处不在的门控 MLP 拿来「一物两用」——它的输入投影 \(W_{up}, W_{gate}\) 保持冻结、继续当存通用知识的慢权重;它的下投影 \(W_{down}\) 则被释放成快权重,在推理时随上下文分块原地更新。整条数据流是一个严格因果的「apply-then-update」循环:序列先切成若干 chunk,对每个 chunk 先用当前快权重把中间激活投影出输出,再用这个 chunk 的激活做 key、用一个含未来 token 信息的 target 做 value,做一步梯度下降把快权重推进到下一个状态,交给下一个 chunk。注意力层完全不变,TTT 模块与注意力互补而非替代——这正是它能用大 chunk、能即插即用的根本原因。

%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400}}}%%
flowchart TD
    A["输入序列<br/>切成 chunk"] --> B["MLP 复用为快权重<br/>冻结 W_up/W_gate,更新 W_down"]
    B --> C["Apply:用当前 W_down<br/>投影激活 Z 得输出 O"]
    C --> D["LM 对齐目标<br/>Conv1D + 投影造 future-token value"]
    D --> E["Update:一步梯度<br/>W_down 推进到下一 chunk 状态"]
    E -->|前缀和并行 / Context Parallel| F["逐 chunk 因果输出<br/>长上下文增强"]
    E -.下一 chunk.-> C

关键设计

1. 原地复用 MLP 下投影矩阵作快权重:不改架构就能 TTT

针对「架构不兼容、必须从头重训」这道坎,作者拒绝再造一个随机初始化的专用 TTT 层,而是直接复用现成的门控 MLP。门控 MLP 的输出是 \(O = \big(\phi(HW_{gate}^\top) \odot (HW_{up}^\top)\big)W_{down}^\top\),本文把 \(W_{up}, W_{gate}\) 当作冻结的慢权重(保留预训练知识),唯独把最后那个下投影 \(W_{down}\) 当作可在推理时原地更新的快权重。这样做之所以成立,是因为 TTT 形式上对「谁来当快权重」没有任何限制,而 MLP 本身就被论证为一种 key-value 记忆,让它额外兼任动态记忆是自然延伸。好处是彻底「drop-in」:模型结构、预训练权重的完整性都不动,只需相对便宜的继续训练就能给现成 LLM 装上 TTT,避开了从头预训练的天价成本。

2. 大 chunk 分块更新:摆脱逐 token 串行,喂饱加速器

针对「逐 token 串行、效率低」这道坎,本文用分块更新替代逐 token 更新。把中间激活 \(Z = \phi(HW_{gate}^\top)\odot(HW_{up}^\top)\) 和对应的 value/output 切成大小为 \(C\) 的不重叠 chunk,对每个 chunk 先 Apply(\(O_{[i]} = Z_{[i]}(W_{down}^{(i)})^\top\))再 Update(一步梯度把 \(W_{down}^{(i)}\) 更新到 \(W_{down}^{(i+1)}\))。关键在于:正因为只更新 MLP、注意力层原封不动,TTT 不再背负「严格逐 token 因果」和「小 chunk 才保性能」的包袱——注意力已经负责细粒度 token mixing,TTT 只做互补的上下文压缩,于是可以放心用 512~1024 这样的大 chunk 成块处理,把并行度拉满。消融也证实它天然适配大 chunk,最优 chunk 在 512~1024。

3. LM 对齐目标:让快权重存「对预测下一个 token 有用」的信息

针对「重建目标与语言建模错位」这道坎,本文把 value target 从「当前 token 自身」换成「含未来 token 信息」的目标。具体地,target 取 \(\hat{V} = \mathrm{Conv1D}(X_0)W_{target}\),其中 \(X_0\) 是 token embedding,1D 卷积负责把邻近未来 token 的信息按可学习权重聚合进来,\(W_{target}\) 是可训练投影;把卷积核设成只取下一个 token、\(W_{target}\) 设成单位阵,就退化成标准的 Next-Token target,而一般情况下它学到的是一个局部未来 token 的组合,与先进 LLM 里的 Multi-Token Prediction 思路一致。损失用最简单的相似度 \(L(\cdot,\cdot) = -\langle\cdot,\cdot\rangle_F\),于是分块下的快权重更新有干净的闭式:\(W_{down}^{(i)} = W_{down}^{(i-1)} + \eta\,\hat{V}_{[i]}^\top Z_{[i]}\)。作者还在 induction head 设定下给了理论保证(Theorem 1):用 LM 对齐 target 更新一步后,正确下一个 token \(v^*\) 的 logit 期望增加(下界 \(\lambda_{lr}c_{norm}^2 c_{align}\)),其它 token 几乎不变;而重建 target 对正确 token 的 logit 提升可忽略——直观说明对齐目标真的把「预测性有用」的信息压进了快权重。

4. Context-Parallel 因果实现:并行扫描下严格等价于串行

为了在长序列上既快又不破坏因果性,本文让更新规则适配 Context Parallelism。由于更新 \(\Delta W_{down}^{(i)} = \hat{V}_{[i]}^\top Z_{[i]}\) 满足结合律,可以分三步并行:(i) 所有 chunk 并行算各自的激活和增量;(ii) 对增量序列做一次前缀和(prefix sum / parallel scan)得到每个 chunk 的累计更新 \(\Delta S_i\);(iii) 用 \(W_{down}^{(i-1)} = W_{down}^{(0)} + \eta\Delta S_i\) 并行算各 chunk 输出。再配合对 1D 卷积做因果 padding(保证某 chunk 的增量不含自身的未来信息)、以及在文档边界把快权重重置回预训练状态(防跨序列泄漏),这套并行扫描在数学上严格等价于逐步串行更新。结果是一个 CP-native、完全因果、可直接替换标准 MLP 块的模块。

损失函数 / 训练策略

快权重的内层目标是相似度损失 \(L = -\langle\cdot,\cdot\rangle_F\),对应更新式 \(W_{down}^{(i)} = W_{down}^{(i-1)} + \eta\hat{V}_{[i]}^\top Z_{[i]}\)。外层训练上:drop-in 实验对 Qwen3-4B-Base 做两阶段继续训练(先 ∼20B token / 32k 上下文,再 ∼15B token / 128k 上下文),并用 YaRN 扩展 RoPE;from-scratch 实验在 32k(500M/1.5B)或 8k、120B token(4B)上从零训练。

实验关键数据

主实验

Qwen3-4B-Base 上即插即用 In-Place TTT,在 RULER 长上下文 benchmark 上随着上下文变长优势越来越大,并能外推到 256k:

上下文 Baseline In-Place TTT 提升
16k 92.1 92.7 +0.6
32k 88.7 89.3 +0.6
64k 74.3 78.7 +4.4
128k 74.8 77.0 +2.2
256k(外推) 41.7 43.9 +2.2

跨模型族同样有效(RULER,64k 处增益最明显):

模型 方法 32k 64k
LLaMA-3.1-8B Baseline 91.1 81.6
LLaMA-3.1-8B In-Place TTT 91.7 83.7(+2.1)
Qwen3-14B Baseline 90.7 67.9
Qwen3-14B In-Place TTT 91.2 70.6(+2.7)

从头训练对比(4B,共识推理 + 长上下文):In-Place TTT 在 Full Attention 与 SWA 两种 backbone 下都全面提升,长上下文增益尤其大:

架构 配置 MMLU RULER-8k RULER-16k
Full Attn. Baseline 36.43 38.09 6.58
Full Attn. In-Place TTT 37.42 43.82 19.99
SWA Baseline 36.06 9.91 5.07
SWA In-Place TTT 36.48 26.80 7.57

在 500M 和 1.5B 规模上,In-Place TTT 的 Sliding Window Perplexity 在 2k~32k 全程低于 SWA / GLA / DeltaNet / LaCT 所有竞品,且困惑度随上下文延长持续下降。

消融实验

配置 关键指标 说明
State size 4× vs 1× vs 0.5× RULER ↑ 状态越大性能越好
Chunk size C=256/512/1024/2048 512~1024 最优 太小太大都掉点,呈 trade-off
w Conv, Proj(完整) 最佳 LM 对齐目标完整版
w/o Conv 掉点 去掉卷积(未来 token 聚合)
w/o Proj 掉点 去掉可训练投影 \(W_{target}\)
w/o Conv, Proj(退化成重建) 最差 退回通用重建目标

关键发现

  • LM 对齐目标里卷积和投影缺一不可:把两者都去掉退化成重建目标时长上下文得分明显下降,印证了「目标对齐」而非「目标存在」才是关键。
  • 大 chunk 不仅不掉点反而更优(512~1024),这与传统 TTT「必须小 chunk 才保性能」截然相反,根源是注意力承担了细粒度 mixing、TTT 只做互补压缩。
  • 增益随上下文长度单调放大:短上下文几乎持平,64k/128k 才显著拉开差距,说明它真正改善的是长程上下文利用,而非短文本能力。

亮点与洞察

  • 「快权重无约束」这一句话被用到极致:既然任意参数都能当快权重,那就别造新层,直接征用已经训好的 MLP 下投影——一招同时绕开了架构不兼容和从头重训两个最硬的坎,是非常漂亮的「免费午餐」式洞察。
  • 把效率枷锁的根源归到「定位」上:作者点破 TTT 低效的真因不是算法本身,而是「想取代注意力」这个野心强加的逐 token + 小 chunk 约束;改成「互补注意力」后,大 chunk 并行水到渠成。这个 reframing 可迁移到其它「想替换核心组件」的工作。
  • 目标对齐有理论背书:用 induction head 给出 logit 单调上升的下界,把「重建 vs NTP 对齐」从经验直觉抬到可证明,且实践中自然过渡到 Multi-Token Prediction。

局限与展望

  • 论文主要用语言建模/困惑度和 RULER 作为「长程演进任务」的代理,真实的持续学习 / 流式经验学习场景没有直接评测,离「像人一样从无界经验流中学习」的愿景仍有距离。
  • 损失函数和优化器只用了最简单的相似度 + 一步梯度,作者自承核心框架与具体 loss/优化器正交,更强的内层优化器(如带动量、更复杂的记忆参数化)留作未来工作,当前版本未必榨干性能。
  • 快权重只放在 MLP 下投影一处,是否该在更多矩阵 / 更多层上展开、状态规模与计算成本如何权衡,仍是开放问题(消融显示 state 越大越好,但成本边界没充分讨论)。
  • drop-in 仍需 ∼35B token 的继续训练,并非真正「零成本」插上即用。

相关工作与启发

  • vs 经典 TTT(Sun et al. 2020/2024):经典 TTT 用专用层替换注意力、逐 token 更新、重建目标;本文复用 MLP 不动注意力、大 chunk 并行、NTP 对齐目标——三处都反着来,换来即插即用与高吞吐。
  • vs LaCT(Large Chunk TTT):两者都走大 chunk,但 LaCT 仍作为独立 TTT 层建在 SWA 上;本文是原地复用 MLP,且在同样 SWA backbone 下困惑度更低。
  • vs 线性注意力类(GLA / DeltaNet):它们是 sub-quadratic 的注意力替代品做 token mixing;本文不替代注意力,而是给标准 Transformer 的 MLP 加在线更新,在 500M/1.5B 困惑度上全面更优。
  • vs YaRN / RoPE 外推:YaRN 改的是位置编码,本文改的是权重动态性,二者正交可叠加(实验中 64k+YaRN 仍有增益)。

评分

  • 新颖性: ⭐⭐⭐⭐⭐ 「复用 MLP 下投影当快权重」+「NTP 对齐目标」是干净且少见的组合,把 TTT 从重构性改造变成即插即用。
  • 实验充分度: ⭐⭐⭐⭐ drop-in(4B~14B 三个模型)+ from-scratch(500M~4B)+ 多消融,但缺真实持续学习场景与更大规模验证。
  • 写作质量: ⭐⭐⭐⭐⭐ 三道坎→三处设计→理论保证的逻辑链非常清晰,desiderata 框架尤其好读。
  • 价值: ⭐⭐⭐⭐⭐ 让现成 LLM 低成本获得长上下文/在线适应能力,对落地非常实用,且指向 LLM 持续学习这一大方向。