MoM: Linear Sequence Modeling with Mixture-of-Memories¶
会议: ICLR 2026
OpenReview: https://openreview.net/forum?id=3PdOq8Rgue
代码: 待确认
领域: 线性序列建模 / 高效注意力架构
关键词: 线性注意力, Mixture-of-Memories, 召回密集任务, 记忆干扰, Gated DeltaNet, Test-Time Training
一句话总结¶
MoM 用一组相互独立的记忆状态 + 路由网络替换线性模型里那个唯一的固定大小记忆,让不同 token 只更新各自被分配的记忆,从而在保持线性复杂度的同时大幅扩容记忆、消除写入干扰,把召回密集任务做到逼近 Transformer。
研究背景与动机¶
- 领域现状:为了摆脱 Transformer 的 \(O(n^2)\) 复杂度,线性注意力、状态空间模型(Mamba)、线性 RNN 等方法把整条序列压进一个固定大小的矩阵记忆 \(M\),做到 \(O(n)\) 训练、\(O(1)\) 推理。它们普遍可写成递归形式 \(M_t = M_{t-1} + k_t^\top v_t,\ o_t = q_t M_t\)。
- 现有痛点:把整条序列塞进单个固定记忆带来两个硬伤——记忆容量有限和记忆干扰。新信息以累加方式覆盖旧记忆,正交或新颖的输入会污染已存内容,导致在 FDA、SWDE、SQuAD 这类召回密集任务上和 Transformer 差一大截。
- 核心矛盾:Transformer 之所以强,正是因为它给每个 token 保留独立 KV cache、几乎无干扰、容量近乎无限;而线性模型靠极致压缩省了算力,却也丢了这种"分而治之"的能力。单纯把单个 RNN state 调大(expand)治标不治本——一个膨胀的记忆仍难以同时承载多个互相正交的信息侧面。
- 本文目标:在 Transformer 的"显式 token 表示"和线性模型的"极致压缩"之间找平衡点,既要扩容、去干扰,又不能丢掉线性训练 / 常数推理的效率红利。
- 核心 idea:[受神经科学启发的多记忆架构] 借鉴海马体 theta-gamma 振荡的多项记忆编码(每个 gamma 子周期激活一组不同神经元,时间上分离记忆项以防干扰)和 MoE 的稀疏路由思想,提出 Mixture-of-Memories:维护多个独立记忆状态,路由器把每个 token 只送给 top-k 个记忆去更新,最后按重要度加权混合读出。
方法详解¶
整体框架¶
MoM 把线性层里"唯一记忆 \(M\)"换成"一组记忆 \(\{M^1,\dots,M^M\}\) + 路由器"。每个输入 token 先经路由器算出重要度分数、选出 top-k 个记忆;被选中的记忆各自用自己的 KV 投影做 RNN 式更新,未被选中的记忆原封不动(这正是去干扰的关键);读出时把激活的记忆按路由权重加权求和成"混合记忆"\(\tilde M_t\),再用共享的 query 查询。另外挂一条对全序列始终激活的共享记忆兜底长程依赖。整套机制对记忆更新规则不挑食,能即插即用各种线性模型(Linear Attn / GLA / DeltaNet / Gated DeltaNet / Mamba2 / RWKV…)。
flowchart LR
X[输入 token x_t] --> R[Router 打分 + TopK softmax]
X --> SM[共享记忆 KV Proj]
R -->|选中 top-k| KV1[KV Proj 1..N 各自投影]
KV1 --> U[各激活记忆独立更新<br/>M_t^m = update M_t-1^m, k_t^m, v_t^m]
U --> MIX[按路由权重加权混合<br/>M̃_t = Σ g_t^m M_t^m]
SM --> MIX
MIX --> Q[q_t · M̃_t]
Q --> O[归一化 + 线性变换 → 输出 o_t]
关键设计¶
1. 路由器:把 token 稀疏分派给记忆,让每个记忆只接收同类信息。 路由器是个简单线性层 \(W_g\in\mathbb{R}^{d\times M}\),对每个 token 算分数后 softmax、取 top-k 并归一化:\(\text{scores}_t=\text{TopK}(\text{softmax}(x_tW_g))\in\mathbb{R}^k,\ g_t=\text{scores}_t/\sum\text{scores}_t\)。这一步把"一个记忆扛全部信息"变成"不同记忆各管一摊",是容量扩张和干扰消除的源头——后文 UMAP 可视化显示路由确实把输入按特征聚成了若干簇,每个记忆专精一个子分布。
2. 独立记忆更新 + 未激活记忆冻结:干扰消除的机制本质。 对每个被激活的记忆 \(m\),用其专属投影 \(W_k^m,W_v^m\) 算出 \(k_t^m,v_t^m\),再做记忆更新 \(M_t^m = M_{t-1}^m + (k_t^m)^\top v_t^m\)。关键在于没被路由到的记忆这一步完全不动,于是当前 token 的新信息绝不会写进与它无关的记忆里——这与 MoE 的"专家"思想同构,但这里的"专家"不是独立网络,而是嵌在线性递归里的一个个 RNN 状态。更新规则可任意替换:论文给出一张对照表,把 RetNet 的衰减 \(\gamma M_{t-1}\)、GLA / HGRN2 的数据相关门控、DeltaNet 的 \((I-k_t^\top k_t)\) 删除项、Gated DeltaNet、Mamba2、RWKV7 等统一纳入 \(M_t = \text{(gate)}\,M_{t-1}+\text{(write)}\) 的框架,因此 MoM 是正交于这些工作的通用增强。
3. 加权混合读出 + 共享记忆:把"分散的记忆"重新聚合成可查询的整体。 更新完后用路由分数把激活记忆加权求和成混合记忆 \(\tilde M_t = \sum_m g_t^{(m)} M_t^m\),再用 query 读出 \(o_t = q_t\tilde M_t\)。值得注意的是,"先混合再乘 query" 与 "先各自乘 query 再混合" 数学等价,这给了硬件实现极大便利。同时一条始终激活的共享记忆看遍全序列,专门承接长程上下文,弥补稀疏路由可能漏掉的全局信息。
4. 硬件高效实现:把多记忆计算化简为 varlen 单核调用。 朴素实现会因为多记忆而成本翻倍。MoM 利用上述等价性,按路由结果重排 token使同一记忆的 token 连续排列,拼成变长(varlen)序列后用现成的 Triton 线性算子一次性处理,算完再按权重聚合、还原回原始顺序。形式上,对 batch \(b\)、记忆 \(m\) 收集索引集 \(I_{b,m}\)、拼成扁平序列 \(\tilde X\) 并记录累积边界 \(s\),对每段施加记忆专属核 \(F_m\),最后按 \(y_{b,t}=\sum_m \alpha_{b,t,m}\hat o_{b,t,m}\) 重建。这样 MoM 直接复用前人线性模型的高效算子,保持线性训练 / 常数推理。
实验关键数据¶
配置:以 Gated DeltaNet 作记忆更新机制,4 个记忆、每步激活 2 个、外加 1 个共享记忆;在 SlimPajama 上从零训练 380M(15B tokens)与 1.3B(100B tokens)。
主实验表格(召回密集任务,截断 2K,分数越高越好)¶
| Scale | Model | FDA | SWDE | SQuAD | NQ | TriviaQA | Drop | Avg. |
|---|---|---|---|---|---|---|---|---|
| 380M | Transformer++ | 46.14 | 25.87 | 33.22 | 18.94 | 45.97 | 20.03 | 31.70 |
| 380M | Gated DeltaNet | 20.53 | 23.24 | 28.55 | 14.98 | 44.91 | 16.48 | 24.78 |
| 380M | MoM | 22.98 | 29.90 | 29.69 | 16.60 | 48.82 | 20.99 | 28.16 |
| 1.3B | Transformer++† | 44.32 | 32.43 | 42.59 | 24.49 | 58.47 | 21.56 | 37.31 |
| 1.3B | Gated DeltaNet | 30.25 | 27.65 | 34.06 | 23.22 | 58.23 | 20.36 | 32.30 |
| 1.3B | MoM | 41.14 | 34.30 | 37.08 | 24.11 | 58.59 | 21.03 | 36.04 |
MoM 在两个规模上都显著超越所有线性基线;1.3B 时平均分 36.04 已逼近 Transformer++ 的 37.31,把线性模型与 Transformer 在召回任务上的差距几乎抹平。LongBench 上 MoM 平均 15.64,也优于 GSA(14.61)、Gated DeltaNet(13.98) 等。
消融实验表格(混合记忆 vs 单记忆扩容,召回密集任务 Avg.)¶
| Model | Params | Recall Avg. |
|---|---|---|
| GLA expanded | 425M | 22.87 |
| GLA MoM | 395M | 23.53 |
| Gated DeltaNet expanded | 550M | 26.32 |
| Gated DeltaNet MoM | 444M | 28.16 |
在常识推理任务上同样如此(Gated DeltaNet MoM 444M 取 41.97,胜过 expanded 550M 的 41.32)。关键是:MoM 用更少的参数(444M vs 550M)打赢了"单纯把单记忆调大"的做法,证明增益来自"分而存之"的去干扰,而非单纯堆容量。即便在严格对齐激活参数(均 400M)下,MoM 召回 Avg. 26.51 仍高于 Gated DeltaNet 的 24.78。
关键发现¶
- 去干扰才是真增益:相同(甚至更少)参数下,分立记忆稳定优于膨胀单记忆,说明性能来自消除写入干扰而非扩容本身。
- 效率保持线性:推理时延 / 显存随序列长度线性增长,Transformer++ 在长序列直接 OOM,MoM 仍可扩展到 512K。
- 外推更稳:在 2K 上训练、外推到 32K 测 ppl,Transformer++ 急剧上升,MoM 在所有线性模型中 ppl 最低。
- 记忆自发专精:UMAP 显示路由把 token 隐状态聚成清晰簇,每个记忆专攻一个子分布;作者据此给出 TTT 视角——MoM 等价于一种"测试时集成学习",每个记忆只需拟合更简单的 \(k\to v\) 子映射,配 auxiliary loss 后各记忆负载基本均衡。
亮点与洞察¶
- 范式上的差异化:现有线性模型靠"门控/删除"被动减少干扰(丢信息),MoM 靠"分离存储"主动避免干扰(保信息),这是一个新的、与门控正交的方向。
- 通用插件:不绑定任何特定更新规则,能给 GLA、DeltaNet、Gated DeltaNet、Mamba2、RWKV 等普遍套上,落地成本低。
- 生物 + 工程双驱动:theta-gamma 多项记忆编码给了直觉,MoE 稀疏路由给了实现,varlen Triton 重排让"多记忆"不增算力,三者串得很顺。
- TTT 视角自洽:把"路由 = 动态聚类、每个记忆 = 子分布专家"解释清楚,给方法补上了理论叙事。
局限与展望¶
- 不是纯零成本:稀疏激活只加在 KV 投影上,激活参数仍有小幅上升(作者在附录专门讨论公平性),严格等参对比下召回增益相对常识推理增益更明显,长程摘要类任务(LongBench Sum)提升有限。
- 超参敏感性:记忆数、top-k、共享记忆的最优配置主要在 4/2/1 上验证,更大规模、更多记忆下的 scaling 行为与负载均衡还需更系统的探索。
- 依赖 auxiliary loss 保均衡:去掉负载均衡损失后路由是否退化、是否出现记忆坍塌,文中未充分压力测试。
- 规模上限:实验最大到 1.3B / 100B tokens,能否在 7B+ 真正逼平 Transformer 的召回能力仍是开放问题。
相关工作与启发¶
- 线性序列建模谱系:Linear Attention、RetNet、GLA、HGRN2、DeltaNet、Gated DeltaNet、Mamba2、RWKV6/7 等,被本文统一成记忆更新规则视角,是 MoM 的"宿主"。
- Mixture-of-Experts:Switch Transformer 等的 top-k 路由 + auxiliary loss 直接被借用,但 MoM 的"专家"是 RNN 状态而非独立 FFN。
- Test-Time Training / Titans:MoM 的 UMAP 专精分析与 TTT 集成解释,把它接到"测试时拟合 \(k\to v\)"这条线上。
- 启发:①"扩容不如分治"对所有压缩式记忆架构都是提醒;②路由 + varlen 重排是一种把稀疏结构低成本落到线性算子上的通用工程范式,可迁移到其他 SSM / 线性 RNN 变体。
评分¶
- 新颖性: ⭐⭐⭐⭐ — 用多记忆 + 稀疏路由替换单一记忆,开辟了不同于门控的去干扰新范式,且把记忆更新规则统一成可插拔框架。
- 实验充分度: ⭐⭐⭐⭐ — 召回 / 长上下文 / 常识 / 外推 / 效率 / 负载均衡 / 专精分析覆盖全面,且做了等参、单记忆扩容等关键对照;规模止于 1.3B 略欠。
- 写作质量: ⭐⭐⭐⭐ — 动机(生物启发 + 工程)清晰,方法递进自然,硬件实现讲得明白,图表到位。
- 价值: ⭐⭐⭐⭐ — 把线性模型在召回密集任务上做到逼近 Transformer,同时保住线性效率,对高效长序列架构是实打实的推进,且即插即用易被后续工作采纳。