跳转至

Hardware-aligned Hierarchical Sparse Attention for Efficient Long-term Memory Access

会议: NEURIPS2025
arXiv: 2504.16795
代码: ant-research/long-context-modeling
领域: LLM效率
关键词: sparse attention, RNN, Mamba, long context, length generalization, chunk selection, hardware-aligned kernel

一句话总结

提出层次化稀疏注意力(HSA)及 RAMba 架构,通过两阶段 token-to-chunk 相关性学习与硬件对齐 kernel 设计,让 Mamba 获得高效长程随机访问能力,仅在 4K 上下文预训练即可在 64M passkey retrieval 上达到 100% 准确率。

研究背景与动机

  • RNN(如 Mamba)具有线性复杂度优势,但受限于固定维度隐状态的信息瓶颈,无法随机访问历史上下文
  • Transformer 的全注意力机制训练和推理代价随序列长度呈二次增长,且难以外推至超训练长度
  • 现有稀疏注意力(NSA、MoBA)采用 chunk 选择策略,但以 token 级梯度学习 chunk 重要性(chunk-unaware),导致 chunk 选择不准确
  • 将注意力直接嫁接到 RNN 上会破坏其效率优势,产生"效率-随机访问-长度泛化"三难困境
  • 推理阶段 KV cache 随序列长度线性增长,内存开销巨大,限制长上下文的实际部署
  • 已有方法(Landmark Attention、DCA)在外推超过 32× 训练长度后困难度骤增;GCA 虽可外推至 16M 但每 64 token 才检索一次,灵活性不足

方法详解

整体架构(RAMba)

  • 基于 Mamba-2,分为上下两个 decoder(各 L/2 层),中间共享一个 chunk selection 层
  • 下层 decoder 输出被分成大小为 S 的 chunk,经双向 Transformer encoder 编码为 chunk memory
  • 上层 decoder 交替使用 HSA 层和 G 个 Mamba 层,HSA 层从选中的 chunk 中做稀疏注意力

层次化稀疏注意力(HSA)

  1. Chunk Selection:每个 token 用下层 decoder 输出计算 query,与 chunk 的 mean-pooled key 做点积,选 top-K 个 chunk(每个 query group 独立选择)
  2. 第一阶段(token-level attention):对每个选中 chunk 内部独立做标准注意力,获得 chunk 级信息表示 O_{t,k}
  3. 第二阶段(chunk-level attention):用 stick-breaking 权重(无位置编码,仅依赖相关性分数的 sigmoid)对各 chunk 信息加权融合
  4. 端到端学习:反向传播时 chunk 权重根据整个 chunk 对 next-token prediction 的贡献调整,实现 chunk-aware 学习

硬件对齐 Kernel

  • 基于 Triton 实现:每个 GPU 线程处理一个 token 的 query 及其对应 K 个 chunk 的 KV 对
  • 使用 softmax-off-by-one 允许当前 chunk 内 token 忽略检索到的 token
  • 反向传播分两阶段:先累积 Q 和 w 的梯度,再累积 K 和 V 的梯度
  • 前向过程中每个线程初始化 O'=0,循环 K 个 chunk:加载 chunk 的 K/V 到 SRAM,计算 softmax 注意力,用 chunk 权重 w 加权累加
  • 反向 Phase 1(Algorithm 2):每个线程 t 遍历 K 个 chunk 计算 ∇Q 和 ∇w
  • 反向 Phase 2(Algorithm 3):每个线程 i 遍历所有选中该 chunk 的 token 累积 ∇K 和 ∇V
  • 整体设计避免了 naive 实现中每 token 对应不同 K 个 chunk 导致的巨大内存开销

训练与推理优化

  • Memory Reset:训练时随机替换前段 RNN 隐状态为零/随机段末态,打断 shortcut,增强长度泛化
  • Truncated BPTT:序列的初始 state 用上一条序列的末态初始化,配合 memory reset 提供多样化训练信号
  • KV Cache Offloading:推理时将 token 级 KV cache 卸载到 CPU,GPU 只保留紧凑的 chunk 级表示 K^slc(维度仅 ⌊L/S⌋×d)用于 chunk 选择;每步仅加载选中 chunk 的 KV
  • 共享 KV Cache:所有 HSA 层共享同一份源自中间层的 KV cache,每步仅需一次 chunk 选择和 CPU-GPU 交换,显著减少通信开销
  • 理论无限内存:K^slc 也可进一步卸载至 FAISS 数据库实现恒定 GPU 内存,但实际中其内存占用很小无需如此

实验关键数据

长程语言建模(370M 模型,4K 预训练)

模型 (370M) PG19 PPL (4K) PG19 PPL (64K) ArXiv PPL (64K) Code PPL (64K)
Transformer (full attn) 18.61 >10⁴ >10⁴ 2865.51
Mamba-2 17.92 17.30 3.86 3.05
Mamba + NSA (w/ m.r.) 17.87 17.31 3.87 3.05
Mamba + NSA (w/o m.r.) 17.74 17.62 4.35 3.28
RAMba (w/ m.r.) 17.82 17.01 3.65 3.07
RAMba (w/o m.r.) 17.63 17.11 3.87 3.21

下游任务与检索

任务 RAMba (w/ m.r.) Mamba-2 Transformer
Passkey Retrieval 64M 100% ~0% ~0%
RULER S-N 64K 85.07 10.45 0.00
RULER MQ-N 64K 55.22 0.00 0.00
RULER VT 64K 55.22 8.96 0.00
LongBench Overall 25.7 22.4 24.8
下游任务 AVG (SFT) 33.64 31.03 32.82
SQuaD EM/F1 48.24/59.17 41.33/52.03 45.50/56.13
HotpotQA EM/F1 22.30/30.53 18.70/26.20 21.90/29.49

亮点

  • 首个在 64M 上下文 passkey retrieval 达到 100% 准确率的 Mamba 模型,且仅用 4K 上下文预训练
  • HSA 前向速度比 NSA 快 3×,比 full attention 快 5-25×(16K+ 上下文)
  • 推理时内存几乎恒定(KV cache offload),理论上可处理无限长上下文
  • chunk-aware 两阶段注意力机制使 chunk 选择即使在超训练长度 10000× 仍然准确
  • Memory Reset 训练策略简单有效,同时适用于 NSA 和 HSA
  • 2.7B 规模验证 RAMba 仍保持显著优势(RULER 64K 多任务大幅领先 Mamba-2)
  • Stick-breaking 权重替代位置编码,天然适合长度外推
  • 消融实验充分:chunk encoder 的重要性、memory reset 的效果、softmax vs stick-breaking 的对比

局限与展望

  • 370M 规模实验为主,2.7B 实验仅展示部分结果,更大规模(7B/13B)验证不足
  • RULER 上更难的检索任务(MQ-N、VT)在超 256K 时性能显著下降,极长上下文精准 chunk 选择仍是开放问题
  • 双向 Transformer encoder 额外引入 5.4% 参数,chunk 编码增加预填充时间
  • CPU-GPU 内存交换虽可接受但仍可能成为大规模部署瓶颈(每步交换 g×dh×K×S 参数)
  • 模板/任务格式对长度泛化影响很大(passkey 64M vs S-N 仅 4M),鲁棒性有待提升
  • Chunk size S=64 固定设置,未探索自适应 chunk 大小的可能性
  • Memory Reset 虽然增强外推但牺牲了 in-domain 性能(4K PPL 略高)
  • 仅在 Pile 数据集上预训练,未在大规模多样化语料上验证

评分

  • 新颖性: ⭐⭐⭐⭐⭐ (两阶段层次化注意力 + chunk-aware 学习 + stick-breaking 权重,思路新颖且有心理学启发)
  • 实验充分度: ⭐⭐⭐⭐ (多任务多长度评测全面,消融实验丰富,但大模型规模验证偏少)
  • 写作质量: ⭐⭐⭐⭐ (结构清晰,算法伪代码详尽,kernel 设计解释充分,与 NSA/MoBA 的本质区别分析到位)
  • 价值: ⭐⭐⭐⭐⭐ (解决 RNN 长上下文随机访问核心瓶颈,64M 外推能力突破性,开源代码可复现)