跳转至

Dispersion Loss Counteracts Embedding Condensation and Improves Generalization in Small Language Models

会议: ICML 2026
arXiv: 2602.00217
代码: https://github.com/KrishnaswamyLab/LM-Dispersion
领域: 模型压缩 / 表征学习 / 小模型训练
关键词: 嵌入坍缩, 分散损失, 小模型泛化, 知识蒸馏, GPT2 / Qwen3

一句话总结

本文系统观测到 "小语言模型的 token 嵌入会随深度坍缩到一个窄锥体"(embedding condensation)这个普遍现象——大模型反而不会——并设计了一个角度分散损失 \(\mathcal{L}_{\text{disp}}\) 直接逼嵌入散开,无须加参数就让 Qwen3 / GPT2 在 10 个 benchmark 上平均提升 3.3%。

研究背景与动机

领域现状:LLM 的能力随规模 scaling 持续提升,但训练 / 部署成本飞涨,迫切需要 "用小模型复现大模型的关键性质"。已有压缩路线主要是蒸馏、量化、剪枝,都侧重模仿大模型的输出分布。

现有痛点:作者从 representation geometry 视角发现,小模型(GPT2-small、Qwen3-0.6B)的 token 嵌入在深层会几乎全部对齐到同一方向,pairwise cosine similarity 趋近 1;大模型(GPT2-xl、Qwen3-32B)则维持嵌入分散。Geshkovski 2025 的理论也证明 Transformer 在层数趋无穷时嵌入会坍缩成一个点,但经验上没人系统验证过和性能的关系。

核心矛盾:嵌入坍缩意味着模型实际可用的 "表征方向" 越来越少,表达能力被几何上锁死。即使蒸馏从大教师学到 logit 分布,也无法把大模型的几何性质继承下来——因为蒸馏目标只约束输出,不约束中间嵌入。

本文目标:(1) 把 embedding condensation 现象量化测量,确认 "大模型抗坍缩" 是普适规律;(2) 验证蒸馏不能缓解;(3) 设计直接作用于几何的辅助损失,让小模型主动散开嵌入。

切入角度:既然大模型 "自动" 维持分散,那分散本身可能就是性能的瓶颈条件。与其堆参数让模型 "自然" 分散,不如显式加一个目标函数强制分散。

核心 idea:用基于角度的分散损失 \(\mathcal{L}_{\text{disp}} = \log \sum_{i \neq j} \exp(-\arccos(\cos\text{sim}(z_i, z_j)) / \pi\tau)\) 把所有 token 嵌入推向单位超球面上的均匀分布,零额外参数。

方法详解

整体框架

方法分两阶段:(1) 诊断阶段——量化嵌入坍缩;(2) 干预阶段——把分散损失加进训练。诊断用 Spearman \(\rho\) 和 Kendall \(\tau\) 衡量 layer-wise 平均余弦相似度的单调上升程度。干预把 \(\mathcal{L}_{\text{disp}}\) 当正则项加到原训练目标:\(\mathcal{L} = \mathcal{L}_{\text{train}} + \lambda_{\text{disp}} \cdot \mathcal{L}_{\text{disp}}\)

关键设计

  1. 角度分散损失(核心 dispersion loss):

    • 功能:在单位超球面上把所有 token 嵌入推开。
    • 核心思路:对每层每个 token pair \((z_i, z_j)\),先把余弦相似度映射到角度距离 \(D(z_i, z_j) = \arccos(\cos\text{sim}(z_i, z_j)) / \pi \in [0, 1]\);然后用 log-sum-exp 聚合 \(\mathcal{L}_{\text{disp}} = \log \sum_{i \neq j} \exp(-D(z_i, z_j)/\tau)\)。距离小(同向)时 \(\exp\) 项大、损失大,被推开;距离大(接近正交)时几乎为零。所有层的损失求和,每 batch 复杂度 \(\mathcal{O}(N^2 F)\),可通过 token 维度子采样减负。
    • 设计动机:(1) 用 \(\arccos\) 而非裸 cosine 是为了数值稳定,避免在两端饱和;(2) log-sum-exp 比 mean 在数值上更稳健,差一个加性常数不影响梯度;(3) 显式排除对角项防止 \(z_i\) 自相似爆梯度;(4) 选择 angular 而非欧式距离是因为坍缩本质是方向问题不是长度问题。
  2. 三个备选公式(消融用):

    • 功能:验证 "分散" 这一目标的不同实现方式是否都有效,从而隔离 "角度分散" 与其他变体的优劣。
    • 核心思路:(a) Decorrelation——最小化嵌入协方差矩阵的非对角元,间接降低不同特征维度间的耦合;(b) \(\ell_2\)-repel——直接拉大 token 间的欧式距离,但要配 norm 正则 \(\lambda_{\text{norm}} \|\mathcal{Z}\|_2^2\) 防止靠膨胀 norm 作弊;(c) Orthogonalization——铰链式损失 \(\max(0, 1/2 - D(z_i, z_j))^2\),只惩罚距离 < 1/2(锐角对),让钝角对自由生长。
    • 设计动机:dispersion 本身是个抽象诉求,作者通过四种实现的对比想说明 "在角度空间均匀分散" 比 "在特征维度上去相关" 或 "在欧氏空间排斥" 更直接、更有效,强化主方法的合理性。
  3. 覆盖 mid-training + full pre-training 的应用策略:

    • 功能:把分散损失嵌入两种实际训练流程,证明既能 retrofit 既有模型也能从零训练。
    • 核心思路:Mid-training 阶段——拿现成的 GPT2 / Qwen3 在 wikitext-103 上继续训 200M token,单 A100 即可跑完;Full pre-training 阶段——Qwen3 在 C4 上从零训 156B token,用 640 GPU。两种场景下都把 \(\lambda_{\text{disp}} \cdot \mathcal{L}_{\text{disp}}\) 加到 cross-entropy 上,loss 在每个 forward 同时计算多层嵌入的 dispersion 并加权。
    • 设计动机:mid-training 是低成本的 proof-of-concept 和超参扫描场所;full pre-training 验证从零开始时该信号能塑造更好的几何,从根本上改变模型容量。

损失函数 / 训练策略

训练目标 \(\mathcal{L} = \mathcal{L}_{\text{CE}} + \lambda_{\text{disp}} \cdot \mathcal{L}_{\text{disp}}\)。温度 \(\tau\) 与权重 \(\lambda_{\text{disp}}\) 是主要超参,论文 Appendix 给了扫描结果。Mid-training 用 3 个 seed 报均值方差,full pre-training 跑单 seed 但 token 量足够稳定。

实验关键数据

主实验

GPT2 mid-training(10 个 benchmark 平均):

配置 训练成本 平均分↑ rank↓ 显著性
GPT2 原版(无 mid-training) 34.35 6.1 p<0.0001
+ \(\mathcal{L}_{\text{CE}}\) only 1.122 A100h 34.95 6.2 p<0.01
+ noisy embedding 1.122 35.15 4.3 p<0.01
+ active forgetting 1.127 35.36 3.2 n.s.
+ Dispersion loss 1.13 (1.01×) 35.52+ 最佳

Qwen3 full pre-training(156B token from scratch):加 dispersion loss 平均提升 +1.17 分(3.3% 相对增益),所有 benchmark 上稳定提升。

消融实验

四种 dispersion 变体对比:

变体 平均分 备注
Decorrelation 35.1 间接,受 feature dim 影响
\(\ell_2\)-repel 35.0 需 norm 正则才稳定
Orthogonalization 35.2 只惩罚锐角
Dispersion (canonical) 35.5+ 角度均匀分散,最优

规模对照(confounder-controlled):从零训四个 GPT2-like 模型,只改 MLP 维度其余固定,larger MLP → 更小 condensation,验证 "大模型抗坍缩" 不是其他因素带来的伪相关。

关键发现

  • 蒸馏救不了坍缩:Qwen2.5 系列蒸馏后嵌入几何与从零训几乎一样,因为 KD loss 只约束输出 logits,不规范中间表征——这成为该论文最直观的动机证据。
  • 坍缩从初始化就有,但训练能缓解:Olmo-3-7B 的 checkpoint 显示 condensation 指标在初始化时正且大,随训练降到负值,说明 SGD 本身就在抵抗坍缩,分散损失只是加速 + 强化。
  • 小模型增益更大:Qwen3-0.6B 提升最明显,Qwen3-32B 几乎无收益,符合 "大模型本来就分散" 的假说。
  • mid-training 即可生效:不需要重训,对已有 checkpoint 加 200M token + dispersion 即获明显收益,落地成本极低。
  • 代价 < 1% 训练时间:1.13 vs 1.122 A100h,因为 \(N^2\) 配对可子采样。

亮点与洞察

  • "小模型的瓶颈是几何不是参数":把性能差距归因于表征几何而非容量本身,是该论文最有想象力的命题——意味着可以在不加参数的前提下逼近大模型性能上限。
  • 角度而非欧氏距离:用 \(\arccos\) 把 cosine 映成均匀的角度距离,避开 cosine 在两端的饱和,是工程上很关键的稳定化 trick。
  • 理论 → 经验 → 干预 → 验证 闭环:从 Geshkovski 2025 的理论坍缩定理出发,提供大规模经验证据,设计干预,再用 confounder 控制实验证实,整篇文章的论证链非常完整。
  • 零参数增量:与剪枝、量化、LoRA 等需要改架构的方法不同,dispersion loss 只是训练时附加项,可与任何主流 LM 训练 pipeline 即插即用。

局限与展望

  • 未扩展到对齐 / 推理任务:实验全在 zero/few-shot 通用 NLU,未测 reasoning / math / code 等需要更复杂表征几何的任务。
  • 大模型几乎无增益:32B 上没看到提升,分散假说是否对大模型成立未确认;也许大模型的瓶颈在别处。
  • token 子采样的代价:实际跑大模型时 \(N^2\) 还是要子采样,子采样如何影响收敛和最终性能没仔细消融。
  • 未与 anti-collapse 技术(如 SimSiam stop-gradient、BarlowTwins)比较:自监督表征学习里有大量去坍缩方法,应横向比对。
  • 未解释为何 \(\lambda_{\text{disp}}\) 不需要复杂调度:直觉上训练早期需要更强分散、后期可减弱,但论文用固定权重就 work,这是个有趣的现象但未深入分析。
  • 架构敏感性:在 RMSNorm(无 affine LN)、不同位置编码下 dispersion 是否仍生效未测。

相关工作与启发

  • vs Wang & He 2025(diffusion 中的 dispersion): 那篇把 dispersion 用在生成模型上,本文把它迁移到语言模型并改成角度形式 + 显式去对角,是同一思路的领域适配。
  • vs noisy embedding / active forgetting: 这些 trick 也试图增加表征多样性,但都是间接手段且无几何解释;dispersion 直接、可解释,效果也更好。
  • vs 蒸馏类压缩:蒸馏只能传输输出行为,dispersion 传输的是表征几何这一更基本的性质;二者可叠加。
  • vs Cai 2021、Bis 2021(isotropy 研究): 都关注嵌入空间的各向同性 / 异性,本文给出了第一个明确的训练目标使之可控。
  • 启发:这个思路可以推广到 (1) vision encoder 的 patch token 嵌入;(2) 多模态对齐时不同模态嵌入之间的分散;(3) MoE 专家激活模式的分散,避免专家坍缩。

评分

  • 新颖性: ⭐⭐⭐⭐ 把理论坍缩定理转化为可训练的辅助损失,并配套提出 4 种变体,属于 "已知现象 + 干净干预" 类新颖性,不是石破天惊但完整可信。
  • 实验充分度: ⭐⭐⭐⭐⭐ mid-training + full pre-training(156B token / 640 GPU)双重验证,10 个 benchmark + 3 seed + confounder 控制 + 4 种 dispersion 变体消融,少见的扎实。
  • 写作质量: ⭐⭐⭐⭐⭐ 论证链 "理论 → 经验观测 → 反例(蒸馏无效)→ 干预 → 验证" 非常顺,图表清晰,把抽象的几何现象讲得像故事一样。
  • 价值: ⭐⭐⭐⭐ 几乎零成本即插即用,对小模型训练社区直接受益;只是大模型上无收益、未测推理 / 代码任务,长期影响力有上限。