Long-Context Modeling with Dynamic Hierarchical Sparse Attention for On-Device LLMs¶
会议: NeurIPS 2025
arXiv: 2510.24606
代码: GitHub
领域: LLM效率 / 稀疏注意力 / 端侧部署
关键词: sparse attention, dynamic chunking, hierarchical sparsity prediction, on-device LLM, long context, chunk representation, boundary detection
一句话总结¶
提出动态分层稀疏注意力 (DHSA),通过自适应 chunk 分割 + chunk 级相似度预测 + 上采样到 token 级的分层框架,在不重训基座模型的前提下将密集注意力替换为稀疏注意力,在 Gemma2/3 上实现与密集注意力同等精度、20-60% prefill 延迟降低和 35% 峰值内存节省。
研究背景与动机¶
领域现状:长上下文建模是 LLM 的核心需求,但注意力机制的 \(O(L^2)\) 复杂度使得端侧设备难以处理长序列。稀疏注意力是主流优化方向。
现有痛点: - 静态稀疏方法(Longformer 的滑动窗口、BigBird 的全局 token)使用固定稀疏模式,无法适应不同输入的注意力分布变化 - 已有动态方法(MInference、LM-Infinite、H2O、Scissorhands)依赖预定义模板或启发式 KV cache 淘汰规则,缺乏通用性,会丢弃仍然重要的上下文 token
核心矛盾:高效性要求减少注意力计算,但精度要求保留关键 token 对交互。直接预测 \(L \times L\) token 级稀疏掩码本身就是 \(O(L^2)\),无法降低复杂度。
本文目标:设计一种无需重训的 plug-in 模块,动态预测注意力稀疏模式,同时适用于 prefill 和 decode 阶段。
切入角度:分层预测——先在 chunk 级 (\(N_c \times N_c\), \(N_c \ll L\)) 做粗粒度相似度估计,再上采样到 token 级做细粒度选择。
核心idea:chunk 级相似度可以用很低成本计算,且能有效代理 token 级重要性;配合自适应 chunk 边界预测和长度归一化,实现数据驱动的动态稀疏注意力。
方法详解¶
整体框架¶
DHSA 作为 plug-in 模块嵌入 Transformer 每一层。输入当前层的 token 嵌入,输出稀疏掩码 \(\mathbf{M} \in \{0,1\}^{L \times L}\)。核心流程:动态分割 chunk → 计算 chunk 级相似度 → 上采样到 token 级 → TopK 选择保留的 token 对。
关键设计¶
-
分层稀疏预测 (Hierarchical Sparsity Prediction)
- 功能:将序列分为 \(N_c\) 个不重叠的 chunk,计算 chunk 级相似度矩阵 \(\mathbf{S}_c \in \mathbb{R}^{N_c \times N_c}\),上采样为 token 级相似度矩阵 \(\mathbf{S}_t \in \mathbb{R}^{L \times L}\),对每个 query token 做 TopK 选择(预算 \(N_b\))
- 核心思路:\(N_c \ll L\) 使得 chunk 级计算代价极低(\(O(N_c^2)\) 替代 \(O(L^2)\));同一 chunk 对内的 token 对共享同一重要性分数
- 设计动机:直接预测 \(L \times L\) 掩码的代价等价于密集注意力,分层预测将复杂度降为 \(O(N_c^2 + L \cdot N_b)\)
-
动态边界检测 (Dynamic Boundary Detection)
- 功能:用轻量神经网络预测每个 token 位置是否为 chunk 边界。编码器用 MHA 聚合左右窗口的 key 向量,特征融合拼接 \([\mathbf{k}_{\text{left}}, \mathbf{k}_{\text{right}}, |\mathbf{k}_{\text{left}} - \mathbf{k}_{\text{right}}|, \mathbf{k}_{\text{left}} \odot \mathbf{k}_{\text{right}}, \text{sim}(\mathbf{k}_{\text{left}}, \mathbf{k}_{\text{right}})]\),MLP 输出二分类概率
- 核心思路:内容变化大的位置应成为 chunk 边界(语义分段),用左右窗口差异来检测
- 设计动机:固定大小 chunk 太死板,一刀切无法适应文档内部的语义段落结构变化。自适应分割让每个 chunk 内部语义更一致,chunk 级相似度对 token 级重要性的代理更准确
-
鲁棒 Chunk 表示 (Robust Chunk Representation)
- 功能:对 chunk 内 token 嵌入做平均池化,然后乘以 \(\sqrt{|\mathbf{C}|}\) 进行长度归一化
- 核心思路:\(\mathbf{q}_c = \sqrt{|\mathbf{C}|} \cdot \bar{\mathbf{q}}\),\(\mathbf{k}_c = \sqrt{|\mathbf{C}|} \cdot \bar{\mathbf{k}}\)。chunk 级相似度 \(\mathbf{S}_c = \mathbf{Q}_c \mathbf{K}_c^{\top}\)
- 设计动机:
- 直接 padding 后平均会被零值稀释表示质量
- 不同长度 chunk 的平均向量范数不同,导致相似度分数偏差。\(\sqrt{|\mathbf{C}|}\) 归一化消除长度对点积的影响
-
Prefill 与 Decode 阶段适配
- 功能:Prefill 阶段一次性预测全部边界并计算完整 \(\mathbf{S}_c\);Decode 阶段增量扩展边界并只计算新增行
- 核心思路:Decode 时将之前生成的 token 作为一个额外 chunk,当前 token 单独作为一个 chunk,只需计算最后一行的 chunk 相似度
- 设计动机:避免 decode 时重复计算已有 chunk 的相似度
损失函数 / 训练策略¶
- 边界检测器使用二分类交叉熵损失训练,正样本为真实语义边界位置
- DHSA 本身不需要重训基座模型,只需训练轻量的边界预测器
- 支持跨层共享边界 (boundary sharing) 以进一步降低开销,但会略微影响精度
实验关键数据¶
主实验 — LongBench (Gemma2-2b-it, budget=2k)¶
| 方法 | NrtvQA | Qasper | Mf-en | HotpotQA | 2WikiMQA | Musique | GovReport | QMSum | MultiNews | TriviaQA | SAMSum |
|---|---|---|---|---|---|---|---|---|---|---|---|
| Dense | 22.37 | 35.32 | 37.32 | 41.63 | 32.05 | 19.05 | 27.08 | 21.08 | 25.48 | 87.00 | 41.26 |
| Block Sparse | 16.74 | 26.15 | 32.83 | 35.74 | 31.93 | 14.44 | 26.20 | 19.54 | 25.30 | 86.12 | 40.38 |
| DHSA | 20.69 | 30.20 | 34.98 | 38.78 | 31.96 | 15.90 | 26.75 | 20.74 | 25.38 | 87.03 | 41.46 |
消融实验 — 延迟与内存 (NarrativeQA, Gemma2)¶
| 注意力实现 | 方法 | 精度(%) | 延迟(s) | 峰值内存(GB) |
|---|---|---|---|---|
| eager | Dense | 21.15 | 1.65 | 10.72 |
| eager | Block Sparse | 17.04 | 1.00 | 9.08 |
| eager | DHSA | 20.12 | 1.19 | 6.91 |
| torch.sdpa | Dense | 22.37 | 1.10 | 6.33 |
| torch.sdpa | Block Sparse | 16.74 | 0.88 | 9.88 |
| torch.sdpa | DHSA | 19.37 | 0.91 | 6.99 |
关键发现¶
- 精度保持:在 LongBench 11 个子任务中,DHSA 在 10 个上优于 block sparse,在 TriviaQA 和 SAMSum 上甚至超过 dense attention。Needle-in-a-Haystack 测试中 DHSA (1k budget) 与 dense 表现完全一致。
- 内存优势显著:eager 模式下 DHSA 峰值内存仅 6.91GB,比 dense 降低 35.5%,比 block sparse (9.08GB) 也低 24%。
- 延迟竞争力:torch.sdpa 模式下 DHSA 延迟 0.91s 仅比 block sparse (0.88s) 慢 3%,但精度高出 2.6 个百分点。
- 长上下文扩展:16k 和 32k 序列长度下,dense eager OOM,DHSA 正常运行且延迟仅为 sdpa dense 的 ~40-60%。
- 边界共享权衡:跨层共享边界 (DHSA+bs) 进一步降低开销,但部分任务精度略降(如 Mf-en 从 34.98 降至 31.20)。
亮点与洞察¶
- 分层预测是核心创新:绕过了"预测 \(L^2\) 稀疏掩码本身就是 \(O(L^2)\)"的悖论。Chunk 级预测将搜索空间压缩了 \((L/N_c)^2\) 倍,是该方法能实际加速的关键。
- 完全数据驱动:不依赖预定义的注意力模式模板(如 A-shape、vertical-slash),通过学习自动发现输入相关的稀疏模式。这使得 DHSA 在不同任务间有更好的泛化性。
- \(\sqrt{|\mathbf{C}|}\) 归一化看似简单但很关键:解决了变长 chunk 的表示偏差问题。粗暴地用均值池化会让长 chunk 的相似度分数与短 chunk 不可比。
- Plug-in 设计:不修改原模型权重、不需要重训,对已部署的端侧模型有极大的实用价值。
局限与展望¶
- 延迟不是绝对最优:DHSA eager 模式 1.19s 比 block sparse 1.00s 慢 19%,主要是边界预测和 chunk 表示计算的开销。端侧部署需要进一步优化这部分算子。
- 未扩展最大上下文长度:作者提到 Gemma 系列缺乏可靠的上下文扩展实现,限制了进一步验证。
- 边界检测器训练:需要预先标注语义边界数据来训练,增加了部署门槛。如能无监督学习边界效果更佳。
- 超参数依赖:chunk budget \(N_b\) 和 chunk 大小仍需手动调节,自适应学习这些超参是重要方向。
- 仅验证小模型:实验限于 Gemma2-2b 和 Gemma3-1b,更大模型的效果待确认。
相关工作与启发¶
- Longformer / BigBird:经典静态稀疏注意力,用固定滑窗 + 全局 token,缺乏对输入的适应性
- MInference:动态稀疏但依赖预定义模式(A-shape、vertical-slash),本文完全数据驱动
- H2O / Scissorhands:基于 KV cache 淘汰的动态方法,但淘汰准则是启发式的
- PyramidKV:金字塔式 KV cache 压缩,思路互补
- Block Sparse Attention (Han Lab):MIT Han Lab 的块稀疏实现,是本文的主要 baseline
- 启发:分层预测的思想可以推广到其他注意力变体(如 cross-attention、multi-query attention);自适应分割也可应用于 RAG 中的文档分块
评分¶
- 新颖性: ⭐⭐⭐⭐ — 动态分层预测 + 自适应边界检测 + 长度归一化的组合设计新颖,每个组件都有清晰的设计动机
- 实验充分度: ⭐⭐⭐⭐ — Needle-in-a-Haystack + LongBench 多任务 + 延迟/内存分析 + 不同注意力实现的对比,评估维度全面
- 写作质量: ⭐⭐⭐⭐ — 方法动机和流程描述清晰,分层预测的直觉解释得好
- 价值: ⭐⭐⭐⭐ — 端侧长上下文 LLM 部署的实用方案,plug-in 设计降低了工程门槛