跳转至

Sparser Block-Sparse Attention via Token Permutation

会议: ICML 2026
arXiv: 2510.21270
代码: https://github.com/xinghaow99/pbs-attn (有)
领域: LLM效率 / 长上下文 / 稀疏注意力
关键词: 块稀疏注意力, Token Permutation, 长上下文 Prefilling, FlashAttention, Heavy Hitter

一句话总结

本文提出 PBS-Attn,利用注意力的置换不变性,先按"全局重要性"对 key 在段内重排,把散落各处的 heavy hitter 聚拢成连续高密度块,再做块稀疏计算,从而在保持精度近乎追平 full attention 的同时,把长上下文 prefilling 端到端加速最高 2.75 倍。

研究背景与动机

领域现状:长上下文 LLM 的瓶颈是 self-attention 的 \(O(N^2)\) 复杂度。FlashAttention 通过分块和在线 softmax 解决了显存问题,但 FLOPs 仍是平方级。块稀疏注意力(MInference / FlexPrefill / XAttention 等)在 FlashAttention 的 tiling 之上再加一层 "block mask",对预测为低权重的整块直接跳过计算,是目前主流加速路径。

现有痛点:块稀疏方法被注意力矩阵的原始结构绑住了。query 在某个 block 里关心的关键 key("heavy hitter"),实际是按重尾分布零散撒在整条序列上的;要把它们覆盖住,就得选很多个 block,而每个被选中的 block 里真正有用的 token 又很少,造成"取了一筐石头去淘几粒金子"。

核心矛盾:现有方法只在给定的混乱矩阵里被动挑选 block(优化 \(\mathbb{C}_{\text{sel}}\)),却没有人去优化注意力矩阵本身的结构。这是一条被忽略的优化轴。

本文目标:在保持模型精度和因果性的前提下,主动重塑 Q/K/V 的排列,让 block 级稀疏度从 30%-40% 提到 60%+,并端到端落到墙钟加速上。

切入角度:注意力对 key-value 的置换是不变的\(\text{Attn}(Q, P_\pi K, P_\pi V) = \text{Attn}(Q, K, V)\))。这意味着可以自由重排 key 的顺序,把散落的 heavy hitter 物理上聚到一起,而不改变数学输出。难点只剩两个:① 怎么定义"重要性"来排序;② 怎么和因果 mask 共存。

核心 idea:用最后一个 query block 当 proxy 估出每个 key 的全局重要性分数,然后在段内按分数降序重排 key,段间保持原序以维持因果性 —— 把"挑块"变成"先整理再挑块"。

方法详解

整体框架

PBS-Attn 是 plug-and-play 的 prefilling 加速模块,pipeline 分四步:

  1. 打分:用序列最后一个 query block \(\mathbf{Q}_{\text{last\_block}}\) 与全部 K 做一次小矩阵乘 + softmax + 按行求均值,得到长度为 \(N\) 的全局重要性分数 \(\mathbf{s}\)(开销 \(O(N \cdot B \cdot d)\),相对 \(O(N^2 d)\) 可忽略)。
  2. 段内重排:把序列切成大小为 \(S\) 的段,在每段内按 \(\mathbf{s}\) 降序对 K(和对应 V)做局部 permutation \(\pi_i\);段间保持原序。Query 保持原序(\(\mathbf{P}_\sigma = \mathbf{I}\))。
  3. 块选择 + 稀疏计算:在重排后的 \((\mathbf{Q}, \mathbf{K}', \mathbf{V}')\) 上用 mean-pooling 估每对 (query block, key block) 的重要性,得到稀疏 mask \(\mathbf{M}\);只对 \(\mathbf{M}_{i,j}=1\) 的块跑 FlashAttention 在线 softmax。
  4. 逆置换:由于 query 没动,输出无需做 \(\mathbf{P}_\sigma^T\) 反置换,直接得到与原始顺序一致的 \(\mathbf{O}\)

关键设计

  1. Segmented Permutation(段内置换 + 段间因果):

    • 功能:在不破坏 causal mask 的前提下做 key 重排。
    • 核心思路:把前 \(\lfloor N/S \rfloor \cdot S\) 个 token 切成 \(G\) 个长度 \(S\) 的段,全局置换矩阵 \(\mathbf{P}_\pi = \text{diag}(\mathbf{P}_{\pi_1}, \dots, \mathbf{P}_{\pi_G}, \mathbf{I})\) 写成块对角形式。段间相对顺序不变,因此 query \(q_i\) 仍然只能"看到"它所在段及之前所有段的 key —— 这些段不论内部怎么打乱,都还在 \(q_i\) 的可见范围里。对角线段(query 段 = key 段)保留因果三角,对角线以下的段整块要么全选要么全跳。
    • 设计动机:一次性全局 permutation 会把因果三角彻底打散,让原本被天然跳过的上三角块变成需要计算(block density 从 \(\frac{T_c+1}{2T_c}\) 涨到 1),收益反而是负的。段化是保因果与提稀疏度之间的最小折中。
  2. Global-Importance-based Key Permutation(用 last-block query 做 proxy 排序):

    • 功能:定义"key 有多重要",作为段内排序的依据。
    • 核心思路:分数向量 \(\mathbf{s} = \text{mean}_{\text{rows}}(\text{softmax}(\mathbf{Q}_{\text{last\_block}} \mathbf{K}^T / \sqrt{d}))\),每段内 \(\pi_i = \text{argsort}(-\mathbf{s}_{[(i-1)S+1 : iS]})\) 降序排列。作者通过 16K 上的对照实验(Figure 1)验证了:随机 permutation 反而掉点(说明原序里有局部结构),fine-grained 的 greedy 局部对齐略好但不如全局;而用"任意一小撮 query"作为 proxy 估全局重要性,效果最佳 —— 因为 heavy hitter(如 attention sink、vertical line pattern)对不同 query 几乎是一致的。
    • 设计动机:直接对完整 \(Q K^T\) 排序是 \(O(N^2)\),得不偿失;用最后 \(B\) 个 query 做 proxy 把代价压到 \(O(NBd)\) 线性,且实测和"全 query 平均"几乎一致。这把"为什么 permutation 能 work"从经验观察落到一个可解释的归纳偏置上:稀疏注意力的关键不在精细对齐,而在把全局重要 token 聚成簇。
  3. Permuted-FlashAttention Triton 内核(只重排 K,避免 GQA 复制开销):

    • 功能:把段化 permutation 嵌进 FlashAttention 的 tile 调度里,让重排逻辑不打断 SRAM 上的在线 softmax。
    • 核心思路:先在 HBM 上做一次性的 \(\mathbf{K}' = \mathbf{P}_\pi \mathbf{K}\)\(\mathbf{V}' = \mathbf{P}_\pi \mathbf{V}\) 重排,然后块选择 mask \(\mathbf{M}\) 指引哪些 \((i,j)\) tile 跳过;选中的 tile 走标准 FlashAttention 流程更新 \(\mathbf{m}_i^{(j)}, \mathbf{l}_i^{(j)}, \mathbf{O}_i^{(j)}\),跳过的 tile 直接继承前一状态。Query 不重排还有一个隐藏好处:在 GQA 下,一个 query head 对应多个 key head 时,可以把 permutation 共享/独立两种策略都做(默认独立以最大化稀疏度,附录 G 也评估了共享方案以省显存)。
    • 设计动机:query permutation 的收益边际(Figure 6a),但代价是要逆置换输出且在 GQA 下需要重新组织 query tile —— 不值。只动 K/V 是性价比最高的切法。

损失函数 / 训练策略

PBS-Attn 是 training-free 的 inference 加速方法,不引入任何额外参数和训练。默认配置 \(B=128\), \(S=256\), 块选择阈值 0.9(即累计 attention mass 覆盖 90% 时停止选 block)。可与 antidiagonal scoring(XAttention 的策略)组合得到 PBS-Attn+。

实验关键数据

主实验

LongBench 平均分(Llama-3.1-8B-Instruct,越接近 Full 越好):

方法 Single-Doc QA Multi-Doc QA Few-shot Synthetic Avg 说明
Full Attention 48.80 41.80 29.73 66.82 38.28 上限 oracle
MInference 47.21 40.93 29.36 62.36 37.06 离线 pattern 搜索
FlexPrefill 47.03 38.57 30.38 24.71 30.56 Synthetic 任务崩了
XAttention 48.26 40.23 31.35 54.64 36.42 antidiagonal 评分
MeanPooling(无 perm) 46.61 40.66 30.64 58.14 36.67 同选块器但不重排
PBS-Attn 48.00 42.09 28.36 63.80 37.37 距 Full 仅差 0.91

RULER 128K 上 Llama-3.1-8B-Instruct 平均分:Full 75.30 / MeanPooling 59.32 / PBS-Attn 66.98 / PBS-Attn+ 72.09 —— 越长上下文,permutation 的相对收益越大(128K 上对 MeanPooling 提升 7.66 分)。

效率:在 H100 上测 TTFT,256K 上下文上 PBS-Attn 相对 FlashAttention 实现 2.75× 端到端加速,且在 8K-512K 全程都是最快或并列最快;对比 MInference 直到 128K 才有加速、XAttention 在 128K 后停止增长。

消融实验

配置 现象 说明
只 permute K(默认) 性能-密度曲线最优 主方案
只 permute Q 边际略优但 GQA 下效率低 不采用
Q 和 K 都 permute 无显著改进 排除
大段 \(S\) 性能-密度曲线更平 段内排序信息更充分,但对角线段计算量也大
不做 permutation(MeanPooling) LongBenchv2-Qwen 上掉 31% 相对分 验证 permutation 本身的价值
Random Permutation 显著掉点 证实原序里确有局部结构需要尊重
Greedy 局部对齐 不如全局 heavy-hitter 排序 全局簇 ≻ 局部精细

关键发现

  • 越长越受益:8K 上 sparsity 绝对提升 7%,128K 上 selected block 数下降 14.4%,RULER 128K 上对 MeanPooling 提升 7.66 分 —— 长上下文里 fragmentation 越严重,permutation 越值。
  • Heavy hitter 是 query-agnostic 的:用随机 query 子集 vs 最后一个 block 做 proxy,差距可忽略。这暗示稀疏注意力中真正重要的 key 是序列固有属性,而非和特定 query 强相关 —— 这一观察让 proxy 排序的 \(O(N B d)\) 开销显得理所当然。
  • Permutation 与块选择算法正交:把 antidiagonal scoring(XAttention)换进 PBS-Attn 得到 PBS-Attn+,进一步把 RULER 平均分推到接近 full attention(Llama 上仅差 3.21);说明 permutation 的收益是结构层面的,不与具体选块器耦合。
  • 失败模式有界:在 Llama-3.1-8B 的 1024 个 head 上,permutation 在 97.5% 覆盖率下让 70.8% 的 head 稀疏度变好,只让 5.2% 的 head 变差;对应那些天然就是"对角带"或"垂直线整齐排列"的 head。

亮点与洞察

  • 把"挑"换成"先整理再挑"是一个很优雅的视角切换:之前所有块稀疏工作都在卷选块策略,本文换了一根优化轴 —— 注意力矩阵本身可以被无损改写。这种"打开新优化维度"比榨干旧维度更值钱。
  • Permutation 的因果性处理可以套用到其他稀疏机制:段内置换 + 段间保序的块对角形式,本质上是给"必须保留某种全局顺序"的场景提供了一个通用的局部重排框架。比如 KV cache eviction、prefix caching、speculative decoding 的 verify 阶段,都可能借用同一思路把"看似不能动"的 token 顺序变成"段内可调"。
  • 用极小 proxy 估全局重要性的思路可迁移:last-block query 作为 proxy 的代价只有 \(O(NBd)\),但能稳定排出 heavy hitter —— 这种"花 1% 算力换 30% 结构优化"的范式,可以用在 KV 量化粒度选择、token pruning 排序、layer skipping 决策等任何"哪些维度值得保留全精度"的子问题上。

局限与展望

  • 只覆盖 prefilling,没动 decoding。decoding 阶段每步只产生一个 query,proxy 排序的逻辑不再适用;KV cache 的 permutation 需要更精细的增量维护策略。
  • 打分依赖 last-block query 这个 proxy,对超长 sequence 且最后一段语义与前文严重断裂的场景(如多文档混合)可能失真;论文没给极端 mismatch 场景的鲁棒性分析。
  • block 选择阈值 0.9 是手工设的;不同任务(如 RULER 上 KV 检索任务)需要切换到 antidiagonal 评分才不掉点,说明"一套阈值打天下"在 synthetic 任务上还有缺口。
  • GQA 下默认要把 K/V 在 group 内复制以最大化稀疏度,会增加 HBM 占用;附录 G 的 share-permutation 方案省内存但稀疏度降低,二者之间还没有一个自适应折中。
  • 改进思路:① 把 last-block proxy 换成"动态采样若干 query block 的并集"做更鲁棒的估计;② 让 segment size \(S\) 随 layer/head 自适应(不同 head 的 fragmentation 程度差异很大);③ 把 permutation 推到 decoding 阶段,配合分段 KV cache 做增量 re-sort。

相关工作与启发

  • vs MInference:MInference 离线搜索 attention pattern 再固定使用;PBS-Attn 在线根据输入决定 permutation,泛化性更好(MInference 在 RULER 128K 上掉到 70.47,PBS-Attn 66.98 但 PBS-Attn+ 72.09)。
  • vs FlexPrefill:FlexPrefill 用 \(\gamma=0.95, \tau=0.1\) 的动态阈值挑块,速度接近 PBS-Attn 但精度严重下降(LongBench Synthetic 24.71 vs Full 66.82,几乎崩了)。说明"光挑得快"不够,得让被挑的内容真正密集。
  • vs XAttention:XAttention 用对角线评分挑块,是当前最强 baseline 之一;PBS-Attn 的 permutation 与之正交,PBS-Attn+ 直接把 XAttention 当块选择器、外加 permutation,进一步把 LongBench 推到 36.87(比 XAttention 高 0.45),证明 permutation 是 plug-in 收益。
  • vs Heavy Hitter Oracle (H2O):H2O 在 decoding 阶段保留重要 token;本文在 prefilling 阶段把它们聚拢但保留全量计算。可视为同一"heavy hitter"信念在两个阶段的不同利用方式 —— 一个是"保留谁",一个是"重排谁"。

评分

  • 新颖性: ⭐⭐⭐⭐⭐ 第一个把 attention 的 permutation invariance 当作主动优化轴用在块稀疏加速上,且段内置换 + 段间保序的因果处理足够干净。
  • 实验充分度: ⭐⭐⭐⭐ LongBench / LongBenchv2 / RULER 三套数据 + 两个主流长上下文模型 + 端到端 TTFT 测量 + 段大小/块大小/置换对象/proxy 选择多维消融;唯一可惜是缺 70B+ 量级和 decoding 阶段的探讨。
  • 写作质量: ⭐⭐⭐⭐⭐ 从 information fragmentation 的现象出发,先观察—再理论(三条 lemma + 一条 theorem)—再算法—再实验,逻辑链非常顺;Figure 1 的 coverage-density trade-off 图把核心动机讲得一目了然。
  • 价值: ⭐⭐⭐⭐⭐ training-free,plug-and-play,开源 Triton 内核,2.75× 端到端加速对长上下文推理服务有直接落地价值;而且"用 permutation 重塑稀疏结构"的思路大概率会被后续 KV cache 压缩 / decoding 加速工作复用。