SRA: Span Representation Alignment for Large Language Model Distillation¶
会议: ACL 2026
arXiv: 2605.01205
代码: 论文未提供公开仓库
领域: 模型压缩 / 知识蒸馏 / Cross-Tokenizer Distillation
关键词: 跨分词器蒸馏、span alignment、center of mass、几何正则、LLM压缩
一句话总结¶
SRA 把跨分词器 LLM 蒸馏的基本对齐单元从易碎的 token 换成 tokenizer-agnostic 的文本 span,通过 LCS 字符偏移匹配、注意力加权 center-of-mass 表示、几何结构正则和共享词表 span logit 蒸馏,在多组 teacher-student 压缩实验中稳定超过 ULD、MinED、DSKD 和 MultiLevelOT。
研究背景与动机¶
领域现状:知识蒸馏是把大语言模型能力转移到小模型的常用压缩技术。传统 KD 往往假设 teacher 和 student 使用同一个 tokenizer,可以直接对齐 token 或 logit 分布;但真实部署中,不同模型家族常常使用不同词表和切分规则。
现有痛点:Cross-Tokenizer Knowledge Distillation 需要跨 tokenizer 对齐。已有方法要么用编辑距离、动态规划或 OT 处理 token 序列,要么把不同词表映射到统一空间,但 token 级别对齐容易受到分词粒度差异影响:同一段文本在 teacher 中可能是一个 token,在 student 中可能被切成多个 token。
核心矛盾:蒸馏想转移的是语义和表示动态,而 tokenizer mismatch 使 token 序列不再是一一对应的稳定单位。直接对齐 token 会把“分词差异”误当成“知识差异”。
本文目标:作者希望构造一个跨分词器也稳定的蒸馏单元,使 teacher 和 student 可以在相同文本 span 上对齐 hidden states、几何结构和预测分布。
切入角度:论文借用 Transformer 作为 Multi-Particle Dynamical System 的物理视角:token hidden states 像粒子位置,span 可以看作粒子簇,span 表示则是 attention-weighted center of mass。
核心 idea:先用字符偏移找到 teacher 和 student 都覆盖的文本 span,再用注意力加权聚合成 span 表示,并在 span 级别做 hidden-state、几何结构和 logit 蒸馏。
方法详解¶
SRA 的设计可以理解为“先找共同语义单位,再转移表示动态”。它避免直接在不同 tokenizer 的 token 序列之间硬匹配,而是回到原始字符串,用字符 offset 找到两边都能解释的 span。然后,SRA 不只让 student 的 span 向 teacher 的 span 靠近,还要求 span 与 span 之间的相对几何关系尽量保持。
整体框架¶
给定同一句文本,teacher tokenizer 和 student tokenizer 分别输出 token 序列和字符 offset。SRA 用 offset 序列的最长公共子序列构造对齐 span。对每个 span,模型从最后一层 hidden states 中通过注意力加权 pooling 得到 span representation。训练时,student 同时优化标准 CE、span hidden-state loss、几何结构正则和 span-level logit KD loss。
关键设计¶
-
基于 LCS 的 span mapping:
- 功能:在不同 tokenizer 之间建立可比的文本片段单位。
- 核心思路:对 teacher 和 student 的 token offset 序列计算 LCS,匹配相同字符边界形成 span pair,并忽略 offset 为 0 的特殊 token。这样得到的 span 覆盖原始文本中的共同子片段,而不强迫 token 数一致。
- 设计动机:token-level alignment 在跨 tokenizer 场景下很脆弱;字符 span 是原始文本层面的稳定单位,更适合作为知识转移载体。
-
注意力加权 Center-of-Mass span 表示:
- 功能:把一个 span 内多个 token 的 hidden states 聚合成单个语义表示。
- 核心思路:SRA 使用最后 token 对各 token 的注意力作为 token 重要性,归一化后计算 span 的加权平均。形式上,span representation 类似 \(C_i=\sum_{t=s_i}^{e_i} w_t H_t\),其中 \(w_t\) 来自最后层多头注意力聚合。
- 设计动机:简单 mean pooling 会稀释关键信息。CoM 类比强调“质量更大的粒子”更影响整体中心,对应到文本中就是更受关注的 token 对 span 表示贡献更大。
-
span-level hidden/logit 蒸馏与几何正则:
- 功能:让 student 同时学习 teacher 的局部 span 表示、span 间相对结构和共享词表预测分布。
- 核心思路:hidden-state loss 用 weighted cosine 对齐 teacher 与 student span 表示,并加入几何正则 \(L_{Geo}\) 保持 span 间 cosine distance;logit loss 将 teacher 和 student 的 span logits 投影到共享词表子空间 \(V_T\cap V_S\) 后做 KL 蒸馏。
- 设计动机:仅对齐点的位置可能被线性投影扭曲,几何正则能保留表示空间结构;仅对齐 hidden states 又可能缺少词汇预测知识,共享词表 logit loss 提供互补监督。
损失函数 / 训练策略¶
总目标为 \(L_{overall}=\alpha L_{CE}+(1-\alpha)(L_{HS}^{Span}+L_{KD}^{Span})\)。其中 \(L_{HS}^{Span}\) 包含加权 cosine loss 和几何结构正则,\(L_{KD}^{Span}\) 在共享词表空间对齐 span logits。训练数据使用 Databricks-Dolly-15k,评估覆盖 Dolly、VicunaEval、SelfInst、S-NI 和 DialogSum,指标为 ROUGE-L,并对结果取 5 个随机种子的平均值。
实验关键数据¶
主实验¶
| Teacher → Student | 最强非SRA基线 Avg | SRA Avg | 主要观察 |
|---|---|---|---|
| Qwen1.5-1.8B → GPT-2 120M | DSKD 15.35 | 17.97 | 小学生模型上提升最明显 |
| Qwen1.5-1.8B → GPT-2 340M | DSKD 15.57 | 18.10 | S-NI 从 17.18 提到 24.49 |
| Qwen2.5-7B → GPT-2 1.5B | DSKD 19.27 | 20.99 | 大 teacher 到 GPT-2 仍有效 |
| Qwen2.5-7B → OPT-2.7B | DSKD 20.15 | 20.92 | OPT 学生上保持领先 |
| Mistral-7B → TinyLLaMA-1.1B | DSKD 21.33 | 22.52 | 跨架构、跨词表仍稳健 |
| GPT-2 1.5B → GPT-2 120M | AKL 17.03 | 19.24 | 同 tokenizer 场景也能受益 |
消融实验¶
| 配置 | Qwen1.5→GPT-2 340M Avg | Qwen1.5→GPT-2 120M Avg | 说明 |
|---|---|---|---|
| 仅 span logit KD | 17.36 | 17.10 | 共享词表蒸馏已有收益 |
| span logit KD + 几何正则 | 17.94 | 17.72 | 几何结构保持带来稳定增益 |
| span logit KD + cosine | 17.54 | 17.32 | 表示点对齐有帮助但不如几何充分 |
| cosine + 几何正则 | 17.48 | 16.04 | 缺少 logit KD 时不够稳 |
| 完整 SRA | 18.10 | 17.97 | 三类信号互补效果最好 |
| WSL / WSP 配置 | GPT-2 340M Avg | GPT-2 120M Avg | 说明 |
|---|---|---|---|
| 去掉 WSL 与 WSP | 16.99 | 14.85 | span 表示质量明显下降 |
| 仅 WSL | 17.11 | 15.77 | 加权损失有一定帮助 |
| 仅 WSP | 17.36 | 15.89 | 加权 pooling 比 mean pooling 更重要 |
| WSL + WSP | 18.10 | 17.97 | 显示 span 权重设计是核心组件 |
关键发现¶
- SRA 在所有 teacher-student 配置上都取得最高平均 ROUGE-L,说明 span-level 对齐对跨 tokenizer 蒸馏是稳定收益,而不是某个模型对的偶然现象。
- 几何正则和注意力加权不是装饰项:去掉 WSP 或 WSL 都会掉点,尤其在 GPT-2 120M 这类小学生模型上更明显。
- 训练效率表显示 SRA 单步时间为 0.2754s,快于 DSKD 0.3520s、MinED 0.4244s 和 ULD 0.4393s;代价是显存 21.96GB,略高于 MinED/ULD 的 19.63GB。
亮点与洞察¶
- 最巧妙的是把 tokenizer mismatch 从离散 token 对齐问题转成连续 span 表示问题。span 来自原始字符边界,天然比 token 更接近语义单位。
- Multi-Particle / Center-of-Mass 类比虽然听起来偏理论,但落到实现上就是“用注意力重要性做 span pooling + 保留相对几何结构”,可操作性很强。
- SRA 不只适用于不同 tokenizer,同 tokenizer 蒸馏也有收益,说明它捕捉的不只是词表重叠问题,还包括 teacher-student 表示空间结构差异。
局限与展望¶
- 当前 logit mapping 是静态的,只在共享词表子空间中对齐,可能忽略非共享词表中携带的细粒度知识。
- span 表示对齐需要在线 teacher inference,若想预计算所有 teacher span embeddings,存储成本会非常高。
- 实验受计算预算限制,主要集中在固定 benchmark 和 decoder-to-decoder 设置;embedding 模型、encoder-decoder 模型和更长上下文任务还需要验证。
- LCS 基于字符 offset,面对语言混排、复杂 Unicode 切分或高度形态变化语言时,span 匹配质量可能成为瓶颈。
相关工作与启发¶
- vs ULD / MinED: 这些方法更关注 token 或编辑距离层面的对齐,SRA 退回到文本 span 层面,减少 tokenizer 粒度差异带来的噪声。
- vs DSKD: DSKD 通过统一空间做分布对齐,SRA 同时对齐 hidden geometry 和 span logits,知识通道更丰富。
- vs MultiLevelOT: OT 方法能处理分布匹配但计算和对齐复杂度较高,SRA 的 LCS + span pooling 更轻量,实验中单步时间更短。
- vs 同 tokenizer KD 方法: SeqKD、RKL、JS、SKL、AKL 等假设词表一致,SRA 在同 tokenizer 下仍能提升,提示 span 几何蒸馏可以作为通用 KD 组件。
评分¶
- 新颖性: ⭐⭐⭐⭐☆ span-level CoM 表示和几何正则组合很有辨识度,物理视角也带来清晰设计动机。
- 实验充分度: ⭐⭐⭐⭐☆ teacher-student 组合、同/跨 tokenizer、消融和效率都覆盖较好;更大规模和更多任务还可扩展。
- 写作质量: ⭐⭐⭐⭐☆ 方法链条完整,公式较多但与实现对应清晰。
- 价值: ⭐⭐⭐⭐⭐ 对跨模型家族蒸馏和小模型部署非常实用,尤其适合 tokenizer 不一致的真实压缩场景。