跳转至

Token Sparse Attention: Efficient Long-Context Inference with Interleaved Token Selection

会议: ICML 2026
arXiv: 2602.03216
代码: https://github.com/dongwonjo/Token-Sparse-Attention
领域: 模型压缩 / 长文本推理加速
关键词: 稀疏注意力, prefill 加速, 可逆 token 选择, FlashAttention 兼容, 动态稀疏

一句话总结

作者发现 token 的"重要性"在层间和头间剧烈变化,传统 token eviction 一次性删除是不可逆的早期决策错误;他们提出 Token Sparse Attention,每层每个 attention head 独立选 \(L' \ll L\) 个 token 做密集 attention,输出再 scatter 回原始序列长度,配上残差路径让被略过的 token 在下一层重新有机会被选中——既保留头/层级动态选择,又能直接调用 FlashAttention 等密集 kernel,在 128K 上下文上叠加 FlexPrefill 后达到 ×3.23 注意力加速、精度损失 <1%。

研究背景与动机

领域现状:LLM 上下文窗口动辄 100K+ 后,attention 的 \(O(L^2)\) 复杂度成主要瓶颈。两条加速路线:(i) 稀疏 attention(如 Minference、FlexPrefill),用块级稀疏模式跳过低重要性区域;(ii) token eviction(PyramidInfer、FastKV、GemFilter),在早期层选出 top-k token,深层只算这些。

现有痛点:稀疏 attention 是块级的,块内若混入低相关 token 也会被一起算,稀疏度被天花板限制;token eviction 在早期层 hard-decide 哪些 token 重要,被删的 token 在深层即使变重要也回不来——这违反了 token 重要性的真实动态性。

核心矛盾:作者用 LLaMA-3.1-8B-Instruct 实测发现:(i) 层间 top-1% token 的重叠率随层间距快速下降,重要性在层间漂移;(ii) 同一层不同 head 的 top token 排名差异很大,head 各自关心不同的语义。eviction 用"一刀切"的 token 集合既忽略层动态又忽略 head 动态。

本文目标:(i) 设计一个 token 级稀疏机制,既能用 head/层各自的 token 选择又能在被略过后保持可恢复;(ii) 必须能直接复用 FlashAttention 等优化好的密集 kernel,不写新 CUDA;(iii) 必须能和现有稀疏 attention(块级)正交叠加。

切入角度:与其在 attention map 上做稀疏(被块边界限制)或在 KV cache 上做删除(不可逆),不如对 \(Q, K, V\) 本身做可逆的压缩-解压:选 token 时压成短序列做密集 attention,输出再 scatter 回原长度并加残差。残差路径让"未选 token"的信息从上一层流入下一层,等价于给它们留一条复活通道。

核心 idea:用 "compress-then-decompress + 残差" 把 token-level 稀疏化变成可逆操作,让每层每 head 都能重新决策。

方法详解

整体框架

Token Sparse Attention 在每个被选中的稀疏层内分两步:(1) Stage 1 压缩:用 Dynamic Token Coverage 算法估出每个 head \(h\) 的 token 集 \(S_{H=h}\)(大小 \(L'\)),从 \(Q,K,V \in \mathbb R^{L\times d}\)\(S_h\) gather 出 \(\hat Q, \hat K, \hat V \in \mathbb R^{L'\times d}\),调 FlashAttention 在 \(L'\times L'\) 上做密集 attention 得到 \(\hat O\);(2) Stage 2 解压:把 \(\hat O\)\(S_h\) scatter 回零初始化的 \(\mathbb R^{L\times d}\),未选位置保持 0,等价于对未选 token 施加 hard mask;再加残差连接。复杂度从 \(O(L^2 d)\) 降到 \(O(L'^2 d)\)。哪些层做稀疏由 Inter-Layer Representation Drift 一次性预选(默认底 50% 漂移最小的层),不需要训练。

关键设计

  1. Compress-then-Decompress 可逆 token 稀疏化:

    • 功能:让每层每 head 独立选 token 做密集 attention,未选 token 通过残差路径在下一层重新有机会被选。
    • 核心思路:Stage 1 对每个 head \(h\) 独立选 \(S_h\),gather 出 \(\hat Q_h, \hat K_h, \hat V_h\);attention 在压缩空间 \(\mathbb R^{L'\times L'}\) 上由 FlashAttention 直接处理,输出 \(\hat O_h\)。Stage 2 用 scatter 把 \(\hat O_h\) 散回 \(\mathbb R^{L\times d}\) 的对应行(未选行 = 0),然后 \(X_{\ell+1} = X_\ell + \text{Decompress}(\hat O_h)\)。残差使被略过 token 的上层表示直接流到下一层,下一层若判定其重要可再次选中。
    • 设计动机:传统 token eviction 把 \(L\to L'\) 当作不可逆 KV 删除,下层就再也看不到被删的 token;compress-decompress 把它当临时不参与 attention 的输入,结构上不删任何东西,从而层/head 间的动态重要性得以保留。还有一个工程红利:压缩后的 \(\hat Q\hat K\hat V\) 是 dense 连续的,能直接喂任何现成 attention kernel(FlashAttention、FlexPrefill 等),不需要写新 CUDA。
  2. Dynamic Token Coverage(按攻击力分位定预算):

    • 功能:在推理时动态决定每层留多少 token(不是固定比例),并分头决定留哪些。
    • 核心思路:对每个 head 用 recent queries 与所有 keys 做 lightweight attention 得到 \(\hat A\),按列求和后 pool 得 head 级 token score \(s_h[t]\),按 head 汇总并归一化得到层级 score \(s_l\)。把 \(s_l\) 升序排,找最小的 \(k_{\text{sparse}}\) 使 \(\sum_{j=1}^{k_{\text{sparse}}} s_l[I[j]] \ge \tau\)(默认 \(\tau=0.005\)),即累计最不重要的 token 总权重不超过 \(\tau\),把这部分丢掉;保留 \(k_{\text{keep}} = L - k_{\text{sparse}}\) 个。每个 head 独立用 top-\(k_{\text{keep}}\) 取自己最关心的子集 \(S_h\)。用 Triton 写自定义 fused kernel 让打分本身的 I/O 开销可忽略。
    • 设计动机:固定保留比例会在不同上下文长度/任务上失配(信息密度变化大);按"累计 attention 噪声尾巴 ≤ \(\tau\)"分位的方式让稀疏度自适应——长上下文里 attention noise 多则稀疏度大、短上下文里则小。这背后假设:长 context attention 必然累积"长尾低权重 token",砍掉它们等于做结构正则化。
  3. Inter-Layer Representation Drift 选稀疏层(哪些层能扛):

    • 功能:发现哪些层做稀疏化对结果伤害最小,避免一刀切所有层。
    • 核心思路:定义层 \(\ell\) 的归一化漂移 \(R_\ell = \mathbb E_t[\|h_{\ell+1,t} - h_{\ell,t}\|_2 / (\|h_{\ell,t}\|_2 + \epsilon)]\),漂移小 = token 表示稳定 = 该层可承受稀疏化。先在校准数据上算出 \(R_\ell\),对其排名得到 \(\hat R_\ell\),取 \(\mathcal L_{\text{sparse}} = \{\ell | \hat R_\ell \le \delta\}\)(默认 \(\delta=0.5\),即漂移最小的 50% 层做稀疏)。这只在模型加载时跑一次。
    • 设计动机:实验显示对 200 个随机 3 层组合做稀疏,平均漂移和准确率高度相关——稳定层做稀疏不破坏 token 表示,不稳定层做则误差累积。把"哪层稀疏"从超参变成 data-driven 的预处理,省掉用户调参负担。

损失函数 / 训练策略

完全 training-free 推理时方法,不需要任何微调;只在模型加载时跑一次校准跑得到 \(\mathcal L_{\text{sparse}}\)。超参 \(\tau\):LLaMA-3.1-8B 用 0.005,Mistral-Nemo-12B 用 0.008。token scoring 用 Triton fused kernel,attention 用未修改的 FlashAttention。

实验关键数据

主实验

RULER benchmark 上叠加各 baseline 后的平均精度与 128K 加速比(LLaMA-3.1-8B-Instruct):

方法 4K 32K 128K Avg. 128K 加速
FlashAttention 95.82 84.87 74.15 87.01 ×1.00
+ Token Sparse 96.06 84.81 73.68 87.02 ×1.36
Minference 93.46 85.34 73.63 86.49 ×1.12
+ Token Sparse 93.05 85.10 72.18 86.05 ×1.38
FlexPrefill 95.48 87.20 73.75 87.27 ×2.44
+ Token Sparse 95.33 87.68 73.58 87.27 ×2.76

与 token eviction 方法在同加速比下对比(128K,LLaMA-3.1-8B):

方法 Avg. 精度 加速
FlashAttention 87.01 ×1.00
PyramidInfer 78.49 ×1.49
GemFilter 85.12 ×1.53
FastKV 85.64 ×1.50
Token Sparse Attention 86.84 ×1.51

消融实验

配置 关键发现 含义
Dynamic \(\tau=0.005\) vs Fixed \(s=0.3\) 同加速下 87.02 vs 86.91 动态预算优于固定比例
Dynamic \(\tau=0.010\) vs Fixed \(s=0.5\) 高稀疏度下 86.84 vs 85.43 稀疏越激进,动态优势越明显
加速分解(128K) scoring/compress/decompress 总 overhead <11% 工程实现轻量
稀疏度随上下文长度 4K: 17%, 128K: 54% 长 context 自然有更多可丢 token

关键发现

  • 与 FlashAttention 叠加:精度几乎无变化(87.01 → 87.02),单独贡献 ×1.36 加速。
  • 与块级稀疏(FlexPrefill)叠加最有价值:×2.44 → ×2.76,证明 token 级与块级稀疏度互补、不互相覆盖。
  • 同加速比下击败所有 token eviction 方法,差距在 4K 短上下文上尤其明显(PyramidInfer 比 FlashAttn 低 17 个点)。

亮点与洞察

  • Compress-then-Decompress 是一个非常优雅的"伪稀疏"机制:表面上算了 \(L'\times L'\) 的密集 attention 然后填回 \(L\times d\),但残差通道让被略过的 token 信息保留,等价于在每层做了一次轻量的、可逆的、head-specific 的 token 选择。这种"逻辑稀疏 + 物理密集"的设计可以迁移到 MoE、稀疏 expert routing 等场景。
  • 不用写新 kernel 是工程上的大杀器:直接调 FlashAttention/FlexPrefill 现成 kernel,对任何下游使用者零门槛。这与 token eviction 必须改 KV cache 结构相比,部署成本天差地别。
  • Drift 选层是个朴素但强力的 prior:把"哪些层能扛稀疏"从超参变成 data-driven 决策,可以推广到任何"层级压缩"任务(如 layer dropout、layer pruning)。

局限与展望

  • 仍依赖 recent queries 估 token score,这是一个 heuristic;如果模型本身用 sliding window 或 chunked attention,recent queries 的统计意义会被破坏。
  • 残差路径让"未选 token"信息保留,但每层 scatter 出来的 0 行其实丢失了被选 token 与未选 token 之间的 cross-attention 贡献;论文未量化这部分损失。
  • 头/层间稀疏度差异大时,batch 内不同 head 的 \(L'\) 不同会破坏 tensor 规整度(虽然 FlashAttention 支持 ragged,但效率受影响);论文没讨论 batch 多 sample 时的实际 throughput。
  • 只在 prefill 上验证,decoding 阶段没用;但 decoding 的瓶颈是 KV cache 加载而非 attention 计算,本方法天然不适合。
  • 改进方向:把 drift 选层做成自适应(每个 prompt 不同)、把 scoring 替换为 learnable router(end-to-end 训练)、与 KV cache quantization 联用。

相关工作与启发

  • vs Minference / FlexPrefill (块级稀疏):他们在 attention map 上按块跳,被块边界限制;本方法在 token 级别选择,可与他们正交叠加,FlexPrefill 上还能再加 ×1.13 加速。
  • vs PyramidInfer / FastKV / GemFilter (token eviction):他们在早期层 hard-decide 哪些 token 留下,深层无法恢复;本方法每层都可重选,同加速比下精度高 1-8 个点。
  • vs FlashAttention:FlashAttention 是 I/O 优化的密集 attention,复杂度仍 \(O(L^2)\);本方法在它之上做算法稀疏化,复杂度降到 \(O(L'^2)\) 且复用其 kernel。
  • vs KV cache quantization (KIVI/H2O):他们减 KV 内存载入开销,本方法减 attention 计算开销,两者完全正交,可联合使用。

评分

  • 新颖性: ⭐⭐⭐⭐ Compress-then-Decompress 的可逆设计 + head-specific token 选择是简洁但有效的新点,drift 选层也是干净的工程贡献。
  • 实验充分度: ⭐⭐⭐⭐ 两个模型 × 4 个 baseline × 多个长度 × 多个 benchmark(RULER/InfiniteBench),加上与 eviction 方法的同加速对比,覆盖度高。
  • 写作质量: ⭐⭐⭐⭐ 从"token 重要性动态性"两个观察直接推到方法设计,逻辑顺畅;图 3 把 compress-decompress 流程画得很清楚。
  • 价值: ⭐⭐⭐⭐ 可直接落地工业部署,对所有长 context LLM 推理服务都有价值;与现有稀疏方法正交可叠加是关键卖点。