跳转至

Stacked From One: Multi-Scale Self-Injection for Context Window Extension

会议: ICLR 2026
OpenReview: https://openreview.net/forum?id=lh3Aa1u7kU
代码: https://github.com/Clement25/SharedLLM
领域: LLM效率
关键词: 上下文窗口扩展, 长上下文压缩, KV共享, 上下文树, 自注入

一句话总结

SHAREDLLM 把同一个短上下文 LLM 拆成「下层压缩器 + 上层解码器」两份堆叠模型,下层把长输入压成粗到细的多粒度上下文树并只在最底部几层把 KV「自注入」给上层,于是仅用 8K 序列训练就能外推到 128K,速度比流式快 2 倍、比 encoder-decoder 快 3 倍,且性能持平或更优。

研究背景与动机

领域现状:把 LLM 的上下文窗口从几 K 扩到 128K,目前主流有三条路。一是在长语料上继续预训练并配合 RoPE 插值类位置编码(PI、YaRN),靠「train short, test long」做外推;二是 prompt 压缩,用语义 token 替换长 prompt;三是改造 transformer 做流式处理(StreamingLLM、Activation Beacon),维持一个常数大小的滑动窗口记忆。

现有痛点:第一条路代价高得离谱——YaRN 要做到 128K,得在 64K 长度上预训练,数据获取和算力成本都难以承受。Prompt 压缩只能加速推理,并不能真正扩展窗口,适用场景也窄。流式方法虽然把内存压成常数,但它们的专用注意力模式往往和 FlashAttention 这类高性能实现不兼容,输入越长反而越慢。还有一类 encoder-decoder 思路(CEPE)把过去上下文喂给一个独立 encoder(如 24 层 RoBERTa),但 encoder 与 decoder 的隐空间不一致,必须额外加一个 warmup 阶段对齐,训练链路又长又重。

核心矛盾:扩展上下文窗口本质是在「效率」和「性能」之间拉扯。要省内存就得压缩或流式,但压缩/流式要么丢信息、要么破坏注意力的硬件友好性;要保性能就得喂全量 token,又回到二次复杂度。已有方法没能同时拿下两端。

本文目标:在不做昂贵长序列预训练、不引入异构编码器的前提下,既把长输入压到可控内存,又保住下游长上下文任务的性能,还要跑得快。

切入角度:作者观察到一个被忽视的事实——如果压缩器和解码器干脆用同一个 LLM 的权重初始化,两者隐空间天然一致,就不再需要 warmup 对齐;再加上长文本里「任务相关信息分布不均」(摘要看主题句、找密钥看细节),可以用一棵树做粗到细的多粒度压缩,让模型按 query 自适应地决定哪里该细、哪里粗一笔带过。

核心 idea:把一个短上下文 LLM「堆成两份」——下层当压缩器把长上下文编成多粒度上下文树、上层当解码器,二者只在最底部 \(M\) 层通过共享 KV 做「自注入(self-injection)」,用极少可训练参数实现长上下文外推。

方法详解

整体框架

SHAREDLLM 要解决的是「怎么把超长上下文塞进一个短窗口 LLM 还不掉性能」。它的做法是把输入序列 \(X\) 切成两段 \(X=\text{concat}([X_C; X_D])\):过去上下文 \(X_C\) 交给下层模型(压缩器),当前运行文本 \(X_D\)(如问题)交给上层模型(解码器)。下层把 \(X_C\) 再切成 \(n\) 个不重叠 chunk \(\{C_i\}\),每个 chunk 并行地编成一棵上下文树,按 query 动态展开相关节点、对保留节点的 KV 做层级下采样,得到高度压缩的多粒度 KV。这些 KV 只在上层的最底部 \(M\) 层经由 cross-attention 注入,上层据此自回归生成。整套流程把自注意力复杂度从 \(O(T^2)\) 降到 \(O(n\cdot(T/n)^2 + T_D\cdot|S'|)\),其中 \(|S'|\) 是压缩后的 KV 长度。

%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400}}}%%
flowchart TD
    A["输入序列 X<br/>过去上下文 X_C + 运行文本 X_D"] --> B["切成 n 个 chunk<br/>并行送入下层模型"]
    B --> C["共享模型·自注入<br/>上下层同一 LLM 权重"]
    C --> D["多粒度上下文树<br/>粗到细 KV 压缩"]
    D -->|query 驱动| E["查询相关动态构建与搜索<br/>只展开相关节点"]
    E --> F["位置感知 Tree Cross-attention<br/>仅注入底部 M 层"]
    F --> G["上层解码器自回归生成"]

关键设计

1. 自注入:上下层共用同一个 LLM,只在底部 M 层交换 KV

最直击痛点的设计是「自注入(self-injection)」——下层压缩器就是目标 LLM 的前 \(M\) 个浅层,上层解码器是同一个 LLM 的全层版本,两者从同一个现成 checkpoint 初始化。这一刀解决的是 CEPE 那种 encoder-decoder 架构的隐空间错位问题:因为权重同源,下层产出的 KV 和上层的隐空间本就在同一个语义坐标系里,无需任何 warmup 对齐阶段就能直接微调,训练成本大幅下降。

更关键的是信息只在最底部 \(M\)注入(语言建模设 \(M=4\),SFT 设 \(M=16\))。下层只跑前 \(M\) 层就停,省掉了完整 forward 的漫长前传;上层也只在底部 \(M\) 层插入 cross-attention,绕过了冗余的层间投影。这条「浅层压缩 + 浅层注入」的短路径让 SHAREDLLM 既不像 CEPE 要把 chunk 过完 24 层 RoBERTa,也不像流式方法破坏注意力实现,从而拿到 2× / 3× 的提速。消融里把注入层从「连续底部」换成「连续顶部」或「交错」,性能都明显变差,且底部方案前传/反传路径最短、效率最高。

2. 多粒度上下文树:用一棵二叉树做粗到细的 KV 压缩

针对「长文本里任务相关信息分布不均」这个观察,作者为每个 chunk 建一棵二叉上下文树。根节点装整个 chunk,非叶节点 \(\{x_{u+1},...,x_{u+l}\}\) 按式 \(C_{\text{left}}=\{x_{u+k}\}_{k=1}^{b},\ C_{\text{right}}=\{x_{u+k}\}_{k=b+1}^{l}\) 一分为二,切点 \(b=\lfloor l/2-\epsilon\rfloor\),其中 \(\epsilon\sim N(0,\sigma^2)\) 是随机扰动。这个噪声有两重作用:训练时充当结构化数据增强,防止模型过拟合到固定切分位置;同时降低硬切语义边界(如把一个实体名拦腰斩断)的风险,逼模型学到对边界鲁棒的表示。测试时 \(\epsilon\) 固定为 0 做确定性切分。

树建好后对每个保留节点取出全 \(M\) 层 KV \(S=\{K,V\}\in\mathbb{R}^{M\times l\times d}\),沿长度维做均匀下采样(用分数步长选择保证等距保留,不引入额外池化参数)得到 \(S'=\{K',V'\}\),单节点压缩率 \(\alpha=l/l'\)。妙处在于层级递减的压缩率:第 \(w\) 层用统一压缩率 \(\alpha_w\),且自上而下逐层衰减 \(\alpha_w=2\alpha_{w+1}\)。这就造出粗到细的语义分布——高层节点子序列长、压得狠,只保留粗粒度信息;底层节点压得轻,保留细粒度细节。整棵树的全局压缩率 \(\beta=\frac{|C|}{\sum_w l'_w n_w}\)\(n_w\) 是第 \(w\) 层保留节点数),实验表明 \(\beta\) 可高达 8 而性能不崩。

3. 查询相关的动态树构建与搜索:只为相关节点花算力

如果对每个 chunk 都建满整棵静态树,GPU 内存和时间都浪费在与 query 无关的节点上。于是作者改成查询驱动的动态树:从根开始做深度优先的「边切边选」,每个节点先按式 (1) 切成左右子序列,再用一个非参数策略 \(\pi((\vec{x}_{\text{left}},\vec{x}_{\text{right}}),\vec{y})\to\text{left or right}\) 决定继续展开哪个孩子,未选中的兄弟节点标记为「保留」、不再展开(根节点恒被选中)。

策略 \(\pi\) 是任务相关的。语言建模(非 SFT)时没有显式 query,固定选右分支 \(\pi\equiv\text{right}\) 来模拟有用的 \(\Lambda\) 形注意力模式;指令跟随(SFT)时 query 显式可得,则选与 query 语义相似度更高的节点:\(\pi=\arg\max_{\phi\in\{\text{left},\text{right}\}}\text{sim}(\vec{h}_{\vec{x}_\phi},\vec{h}_{\vec{y}})\),相似度取末位 token 隐向量的余弦相似度,\(\vec{h}_{\vec{x}_\phi}\) 由下层一层自注意力前传得到、\(\vec{h}_{\vec{y}}\) 由上层得到。一路递归到叶节点,左右孩子都标「保留」。这样只对真正相关的路径做细粒度展开,无关节点停在粗粒度,省下大量内存和时间。消融显示,去掉这个 query-aware 机制在 query 驱动任务上掉点最多,是三项辅助设计里最关键的一个。

4. 位置感知的 Tree Cross-attention:把打乱的 chunk KV 重新按时序对齐

下层并行处理各 chunk,产出的 KV 序列其实是「打乱」的,若直接喂给上层,会丢失原文的全局时序。作者在 cross-attention 里给 query \(Q\) 和压缩后的 key \(K\) 赋 chunk 级位置索引:\(P_Q\) 全取最大值 \(n\)(因为 \(Q\) 来自排在所有上下文之后的 \(X_D\)),\(P_K\) 则按 chunk 顺序给每个 chunk 的压缩 KV 块依次赋 \(0,1,...,n-1\),再按这些块索引施加 RoPE,让 cross-attention 自然尊重 query 与各压缩 chunk 之间的相对距离。最终输出以残差方式融入上层自注意力状态:\(O=\text{cross\_attn}(Q,\text{concat}([K'_1;...;K'_n]),\text{concat}([V'_1;...;V'_n]))\)。消融里去掉 chunk 位置索引(w/o chunk pid)性能下降,印证了「解码器必须知道 chunk 时序」这一点。

损失函数 / 训练策略

训练用标准语言建模损失 \(L=-\sum_{x_t\in X_{\text{tar}}}\log P(x_t\mid X_C; x_{<t})\)。语言建模数据里 \(X_{\text{tar}}=X_D\)(除首 token 外的全部运行文本);指令跟随数据里 \(X_D\) 含指令 \(X_{\text{inst}}\) 与标注回复 \(X_{\text{res}}\),此时设 \(X_{\text{tar}}=X_{\text{res}}\),即只对回复 token 计损失、指令文本被 mask。cross-attention 层全程可训,语言建模阶段额外训练上层顶部 \(N-M\) 个自注意力层做 post-injection 聚合以加速收敛。数据用 RedPajama 采样 20B(1%)token 截断到 8192,8× A800 训练。

实验关键数据

主实验

语言建模困惑度(continual pretraining 设置,LLaMA-2 基座,越低越好):

数据集 长度 SHAREDLLM CEPE YaRN
Arxiv 32K 2.46 2.51 2.58
Arxiv 128K 2.91 2.97 OOM
PG19 128K 5.96 6.10 OOM
ProofPile 128K 2.40 2.39 OOM

仅在 8K 训练却能在 128K 不发生困惑度爆炸,外推能力强;除 ProofPile-128K 外几乎全面优于 CEPE,且无需 CEPE 的额外预训练 + warmup。

长上下文理解(SFT,LongBench 14 任务 + InfBench 三任务,LLaMA-2 基座):

方法 MD-QA Summ. Math.F Ret.N
Activation Beacon 28.44 25.15 12.14 80.58
LongAlpaca-16K 28.10 27.80 6.23 4.87
SHAREDLLM 30.93 25.76 13.82 82.79

五大类任务上持平或超过强基线,在极长输入的数值检索(Ret.N)等任务优势明显。

消融实验

配置 arxiv (ppl↓) MD-QA (F1↑) 说明
Default(底部注入) 2.46 30.93 完整模型
Continuous Top 2.61 28.66 改成顶部注入
Interleaving 2.57 29.15 交错注入
w/o query-aware 29.27 去掉 query 相关展开
w/o noise 2.51 30.08 去掉切分噪声
w/o chunk pid 2.49 29.81 去掉 chunk 位置索引

关键发现

  • 底部注入最优:连续底部注入既性能最好,又因前传/反传路径最短而效率最高,顶部/交错都更差。
  • query-aware 贡献最大:在 query 驱动任务上去掉它掉点最多,是三项辅助设计里最关键的。
  • 效率优势显著:内存近常数,推理速度比流式(Activation Beacon)快约 2×、比 encoder-decoder(CEPE)快约 3×;YaRN 因 \(O(L^2)\) 复杂度在 128K 直接 OOM。
  • 超参敏感区:树高 < 3 或压缩率 < 8 时趋势不稳定——树太矮则欠分割只剩粗信息、丢细节,树太高则过分割叶子只剩琐碎细节、丢全局视角;压缩率取 8 是性能与效率的甜点。

亮点与洞察

  • 「自注入」是最优雅的一笔:同一个 LLM 既当压缩器又当解码器,用权重同源直接消除了 encoder-decoder 的隐空间对齐难题,省掉 warmup,这是它比 CEPE 训练更省、推理更快的根因。
  • 多粒度树 + 层级递减压缩率把「信息分布不均」这个直觉变成了可计算的结构:高层粗、底层细,再叠加 query 驱动的动态展开,等于让模型按需分配压缩预算。
  • 只在底部 M 层交换 KV 是兼顾效率与硬件友好的关键——既缩短计算路径,又不引入像流式方法那样和 FlashAttention 不兼容的专用注意力。
  • 这套「同模型堆叠 + 浅层 KV 注入」的思路可迁移到其他需要长输入的场景(检索增强、长文档摘要),核心可复用 trick 是「用同源权重避免跨模块对齐」。

局限与展望

  • 训练集因版权问题剔除了 books3 子集,作者在附录单独分析其影响,但这意味着主结果的语料覆盖与早期工作不完全可比。
  • 性能对树高和压缩率较敏感(树高 < 3、压缩率 < 8 时不稳定),需要按任务调参才能稳定到甜点区。
  • 策略 \(\pi\) 在语言建模时简单固定选右分支来模拟 \(\Lambda\) 形模式,这是个较粗的启发式,对非 \(\Lambda\) 形注意力分布的任务未必最优。
  • 部分基线按其默认设置做了中间截断,可能反而降低了任务难度、抬高其分数,横向比较需带这个 caveat(原文亦指出)。

相关工作与启发

  • vs CEPE(encoder-decoder):CEPE 用独立的 RoBERTa encoder,chunk 要过完 24 层 + 层间线性投影,隐空间还要 warmup 对齐;SHAREDLLM 用同源浅层做压缩、无需对齐,故训练更省、推理快约 3×。
  • vs Activation Beacon / StreamingLLM(流式):流式维持常数滑动窗口但专用注意力与 FlashAttention 不兼容,输入越长越慢;SHAREDLLM 用标准注意力 + 浅层注入,速度约 2× 且性能更优。
  • vs YaRN / PI(位置编码外推):它们靠 RoPE 重缩放做外推,但仍是 \(O(L^2)\) 全注意力,128K 直接 OOM;SHAREDLLM 靠层级压缩把 KV 长度压到 \(|S'|\),避免内存爆炸。

评分

  • 新颖性: ⭐⭐⭐⭐ 「自注入 + 多粒度上下文树」组合点子巧,把权重同源用到了刀刃上。
  • 实验充分度: ⭐⭐⭐⭐⭐ 覆盖语言建模/理解/效率/消融,多基座(LLaMA-2/3、Mistral)多基线对比扎实。
  • 写作质量: ⭐⭐⭐⭐ 方法叙述清晰、动机问题驱动,个别符号与图示偏密。
  • 价值: ⭐⭐⭐⭐ 仅 8K 训练外推到 128K 且 2×/3× 提速,对低成本长上下文落地很实用。