跳转至

Long-Context Modeling with Dynamic Hierarchical Sparse Attention for On-Device LLMs

会议: NeurIPS 2025
arXiv: 2510.24606
代码: GitHub
领域: LLM效率 / 稀疏注意力 / 端侧部署
关键词: sparse attention, dynamic chunking, hierarchical sparsity prediction, on-device LLM, long context, chunk representation, boundary detection

一句话总结

提出动态分层稀疏注意力 (DHSA),通过自适应 chunk 分割 + chunk 级相似度预测 + 上采样到 token 级的分层框架,在不重训基座模型的前提下将密集注意力替换为稀疏注意力,在 Gemma2/3 上实现与密集注意力同等精度、20-60% prefill 延迟降低和 35% 峰值内存节省。

研究背景与动机

领域现状:长上下文建模是 LLM 的核心需求,但注意力机制的 \(O(L^2)\) 复杂度使得端侧设备难以处理长序列。稀疏注意力是主流优化方向。

现有痛点: - 静态稀疏方法(Longformer 的滑动窗口、BigBird 的全局 token)使用固定稀疏模式,无法适应不同输入的注意力分布变化 - 已有动态方法(MInference、LM-Infinite、H2O、Scissorhands)依赖预定义模板或启发式 KV cache 淘汰规则,缺乏通用性,会丢弃仍然重要的上下文 token

核心矛盾:高效性要求减少注意力计算,但精度要求保留关键 token 对交互。直接预测 \(L \times L\) token 级稀疏掩码本身就是 \(O(L^2)\),无法降低复杂度。

本文目标:设计一种无需重训的 plug-in 模块,动态预测注意力稀疏模式,同时适用于 prefill 和 decode 阶段。

切入角度:分层预测——先在 chunk 级 (\(N_c \times N_c\), \(N_c \ll L\)) 做粗粒度相似度估计,再上采样到 token 级做细粒度选择。

核心idea:chunk 级相似度可以用很低成本计算,且能有效代理 token 级重要性;配合自适应 chunk 边界预测和长度归一化,实现数据驱动的动态稀疏注意力。

方法详解

整体框架

DHSA 作为 plug-in 模块嵌入 Transformer 每一层。输入当前层的 token 嵌入,输出稀疏掩码 \(\mathbf{M} \in \{0,1\}^{L \times L}\)。核心流程:动态分割 chunk → 计算 chunk 级相似度 → 上采样到 token 级 → TopK 选择保留的 token 对。

关键设计

  1. 分层稀疏预测 (Hierarchical Sparsity Prediction)

    • 功能:将序列分为 \(N_c\) 个不重叠的 chunk,计算 chunk 级相似度矩阵 \(\mathbf{S}_c \in \mathbb{R}^{N_c \times N_c}\),上采样为 token 级相似度矩阵 \(\mathbf{S}_t \in \mathbb{R}^{L \times L}\),对每个 query token 做 TopK 选择(预算 \(N_b\)
    • 核心思路\(N_c \ll L\) 使得 chunk 级计算代价极低(\(O(N_c^2)\) 替代 \(O(L^2)\));同一 chunk 对内的 token 对共享同一重要性分数
    • 设计动机:直接预测 \(L \times L\) 掩码的代价等价于密集注意力,分层预测将复杂度降为 \(O(N_c^2 + L \cdot N_b)\)
  2. 动态边界检测 (Dynamic Boundary Detection)

    • 功能:用轻量神经网络预测每个 token 位置是否为 chunk 边界。编码器用 MHA 聚合左右窗口的 key 向量,特征融合拼接 \([\mathbf{k}_{\text{left}}, \mathbf{k}_{\text{right}}, |\mathbf{k}_{\text{left}} - \mathbf{k}_{\text{right}}|, \mathbf{k}_{\text{left}} \odot \mathbf{k}_{\text{right}}, \text{sim}(\mathbf{k}_{\text{left}}, \mathbf{k}_{\text{right}})]\),MLP 输出二分类概率
    • 核心思路:内容变化大的位置应成为 chunk 边界(语义分段),用左右窗口差异来检测
    • 设计动机:固定大小 chunk 太死板,一刀切无法适应文档内部的语义段落结构变化。自适应分割让每个 chunk 内部语义更一致,chunk 级相似度对 token 级重要性的代理更准确
  3. 鲁棒 Chunk 表示 (Robust Chunk Representation)

    • 功能:对 chunk 内 token 嵌入做平均池化,然后乘以 \(\sqrt{|\mathbf{C}|}\) 进行长度归一化
    • 核心思路\(\mathbf{q}_c = \sqrt{|\mathbf{C}|} \cdot \bar{\mathbf{q}}\)\(\mathbf{k}_c = \sqrt{|\mathbf{C}|} \cdot \bar{\mathbf{k}}\)。chunk 级相似度 \(\mathbf{S}_c = \mathbf{Q}_c \mathbf{K}_c^{\top}\)
    • 设计动机
      • 直接 padding 后平均会被零值稀释表示质量
      • 不同长度 chunk 的平均向量范数不同,导致相似度分数偏差。\(\sqrt{|\mathbf{C}|}\) 归一化消除长度对点积的影响
  4. Prefill 与 Decode 阶段适配

    • 功能:Prefill 阶段一次性预测全部边界并计算完整 \(\mathbf{S}_c\);Decode 阶段增量扩展边界并只计算新增行
    • 核心思路:Decode 时将之前生成的 token 作为一个额外 chunk,当前 token 单独作为一个 chunk,只需计算最后一行的 chunk 相似度
    • 设计动机:避免 decode 时重复计算已有 chunk 的相似度

损失函数 / 训练策略

  • 边界检测器使用二分类交叉熵损失训练,正样本为真实语义边界位置
  • DHSA 本身不需要重训基座模型,只需训练轻量的边界预测器
  • 支持跨层共享边界 (boundary sharing) 以进一步降低开销,但会略微影响精度

实验关键数据

主实验 — LongBench (Gemma2-2b-it, budget=2k)

方法 NrtvQA Qasper Mf-en HotpotQA 2WikiMQA Musique GovReport QMSum MultiNews TriviaQA SAMSum
Dense 22.37 35.32 37.32 41.63 32.05 19.05 27.08 21.08 25.48 87.00 41.26
Block Sparse 16.74 26.15 32.83 35.74 31.93 14.44 26.20 19.54 25.30 86.12 40.38
DHSA 20.69 30.20 34.98 38.78 31.96 15.90 26.75 20.74 25.38 87.03 41.46

消融实验 — 延迟与内存 (NarrativeQA, Gemma2)

注意力实现 方法 精度(%) 延迟(s) 峰值内存(GB)
eager Dense 21.15 1.65 10.72
eager Block Sparse 17.04 1.00 9.08
eager DHSA 20.12 1.19 6.91
torch.sdpa Dense 22.37 1.10 6.33
torch.sdpa Block Sparse 16.74 0.88 9.88
torch.sdpa DHSA 19.37 0.91 6.99

关键发现

  1. 精度保持:在 LongBench 11 个子任务中,DHSA 在 10 个上优于 block sparse,在 TriviaQA 和 SAMSum 上甚至超过 dense attention。Needle-in-a-Haystack 测试中 DHSA (1k budget) 与 dense 表现完全一致。
  2. 内存优势显著:eager 模式下 DHSA 峰值内存仅 6.91GB,比 dense 降低 35.5%,比 block sparse (9.08GB) 也低 24%。
  3. 延迟竞争力:torch.sdpa 模式下 DHSA 延迟 0.91s 仅比 block sparse (0.88s) 慢 3%,但精度高出 2.6 个百分点。
  4. 长上下文扩展:16k 和 32k 序列长度下,dense eager OOM,DHSA 正常运行且延迟仅为 sdpa dense 的 ~40-60%。
  5. 边界共享权衡:跨层共享边界 (DHSA+bs) 进一步降低开销,但部分任务精度略降(如 Mf-en 从 34.98 降至 31.20)。

亮点与洞察

  • 分层预测是核心创新:绕过了"预测 \(L^2\) 稀疏掩码本身就是 \(O(L^2)\)"的悖论。Chunk 级预测将搜索空间压缩了 \((L/N_c)^2\) 倍,是该方法能实际加速的关键。
  • 完全数据驱动:不依赖预定义的注意力模式模板(如 A-shape、vertical-slash),通过学习自动发现输入相关的稀疏模式。这使得 DHSA 在不同任务间有更好的泛化性。
  • \(\sqrt{|\mathbf{C}|}\) 归一化看似简单但很关键:解决了变长 chunk 的表示偏差问题。粗暴地用均值池化会让长 chunk 的相似度分数与短 chunk 不可比。
  • Plug-in 设计:不修改原模型权重、不需要重训,对已部署的端侧模型有极大的实用价值。

局限与展望

  1. 延迟不是绝对最优:DHSA eager 模式 1.19s 比 block sparse 1.00s 慢 19%,主要是边界预测和 chunk 表示计算的开销。端侧部署需要进一步优化这部分算子。
  2. 未扩展最大上下文长度:作者提到 Gemma 系列缺乏可靠的上下文扩展实现,限制了进一步验证。
  3. 边界检测器训练:需要预先标注语义边界数据来训练,增加了部署门槛。如能无监督学习边界效果更佳。
  4. 超参数依赖:chunk budget \(N_b\) 和 chunk 大小仍需手动调节,自适应学习这些超参是重要方向。
  5. 仅验证小模型:实验限于 Gemma2-2b 和 Gemma3-1b,更大模型的效果待确认。

相关工作与启发

  • Longformer / BigBird:经典静态稀疏注意力,用固定滑窗 + 全局 token,缺乏对输入的适应性
  • MInference:动态稀疏但依赖预定义模式(A-shape、vertical-slash),本文完全数据驱动
  • H2O / Scissorhands:基于 KV cache 淘汰的动态方法,但淘汰准则是启发式的
  • PyramidKV:金字塔式 KV cache 压缩,思路互补
  • Block Sparse Attention (Han Lab):MIT Han Lab 的块稀疏实现,是本文的主要 baseline
  • 启发:分层预测的思想可以推广到其他注意力变体(如 cross-attention、multi-query attention);自适应分割也可应用于 RAG 中的文档分块

评分

  • 新颖性: ⭐⭐⭐⭐ — 动态分层预测 + 自适应边界检测 + 长度归一化的组合设计新颖,每个组件都有清晰的设计动机
  • 实验充分度: ⭐⭐⭐⭐ — Needle-in-a-Haystack + LongBench 多任务 + 延迟/内存分析 + 不同注意力实现的对比,评估维度全面
  • 写作质量: ⭐⭐⭐⭐ — 方法动机和流程描述清晰,分层预测的直觉解释得好
  • 价值: ⭐⭐⭐⭐ — 端侧长上下文 LLM 部署的实用方案,plug-in 设计降低了工程门槛