Short Window Attention Enables Long-Term Memorization¶
会议: ICLR 2026
OpenReview: https://openreview.net/forum?id=btgVfhudI1
代码: 待确认
领域: LLM效率 / 长上下文 / 混合架构
关键词: 滑动窗口注意力, xLSTM, 线性 RNN, 混合架构, 长上下文记忆
一句话总结¶
本文用「滑动窗口注意力 + xLSTM 线性 RNN」交替的混合架构 SWAX 研究短/长程记忆的分工,发现一个反直觉结论——滑动窗口越短,长上下文检索反而越好(因为短窗口逼着线性 RNN 去学长程依赖),并据此提出随机窗口训练(每个 batch 随机用 128 或 2048 的窗口),让模型在短上下文和长上下文任务上同时拿到最优。
研究背景与动机¶
领域现状:现代 LLM 靠 softmax 注意力的 KV Cache 当工作记忆,长上下文性能很强,但 KV Cache 随序列长度线性膨胀,算力和显存都失控。另一条路是线性 RNN(SSM、线性注意力、xLSTM 等),用一个固定大小的隐状态迭代更新,每个 token 的算力和显存恒定、与序列长度无关,但召回(recall)精度一直比不上 Transformer。近期的主流折中是把两者拼成混合架构:大部分层用固定状态大小的组件(滑动窗口注意力 SWA 或线性注意力),少数层保留全局 softmax 注意力。
现有痛点:保留全局注意力层的混合架构,仍然带着那少数层 \(O(S)\) 的状态和 FLOPs 增长;而纯粹由「固定状态」组件构成的混合架构(如 De et al. 2024 把线性注意力和 SWA 拼起来)虽然算力恒定,却有一个被忽视的问题——滑动窗口长度怎么选。已有工作只用验证困惑度(PPL)来挑窗口大小,得出「窗口越长越好,选窗口纯粹是性能 vs 算力的取舍」的结论,但从没考察过窗口长度对长上下文检索能力的影响。
核心矛盾:在 SWA + 线性 RNN 的混合架构里,两类层的「记忆分工」是隐式形成的。如果 SWA 窗口足够长,训练时绝大多数依赖都落在窗口内,模型会偷懒——优先用更精确的局部 softmax 注意力,而很少训练线性 RNN 去建模长程依赖。结果是:PPL 和短上下文任务看起来都不错,但一旦测试序列超出窗口长度,模型「从没学过靠线性 RNN 做长程检索」,长上下文性能直接崩。
本文目标:① 系统刻画滑动窗口长度对短/长上下文任务的真实影响;② 找到一种训练方式,既保住长窗口带来的短上下文精度,又拿到短窗口带来的长上下文外推能力。
核心 idea:用短窗口(甚至随机切换的窗口)当一种「正则」,逼迫线性 RNN 层接收更多长程依赖的监督信号、专心建模长期记忆,而不是把活全包给局部 softmax 注意力。
方法详解¶
整体框架¶
本文的研究对象是 SWAX——一种交替堆叠滑动窗口注意力(SWA)层和 xLSTM(mLSTM 矩阵记忆)层的混合架构,两类层按 1:1 比例交替。SWA 层是固定窗口 \(w\) 的 softmax 注意力,负责高精度建模最近 \(w\) 个 token 的局部依赖;xLSTM 层是固定状态大小的线性 RNN,靠一个矩阵记忆 \(H_t\) 迭代更新,负责承载无限感受野的长程依赖。两者都是「固定状态、固定每 token 算力」的组件,所以整个模型的算力和显存不随序列长度增长。
围绕这个骨干,本文真正的贡献是怎么训练它:先揭示「窗口越短、长上下文越好」这个反直觉规律,再用随机窗口训练把短/长窗口的好处合二为一——训练时每个 batch 随机决定用短窗口(128)还是长窗口(2048),并在训练末段退火(停止采样短窗口)以恢复长窗口的短上下文精度;测试时统一用长窗口 2048。
%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400}}}%%
flowchart TD
A["输入序列<br/>(训练 16k, 测试外推到 131k)"] --> B["SWAX 混合骨干<br/>SWA 层 ⇄ xLSTM 层 1:1 交替"]
B --> C["短窗口逼出长程记忆<br/>窗口越小, xLSTM 越被迫学长程"]
C --> D["随机窗口训练<br/>每 batch 采样 128/2048 + 末段退火"]
D --> E["输出:短上下文 & 长上下文兼优"]
关键设计¶
1. SWAX 混合骨干:用「局部 softmax + 全局线性记忆」做记忆分工
针对「纯线性 RNN 召回差、纯 softmax 注意力算力爆炸」的两难,SWAX 把两类都只有固定状态大小的组件交替堆叠:SWA 层是窗口为 \(w\) 的 softmax 注意力,把每 token 的复杂度从 \(O(S)\) 降到 \(O(w)\);xLSTM(mLSTM 单元)维护矩阵记忆,读写规则可写成线性注意力的形式——记忆 \(H_t = \sum_t \phi(k_t) v_t^\top\) 增量更新,读出 \(y_t = \phi(q_t)^\top H_t\),算力恒定 \(O(d_{qk}\times d_v)\)、与序列长度无关。本文选 xLSTM 是因为它已扩到 7B、有高效 Triton kernel,且在语言任务上 mLSTM 优于 sLSTM。
这个分工之所以有效:大多数预测下一个 token 所需的信息来自局部依赖,纯线性模型不得不把大部分层都耗在建模局部依赖上、留给长程的层很少;而在 SWAX 里,局部依赖被路由给更精确的 SWA 层,于是 xLSTM 层得以专门化去建模长程依赖。实验也复现了 De et al. (2024) 的反直觉现象:混合架构虽然全局感受野的层更少,长上下文召回却优于纯 SWA 或纯 xLSTM。
2. 短窗口逼出长程记忆:窗口长度决定线性层收到多少长程监督
这是本文的核心发现,直指痛点——窗口太长会让 xLSTM 层在训练时「吃不到」长程任务的监督。作者用 \(\{128,256,512,1024,2048\}\) 五种固定窗口训练 SWAX,发现一个被以往「窗口越长越好」结论完全掩盖的规律:用 PPL 和短上下文 reasoning 评估,窗口 2048 确实最好;但一测长上下文检索(RULER NIAH),窗口 2048 反而掉得最狠。在 131k 序列长度上,短窗口(128/256/512)的 NIAH 召回还有约 30%,而窗口 2048 的召回接近 0%。平均到所有序列长度和 NIAH 子任务,窗口 128 反而是所有窗口里最好的,比 2048 高 16 个准确率点(相对提升 88.9%)。
机理是:窗口为 2048 时,训练里绝大多数依赖都落在窗口内,模型选择用更精确的局部 softmax 而非线性层去建模这些依赖,于是从没学会靠 xLSTM 做长程检索,一旦测试序列让依赖跑出窗口就无法外推;反之窗口短时,很多依赖落在窗口外,模型被迫让 xLSTM 去传递信息,长程记忆因此被训练出来。这也纠正了一个常见误解——短窗口不只是为了省算力或提硬件利用率,它本身就是让线性层承担长程建模的关键监督来源。
3. 随机窗口训练 + 末段退火:短长窗口的好处兼得
短窗口虽利于长上下文,但对短上下文 reasoning 有害——窗口 128 连不少短任务的 prompt 都装不下,短上下文得分明显偏低;而要在测试时用大窗口(拿短上下文精度),又必须在训练时见过大窗口(否则朴素放大窗口会因 RoPE 而灾难性崩溃)。本文用随机窗口训练化解这个两难:训练时每个新 batch 以概率 \(p\) 把窗口设为 128、否则用默认的 2048(1.4B 用 \(p=0.5\),7B 用 \(p=0.75\)),逼模型既不过度依赖长 SWA 窗口、又保留用大窗口的能力。再加一个退火——训练最后 10% 不再采样短窗口、固定用 2048,显著提升短上下文性能而不损害长上下文。测试统一用窗口 2048。
作者把这种「随机降低窗口容量、增强鲁棒性」类比为对注意力机制做 dropout。效果上,随机窗口在短上下文上与固定 2048 相当甚至更好,在长上下文上与固定短窗口 128 相当甚至更好,真正做到了「鱼与熊掌兼得」。这也反证了长窗口混合架构长上下文差不是测试时大窗口的锅,而是训练流程的锅——只要训练时被迫间歇性用短窗口,线性层就会被调动起来。
实验关键数据¶
实验聚焦语言建模,主力用 1.4B 模型(24 层、维度 2048)、并在 7B(32 层、维度 4096)上验证;全部从头在 16k 序列长度上训练 150B token,不做任何长上下文专门微调。长上下文用 RULER 的 needle-in-a-haystack(NIAH)评测,短上下文用一组 reasoning / commonsense / 代码 benchmark。
主实验¶
不同固定窗口的 1.4B SWAX 在「PPL / 短上下文 / 长上下文」上的对照(节选 Table 1 + Figure 5/6):
| 配置 | 验证 PPL ↓ | 短上下文均分 ↑ | 长上下文 NIAH(@131k) |
|---|---|---|---|
| xLSTM(纯线性) | 2.602 | 38.93 | 较弱 |
| SWAX:128 | 2.551 | 39.81 | 约 30%(最佳) |
| SWAX:512 | 2.546 | 40.69 | 较好 |
| SWAX:2048 | 2.523 | 40.88 | 接近 0%(最差) |
可以看到一个清晰的矛盾:PPL 和短上下文均分都是窗口越长越好(2048 最优),但长上下文召回完全相反——窗口 128 在 131k 上还有约 30%,窗口 2048 几乎归零,平均 NIAH 上 128 比 2048 高 16 个点。
随机窗口训练把两端拉齐(节选 Table 2):
| 模型 | 训练窗口 | 测试窗口 | 短上下文均分 ↑ | 长上下文 |
|---|---|---|---|---|
| SWAX 1.4B | 128 | 128 | 39.81 | 好 |
| SWAX 1.4B | stochastic | 2048 | 40.81 | 与 128 持平 |
| SWAX 1.4B | 2048 | 2048 | 40.88 | 差 |
| SWAX 7B | stochastic | 2048 | 49.52 | 优于固定 2048 |
| SWAX 7B | 2048 | 2048 | 49.32 | 差 |
随机窗口的验证 PPL(1.4B 为 2.502)甚至低于所有固定窗口,短上下文与固定 2048 相当或更好,长上下文则与固定 128 相当——1.4B 和 7B 上都成立。
消融实验¶
| 配置 | 关键现象 | 说明 |
|---|---|---|
| 训练窗口 vs 测试窗口(Figure 6) | 测试窗口 > 训练窗口 → 灾难性崩溃 | RoPE 下朴素放大窗口不可行,必须训练时见过大窗口 |
| 退火(末段 10% 固定 2048) | 短上下文显著提升、长上下文不掉 | 让模型学会用大测试窗口 |
| 采样概率 \(p\)(附录 B) | 1.4B 用 0.5、7B 用 0.75 | 控制短窗口被采样的比例 |
| 换 Gated DeltaNet 当线性层(Table 3) | 同样规律成立 | 结论不绑定 xLSTM |
| local-global(SWA + 全注意力交替,Figure 9) | 短层应小才不拖累长层 | 结论可推广到另一类混合架构 |
关键发现¶
- 窗口长度是「长程监督量」的开关:窗口越短,越多依赖落在窗口外,xLSTM 被迫学长程,长上下文外推越好;这与 PPL/短上下文给出的「窗口越长越好」完全相反。
- 随机窗口 ≈ 注意力 dropout:间歇性强制用短窗口防止过度依赖 SWA,等价于随机削减模型容量来提升鲁棒性。
- 结论可迁移:换成 Gated DeltaNet 线性层、换成 local-global(SWA+全注意力)架构,「短窗口/随机窗口更利于长上下文」的结论都成立。
- 在更真实的 LongBench / LongBench2 / Babilong 上,随机训练多数场景胜出,但个别场景固定 2048 更好——对这么小的模型这些任务整体偏难。
亮点与洞察¶
- 把「窗口大小」从单纯的算力旋钮重新定义为「记忆分工旋钮」:以前大家以为短窗口只是省算力/提硬件利用率,本文证明它实质上决定了线性层能收到多少长程监督——这是一个观念层面的纠偏。
- 反直觉但机理清晰:PPL 越好 ≠ 长上下文越好,因为 PPL 测的是局部预测,长程外推靠的是被「逼出来」的线性记忆;这提醒社区别再只用 PPL 挑架构超参。
- 随机窗口 = 一行训练 trick:不改架构、不加参数、不加测试开销,只在训练时随机切窗口 + 末段退火,就同时拿下短/长上下文,工程上极易复用到任意「SWA + 线性/全局」的混合架构。
局限与展望¶
- 主要在 1.4B / 7B、150B token、16k 训练长度的设定下验证,更大规模、更长训练序列下规律是否依旧需进一步确认。
- 在真实长上下文 benchmark(LongBench/LongBench2/Babilong)上随机训练并非全面占优,部分任务固定长窗口更好,说明合成的 NIAH 与真实长文任务存在 gap。
- 随机窗口的概率 \(p\)、退火比例、短/长窗口的具体取值(128/2048)都是经验选的,缺少一套自适应或可解释的选择准则。
- 结论建立在「1:1 交替、inter-layer 混合」这一简单设计上,其他混合比例 / intra-layer 混合方式是否同样适用未充分探索。
相关工作与启发¶
- vs De et al. (2024):他们同样混合线性注意力与 SWA,但只用 PPL 挑窗口、得出「窗口越长越好」;本文指出他们漏掉了长上下文这一维度,并给出相反结论——长上下文上短窗口更好。
- vs 保留全局注意力的混合架构(如 local-global):那类架构仍带着全局层 \(O(S)\) 的状态/算力增长;SWAX 全部用固定状态组件,每 token 算力恒定,本文还把短窗口结论推广验证到了 local-global 架构。
- vs Memory Mosaic(Zhang & Bottou, 2025):他们也用随机注意力掩码,但作用在长期记忆层上;本文把随机窗口显式作用在 SWA 层,目的就是削弱模型对 SWA 做长程召回的过度依赖,动机更对症。
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ 把窗口大小重新诠释为记忆分工旋钮,并给出与主流相反的清晰结论
- 实验充分度: ⭐⭐⭐⭐ 覆盖多窗口、1.4B/7B、多 benchmark 及两类线性层/两类混合架构,但更大规模与真实长文仍有 gap
- 写作质量: ⭐⭐⭐⭐⭐ 机理叙述清楚,反直觉发现层层递进、论证扎实
- 价值: ⭐⭐⭐⭐⭐ 一行训练 trick 即插即用,对长上下文高效架构的设计有直接指导意义