Towards Long-Horizon Interpretability: Efficient and Faithful Multi-Token Attribution for Reasoning LLMs¶
会议: ICML 2026
arXiv: 2602.01914
代码: https://github.com/wbopan/flashtrace
领域: 可解释性 / LLM 推理 / Token 归因
关键词: token attribution, reasoning LLM, span-wise aggregation, recursive attribution, long-context interpretability
一句话总结¶
针对推理 LLM 长思维链场景下逐 token 归因 \(\mathcal{O}(M\cdot N)\) 慢且归因质量被中间推理 token 吸光的问题,本文提出 FlashTrace:用 span-wise 聚合一次过算完整段目标 token 的归因,再用递归归因把重要性从输出经推理链回溯到原始输入,5k 目标 span 上比最强基线 IFR 快 130 倍以上,同时在 RULER / MATH / MoreHopQA 上 faithfulness 全面占优。
研究背景与动机¶
领域现状:Token attribution 是解释 LLM 输出的主流可解释性手段,主流路线包括 perturbation-based(REAGENT/CLP)、gradient-based(Integrated Gradients)、attention+relevance 传播(IFR、AttnLRP)等。它们都假设要解释的目标是单个 token,把 context 中每个 token 对它的因果贡献算成一个分布。
现有痛点:现代 reasoning LLM(o1、DeepSeek-R1、Qwen-3)会先吐几千 token 的 chain-of-thought 再给答案,让 token 归因落到两个具体问题上: - 效率瓶颈:要解释长度 \(M\) 的输出 span,必须对每个 token 各跑一次归因,复杂度从 \(\mathcal{O}(N)\) 变成 \(\mathcal{O}(M\cdot N)\),5k 输出用 IG 要 10 小时以上、用最快的 IFR 也要 38 分钟,根本没法在 agent workflow 里用。 - 忠实度下降(information absorption):自回归模型下一个 token 直接由前一个 token 触发,所以推理 token \(\mathbf{T}\) 会吸掉绝大部分 attribution mass。文中 Figure 1 量化了这件事——CoT 一开,分到 \(\mathbf{T}\) 上的 mass 从 ~80% 涨到 >90%,而 ground-truth 输入 token 的 recovery rate 从 26% 暴跌到 <10%。最终解释只告诉你"答案是被上一句推理决定的",根本回不到 prompt 里那条真正的证据。
核心矛盾:现有方法只刻画 input→output 的直接依赖,而推理 LLM 的因果链是 \(\mathbf{I}\to\mathbf{T}\to\mathbf{O}\) 三段;既要绕过中间桥 \(\mathbf{T}\) 把重要性继续传回 \(\mathbf{I}\),又要避免对 \(\mathbf{T}\) 里每个 token 都暴力跑一遍归因。换句话说,"多 token 目标"和"多跳传播"必须同时解决,否则效率/忠实度二选一。
本文目标:定义 multi-token attribution 问题,把它拆成两个子问题——(i) 给定 span \(S\),一次算完所有源 token 对 \(S\) 的贡献;(ii) 把推理 token 吸走的 mass 沿因果链反向传回原始输入。
切入角度:作者注意到 ALTI/IFR 框架下,attention head 对单个 target 位置 \(i\) 的贡献写成 \(\mathbf{f}_{j\to i}(\mathbf{x}_j)=\alpha_{i,j}^h \cdot (\mathbf{x}_j W_V^h W_O^h)\),其中 \(\mathbf{v}_j = \mathbf{x}_j W_V^h W_O^h\) 只跟源 token 有关,与目标位置 \(i\) 解耦。把这一观察推到整段目标 span 上,"算所有源 token 对整段 span 的贡献"就可以被代数因式分解掉。
核心 idea:用 span-wise 聚合把"对整段 span 的归因"压成一次前向,再用 recursive attribution 把上一跳分到推理 token 上的分数当成下一跳的"加权 target",从而在不显著加成本的前提下,把重要性沿 \(\mathbf{O}\to\mathbf{T}\to\mathbf{I}\) 流回去。
方法详解¶
整体框架¶
输入:一段完整 context \(\mathbf{S}=\mathbf{I}\circ\mathbf{T}\circ\mathbf{O}\)(用户输入 + 模型生成的推理 + 模型最终输出),要解释的目标是输出 span \(\mathbf{O}\)。
输出:每个 context token 对 \(\mathbf{O}\) 的重要性分数 \(\mathbf{w}_{final}\),按预期能精准定位回原始输入 \(\mathbf{I}\) 里真正决定答案的几个 token。
中间流程: 1. Hop 0:把 \(\mathbf{O}\) 作为目标 span,用 span-wise 聚合一次过算出 context 上的归因分布 \(\mathbf{w}^{(0)}\),得到落在输入的部分 \(\mathbf{w}_{\mathbf{I}}^{(0)}\) 和落在推理 token 的部分 \(\mathbf{w}_{\mathbf{T}}^{(0)}\)。 2. Hop \(k\ge 1\):把 \(\mathbf{w}_{\mathbf{T}}^{(k-1)}\) 当作推理 token 的权重,构造加权目标 span,再做一次 span-wise 聚合,得到新的 \(\mathbf{w}^{(k)}\)。 3. 聚合:按"信息流"语义把所有 hop 的输入部分按"剩余 mass"折算后相加,得到最终 \(\mathbf{w}_{final}\)。实验里 \(K=1\) 已经够用。
底层度量复用 ALTI 的 L1 proximity:\(\text{Proximity}(\mathbf{z},\mathbf{y}) = \max(0, -\|\mathbf{y}-\mathbf{z}\|_1 + \|\mathbf{y}\|_1)\),衡量"去掉贡献 \(\mathbf{z}\) 后目标向量 \(\mathbf{y}\) 的范数会缩多少"。这个度量在 Transformer 高维各向异性空间里比 cosine 稳。
关键设计¶
-
Span-wise Aggregation(一次过算整段 span 的归因):
- 功能:把"\(M\) 个目标 token 各跑一次归因"压成"对整段 span 跑一次",复杂度 \(\mathcal{O}(M\cdot N)\to\mathcal{O}(N)\)。
- 核心思路:先把整段目标的层级表示求和为 \(\mathbf{Y}_S=\sum_{i\in S}\mathbf{y}_i\),再把源 token \(j\) 对整段的贡献定义为 \(\mathbf{Z}_S=\sum_{i\in S}\mathbf{z}_{j\to i}\),套同一个 L1 proximity 算分。关键是利用 attention 的线性性把 attention head 贡献 \(\alpha_{i,j}^h \cdot \mathbf{v}_j\) 中的 \(\mathbf{v}_j = \mathbf{x}_j W_V^h W_O^h\) 提出来,得到 \(\mathbf{F}_{j\to S}=\mathbf{v}_j \cdot (\sum_{i\in S}\alpha_{i,j}^h)\);昂贵的 V/O 投影只算一次,每个 target 位置只多一次标量乘加。residual 流也一并按 span 求和,整个 pipeline 内存不再随 \(M\) 涨。
- 设计动机:直接解决"5k token 输出要跑 5k 遍"的效率瓶颈。同时因为只是代数重排、没有近似,理论上保留了 ALTI/IFR 的所有忠实度性质,为后续多跳传播留出预算。
-
Recursive Attribution(沿推理链反向回溯):
- 功能:把上一跳分给推理 token 的重要性 \(\mathbf{w}_{\mathbf{T}}^{(k-1)}\) 转化成下一跳的"加权目标",从而让 mass 不停留在 \(\mathbf{T}\)、继续向 \(\mathbf{I}\) 传播。
- 核心思路:把 span-wise 聚合从"01 mask"自然推广到加权 span,新目标 \(\mathbf{Y}^{(k)}=\sum_{j\in \mathbf{T}} w_j^{(k-1)} \cdot \mathbf{y}_j\),对应贡献 \(\mathbf{Z}^{(k)}=\sum_{j\in \mathbf{T}} w_j^{(k-1)} \cdot \mathbf{z}_{k\to j}\)。因式分解仍然成立——\(\mathbf{v}_k\) 只算一次、外面再点乘标量 \(\sum_j w_j^{(k-1)}\alpha_{j,k}^h\),所以每跳成本基本等于一次前向。重要性等价为"信息流概率":每跳剩余 mass \(\rho_k=\sum_{t\in\mathbf{T}}w_t^{(k)}\) 沿链传播。
- 设计动机:直接对症 information absorption——单跳归因只能解释"答案被上一句推理决定",多跳才能回答"那句推理又是被 prompt 里哪段决定的"。设计成"加权 span 上的同一个 span-wise op"既避免了 sentence-level 切分(vs CAGE),也保住了 \(\mathcal{O}(N)\) 复杂度。
-
跨跳概率流聚合(Probability-mass Aggregation across hops):
- 功能:把 \(K\) 跳里每跳分给输入的分量 \(\mathbf{w}_{\mathbf{I}}^{(k)}\) 合成单一最终分布 \(\mathbf{w}_{final}\),让多跳结果可比、可可视化。
- 核心思路:把整个递归过程当成 mass 的逐跳分流——每跳要么"沉降"到输入、要么"剩在推理链上等下一跳解释"。聚合公式为 \(\mathbf{w}_{final}=\mathbf{w}_{\mathbf{I}}^{(0)}+\sum_{k=1}^{K}(\prod_{j=0}^{k-1}\rho_j)\cdot \mathbf{w}_{\mathbf{I}}^{(k)}\),其中 \(\rho_j\) 是第 \(j\) 跳还留在 \(\mathbf{T}\) 上的剩余 mass。实验里 \(K=1\) 即可解掉绝大部分推理链依赖。
- 设计动机:用"剩余 mass 折算"保证各跳分布在同一概率尺度上合并,不会因为某跳推理链短就把它的输入贡献放大;这也提供了一个天然的早停条件——\(\rho_k\) 很小时再迭代意义不大。
损失函数 / 训练策略¶
本方法是 training-free 的事后可解释性算法,没有训练损失。唯一超参是递归跳数 \(K\)(实验默认 \(K=1\)),不修改模型权重,对底层 Transformer 也无侵入式假设——只用前向 attention 权重 + value/output 投影即可计算。
实验关键数据¶
主实验¶
RULER 系列(多需求 Needle-in-a-Haystack mq、Variable Tracking mv、长上下文 HotpotQA),评估 Qwen-3 8B Instruct,指标 Recovery Rate↑ / RISE↓ / MAS↓:
| 数据集(任务) | 指标 | FlashTrace | 最强基线 | 提升幅度 |
|---|---|---|---|---|
| mq q4(NIAH) | Recovery Rate ↑ | 0.413 | 0.328 (IFR) | +8.5 pp |
| mv v4(Variable Tracking) | Recovery Rate ↑ | 0.516 | 0.452 (IFR) | +6.4 pp |
| HotpotQA h4 c1 | Recovery Rate ↑ | 0.755 | 0.253 (IFR) / 0.229 (AttnLRP) | +50 pp |
| HotpotQA(1024) | RISE ↓ | 0.033 | 0.074 (IFR) | −55% |
| MATH | MAS ↓ | 0.446 | 0.490 (IFR) | −9% |
| MoreHopQA | MAS ↓ | 0.205 | 0.228 (IFR) | −10% |
| Aider Code Gen | MAS ↓ | 0.173 | 0.773 (IFR per-token avg) | −78% |
效率(5k token 目标 span,RULER):FlashTrace < 20 s,IFR > 38 min,130×+ 加速;同时 IG / IG-Attn / Perturbation 在长 context 直接 OOM(Figure 4 虚线)。
消融实验¶
| 配置 | 复杂度 | 时间 (s) | RISE ↓ | MAS ↓ | 说明 |
|---|---|---|---|---|---|
| Exhaustive Token-Level Rollout(理论上的精确多跳归因) | \(\mathcal{O}(M\cdot N)\) | 11.2 | 0.116 | 0.193 | MoreHopQA 上的暴力上界 |
| FlashTrace(span-wise + 递归) | \(\mathcal{O}(N)\) | 0.72 | 0.128 | 0.205 | 时间 ↓93.6%,忠实度只掉 ~10% |
| FlashTrace, K=0(关掉递归) | — | — | — | — | Figure 1 可视化:mass 卡在 \(\mathbf{T}\) 上,recovery rate 跌到 <10% |
| FlashTrace on LLaMA-3.1-8B-It(RULER Avg) | — | — | 0.171 | 0.231 | 换模型仍优于 IFR (0.206/0.298) 和 AttnLRP (0.398/0.683) |
关键发现¶
- 递归是质变而非微调:仅 \(K=1\) 一跳,HotpotQA Recovery Rate 从 IFR/AttnLRP 的 0.13–0.25 跃到 0.51–0.76(h2/h4/h6/h10 全段),证明 information absorption 不是"小修小补"的问题,必须显式建模 \(\mathbf{O}\to\mathbf{T}\to\mathbf{I}\) 的流向。Figure 3 的 hop1−hop2 差异图直观显示:第二跳里推理 token 普遍掉分(红)、输入 token 普遍涨分(绿)。
- Span-wise 是几乎免费的近似:vs Exhaustive Rollout,FlashTrace 在 MoreHopQA 上 RISE/MAS 只退化 6–10%,运行时间从 11.2 s 缩到 0.72 s(93.6% 提速),说明"把多 token 目标聚合成单一 span"在多跳推理上是高质量近似,不需要走 sentence-level 切分(如 CAGE)那种昂贵路线。
- 跨任务跨模型稳定:在代码生成 Aider 上 MAS 从 IFR 的 ~0.78 跌到 0.17,差出近 5 倍,说明 span-wise 归因不仅适合自然语言推理链,也能处理结构化中间产物(代码 diff);换到 LLaMA-3.1-8B-It 后相对优势保留,说明这套机制不挑模型族。
- 效率曲线决定可用性:Figure 4 显示 IG/Perturbation 类方法在 ~2–4k 输入或几百 token 输出处就 OOM,IFR 虽不 OOM 但时间随目标 span 线性涨;FlashTrace 内存 / 时间对目标 span 长度几乎平直——这是它能跑在 agent workflow 里的根本原因。
亮点与洞察¶
- 代数恒等式带来的免费午餐:把 \(\mathbf{F}_{j\to S}=\mathbf{v}_j \cdot (\sum_{i\in S}\alpha_{i,j}^h)\) 提取出来这一步看似简单,却是整篇文章的杠杆点——它让"对 span 的归因"在数学上变成对单 token 归因的标量加权,从而不引入任何近似就把 \(M\) 维循环消掉。这是 attention 线性性 + ALTI proximity 度量恰好可加性的合谋。
- 把多跳归因写成概率流:用 \(\rho_k\) 当"剩余 mass"、\((\prod \rho_j)\cdot \mathbf{w}_{\mathbf{I}}^{(k)}\) 当"第 \(k\) 跳沉降到输入的概率"这套语义,使得多跳聚合既有概率意义又天然支持早停,避免了"多跳越多越好/越多越乱"的玄学调参。这种"信息流分流"的视角可以迁移到任何需要把重要性沿生成链反向传播的解释器(如 tool-calling、ReAct trace、agent 长 memory 检索)。
- 正确诊断 = 半个解法:作者先用一组干净的 Figure 1 实验把"reasoning token 吸 mass→recovery rate 跌"这件事量化出来,把模糊的"CoT 让解释变差"变成两个可测指标;随后 method 章每一个设计都明确对应一个被诊断出来的痛点。这种"先做诊断 ablation 再上 method"的结构值得借鉴。
局限与展望¶
- proximity 是相关性而非反事实:作者自己强调 L1 proximity 衡量的是"源 token 表示有多少流进 target span",是 informational contribution;虽然他们用 RISE/MAS 这种 perturbation 指标验证了它能预测因果重要性,但严格的因果归因(counterfactual / mediation analysis)仍未做。
- 递归跳数固定为 \(K=1\) 的代价:默认 \(K=1\) 已能解掉大多数推理链,但 Appendix H 才讨论 \(K>1\) 的影响。对超长推理(几十次反思、多轮 tool-call 嵌套)是否需要自适应 \(K\)、\(\rho_k\) 阈值化早停,文中没在主文给出实操指南。
- 目标 span 必须连续:方法对"连续 span"做聚合很自然,但实际 reasoning trace 里关键步骤往往是非连续的(中间夹了大量噪声 token)。把"span"换成"任意权重子集"在公式上没问题,但如何自动挑选关键非连续目标(而不是简单地选 top-k)仍是 open 问题。
- 只验证到 8B 量级:实验只覆盖 Qwen-3 8B 和 LLaMA-3.1-8B-It,没在 30B+ 或 MoE 模型上跑。考虑到 MoE 的 attention 共享与 expert 路由会改变信息流动模式,generalization 还需验证。
相关工作与启发¶
- vs AttnLRP / IFR:同样基于 attention/relevance 传播,但 AttnLRP/IFR 只能对单 token target 算分,多 token 时只能"逐个跑+平均",又慢又会被中间推理吸 mass。FlashTrace 把目标从 token 升级为加权 span,并显式做多跳,是对这条技术路线的两个正交补丁。
- vs CAGE(concurrent work):CAGE 也用递归思路追多跳,但它在 sentence-level 切分目标、跑多次完整归因,成本随推理链段数线性涨,长 context 难用。FlashTrace 在 token-span 这一更细的粒度上做加权 span,配合 span-wise 代数化降到 \(\mathcal{O}(N)\),效率优势是数量级的。
- vs Integrated Gradients / Perturbation:IG 类方法理论解释力强、有 axiom guarantee,但梯度+多步积分让它在长 context 直接 OOM;Perturbation 类需要反复前向。FlashTrace 牺牲一点"完美的因果可解性"换来工程上可用的几秒级延迟,更接近 agent 时代真实部署需求。
- vs Circuit-tracer / Transcoder:那条线探的是模型内部 feature circuit,回答"哪个内部组件负责某种行为";FlashTrace 回答"哪个外部输入 token 触发了某段输出"。两条线互补:前者是机制可解释性,后者是数据可解释性,可以叠加用——先用 FlashTrace 定位关键输入,再用 circuit-tracer 看模型内部是怎么用的。
评分¶
- 新颖性: ⭐⭐⭐⭐ 把现有 ALTI/IFR 的代数性质用足 + 递归归因构造,思路朴素但精确踩中 reasoning LLM 时代的痛点。
- 实验充分度: ⭐⭐⭐⭐ 覆盖 NIAH / VT / HotpotQA / MATH / MoreHopQA / Aider 六类任务,含跨模型、跨任务、效率/忠实度 Pareto、与暴力 Exhaustive Rollout 的逼近度对比,主线很硬;扣分点是没上 30B+ 或 reasoning RL 模型。
- 写作质量: ⭐⭐⭐⭐⭐ 诊断(Sec 3.2)→ 方法(Sec 4)→ 实验的论证链非常顺,Figure 1/3/4 都直接服务于关键 claim,公式推导用最少符号交代清楚。
- 价值: ⭐⭐⭐⭐⭐ 长 context + 长 CoT 是当前 agent 部署的常态,能在秒级算出可信归因,对 debug、prompt 优化、agent trace 审计都是直接生产力,代码开源。