Dispersion Loss Counteracts Embedding Condensation and Improves Generalization in Small Language Models¶
会议: ICML 2026
arXiv: 2602.00217
代码: https://github.com/KrishnaswamyLab/LM-Dispersion
领域: 模型压缩 / 表征学习 / 小模型训练
关键词: 嵌入坍缩, 分散损失, 小模型泛化, 知识蒸馏, GPT2 / Qwen3
一句话总结¶
本文系统观测到 "小语言模型的 token 嵌入会随深度坍缩到一个窄锥体"(embedding condensation)这个普遍现象——大模型反而不会——并设计了一个角度分散损失 \(\mathcal{L}_{\text{disp}}\) 直接逼嵌入散开,无须加参数就让 Qwen3 / GPT2 在 10 个 benchmark 上平均提升 3.3%。
研究背景与动机¶
领域现状:LLM 的能力随规模 scaling 持续提升,但训练 / 部署成本飞涨,迫切需要 "用小模型复现大模型的关键性质"。已有压缩路线主要是蒸馏、量化、剪枝,都侧重模仿大模型的输出分布。
现有痛点:作者从 representation geometry 视角发现,小模型(GPT2-small、Qwen3-0.6B)的 token 嵌入在深层会几乎全部对齐到同一方向,pairwise cosine similarity 趋近 1;大模型(GPT2-xl、Qwen3-32B)则维持嵌入分散。Geshkovski 2025 的理论也证明 Transformer 在层数趋无穷时嵌入会坍缩成一个点,但经验上没人系统验证过和性能的关系。
核心矛盾:嵌入坍缩意味着模型实际可用的 "表征方向" 越来越少,表达能力被几何上锁死。即使蒸馏从大教师学到 logit 分布,也无法把大模型的几何性质继承下来——因为蒸馏目标只约束输出,不约束中间嵌入。
本文目标:(1) 把 embedding condensation 现象量化测量,确认 "大模型抗坍缩" 是普适规律;(2) 验证蒸馏不能缓解;(3) 设计直接作用于几何的辅助损失,让小模型主动散开嵌入。
切入角度:既然大模型 "自动" 维持分散,那分散本身可能就是性能的瓶颈条件。与其堆参数让模型 "自然" 分散,不如显式加一个目标函数强制分散。
核心 idea:用基于角度的分散损失 \(\mathcal{L}_{\text{disp}} = \log \sum_{i \neq j} \exp(-\arccos(\cos\text{sim}(z_i, z_j)) / \pi\tau)\) 把所有 token 嵌入推向单位超球面上的均匀分布,零额外参数。
方法详解¶
整体框架¶
方法分两阶段:(1) 诊断阶段——量化嵌入坍缩;(2) 干预阶段——把分散损失加进训练。诊断用 Spearman \(\rho\) 和 Kendall \(\tau\) 衡量 layer-wise 平均余弦相似度的单调上升程度。干预把 \(\mathcal{L}_{\text{disp}}\) 当正则项加到原训练目标:\(\mathcal{L} = \mathcal{L}_{\text{train}} + \lambda_{\text{disp}} \cdot \mathcal{L}_{\text{disp}}\)。
关键设计¶
-
角度分散损失(核心 dispersion loss):
- 功能:在单位超球面上把所有 token 嵌入推开。
- 核心思路:对每层每个 token pair \((z_i, z_j)\),先把余弦相似度映射到角度距离 \(D(z_i, z_j) = \arccos(\cos\text{sim}(z_i, z_j)) / \pi \in [0, 1]\);然后用 log-sum-exp 聚合 \(\mathcal{L}_{\text{disp}} = \log \sum_{i \neq j} \exp(-D(z_i, z_j)/\tau)\)。距离小(同向)时 \(\exp\) 项大、损失大,被推开;距离大(接近正交)时几乎为零。所有层的损失求和,每 batch 复杂度 \(\mathcal{O}(N^2 F)\),可通过 token 维度子采样减负。
- 设计动机:(1) 用 \(\arccos\) 而非裸 cosine 是为了数值稳定,避免在两端饱和;(2) log-sum-exp 比 mean 在数值上更稳健,差一个加性常数不影响梯度;(3) 显式排除对角项防止 \(z_i\) 自相似爆梯度;(4) 选择 angular 而非欧式距离是因为坍缩本质是方向问题不是长度问题。
-
三个备选公式(消融用):
- 功能:验证 "分散" 这一目标的不同实现方式是否都有效,从而隔离 "角度分散" 与其他变体的优劣。
- 核心思路:(a) Decorrelation——最小化嵌入协方差矩阵的非对角元,间接降低不同特征维度间的耦合;(b) \(\ell_2\)-repel——直接拉大 token 间的欧式距离,但要配 norm 正则 \(\lambda_{\text{norm}} \|\mathcal{Z}\|_2^2\) 防止靠膨胀 norm 作弊;(c) Orthogonalization——铰链式损失 \(\max(0, 1/2 - D(z_i, z_j))^2\),只惩罚距离 < 1/2(锐角对),让钝角对自由生长。
- 设计动机:dispersion 本身是个抽象诉求,作者通过四种实现的对比想说明 "在角度空间均匀分散" 比 "在特征维度上去相关" 或 "在欧氏空间排斥" 更直接、更有效,强化主方法的合理性。
-
覆盖 mid-training + full pre-training 的应用策略:
- 功能:把分散损失嵌入两种实际训练流程,证明既能 retrofit 既有模型也能从零训练。
- 核心思路:Mid-training 阶段——拿现成的 GPT2 / Qwen3 在 wikitext-103 上继续训 200M token,单 A100 即可跑完;Full pre-training 阶段——Qwen3 在 C4 上从零训 156B token,用 640 GPU。两种场景下都把 \(\lambda_{\text{disp}} \cdot \mathcal{L}_{\text{disp}}\) 加到 cross-entropy 上,loss 在每个 forward 同时计算多层嵌入的 dispersion 并加权。
- 设计动机:mid-training 是低成本的 proof-of-concept 和超参扫描场所;full pre-training 验证从零开始时该信号能塑造更好的几何,从根本上改变模型容量。
损失函数 / 训练策略¶
训练目标 \(\mathcal{L} = \mathcal{L}_{\text{CE}} + \lambda_{\text{disp}} \cdot \mathcal{L}_{\text{disp}}\)。温度 \(\tau\) 与权重 \(\lambda_{\text{disp}}\) 是主要超参,论文 Appendix 给了扫描结果。Mid-training 用 3 个 seed 报均值方差,full pre-training 跑单 seed 但 token 量足够稳定。
实验关键数据¶
主实验¶
GPT2 mid-training(10 个 benchmark 平均):
| 配置 | 训练成本 | 平均分↑ | rank↓ | 显著性 |
|---|---|---|---|---|
| GPT2 原版(无 mid-training) | — | 34.35 | 6.1 | p<0.0001 |
| + \(\mathcal{L}_{\text{CE}}\) only | 1.122 A100h | 34.95 | 6.2 | p<0.01 |
| + noisy embedding | 1.122 | 35.15 | 4.3 | p<0.01 |
| + active forgetting | 1.127 | 35.36 | 3.2 | n.s. |
| + Dispersion loss | 1.13 (1.01×) | 35.52+ | 最佳 | — |
Qwen3 full pre-training(156B token from scratch):加 dispersion loss 平均提升 +1.17 分(3.3% 相对增益),所有 benchmark 上稳定提升。
消融实验¶
四种 dispersion 变体对比:
| 变体 | 平均分 | 备注 |
|---|---|---|
| Decorrelation | 35.1 | 间接,受 feature dim 影响 |
| \(\ell_2\)-repel | 35.0 | 需 norm 正则才稳定 |
| Orthogonalization | 35.2 | 只惩罚锐角 |
| Dispersion (canonical) | 35.5+ | 角度均匀分散,最优 |
规模对照(confounder-controlled):从零训四个 GPT2-like 模型,只改 MLP 维度其余固定,larger MLP → 更小 condensation,验证 "大模型抗坍缩" 不是其他因素带来的伪相关。
关键发现¶
- 蒸馏救不了坍缩:Qwen2.5 系列蒸馏后嵌入几何与从零训几乎一样,因为 KD loss 只约束输出 logits,不规范中间表征——这成为该论文最直观的动机证据。
- 坍缩从初始化就有,但训练能缓解:Olmo-3-7B 的 checkpoint 显示 condensation 指标在初始化时正且大,随训练降到负值,说明 SGD 本身就在抵抗坍缩,分散损失只是加速 + 强化。
- 小模型增益更大:Qwen3-0.6B 提升最明显,Qwen3-32B 几乎无收益,符合 "大模型本来就分散" 的假说。
- mid-training 即可生效:不需要重训,对已有 checkpoint 加 200M token + dispersion 即获明显收益,落地成本极低。
- 代价 < 1% 训练时间:1.13 vs 1.122 A100h,因为 \(N^2\) 配对可子采样。
亮点与洞察¶
- "小模型的瓶颈是几何不是参数":把性能差距归因于表征几何而非容量本身,是该论文最有想象力的命题——意味着可以在不加参数的前提下逼近大模型性能上限。
- 角度而非欧氏距离:用 \(\arccos\) 把 cosine 映成均匀的角度距离,避开 cosine 在两端的饱和,是工程上很关键的稳定化 trick。
- 理论 → 经验 → 干预 → 验证 闭环:从 Geshkovski 2025 的理论坍缩定理出发,提供大规模经验证据,设计干预,再用 confounder 控制实验证实,整篇文章的论证链非常完整。
- 零参数增量:与剪枝、量化、LoRA 等需要改架构的方法不同,dispersion loss 只是训练时附加项,可与任何主流 LM 训练 pipeline 即插即用。
局限与展望¶
- 未扩展到对齐 / 推理任务:实验全在 zero/few-shot 通用 NLU,未测 reasoning / math / code 等需要更复杂表征几何的任务。
- 大模型几乎无增益:32B 上没看到提升,分散假说是否对大模型成立未确认;也许大模型的瓶颈在别处。
- token 子采样的代价:实际跑大模型时 \(N^2\) 还是要子采样,子采样如何影响收敛和最终性能没仔细消融。
- 未与 anti-collapse 技术(如 SimSiam stop-gradient、BarlowTwins)比较:自监督表征学习里有大量去坍缩方法,应横向比对。
- 未解释为何 \(\lambda_{\text{disp}}\) 不需要复杂调度:直觉上训练早期需要更强分散、后期可减弱,但论文用固定权重就 work,这是个有趣的现象但未深入分析。
- 架构敏感性:在 RMSNorm(无 affine LN)、不同位置编码下 dispersion 是否仍生效未测。
相关工作与启发¶
- vs Wang & He 2025(diffusion 中的 dispersion): 那篇把 dispersion 用在生成模型上,本文把它迁移到语言模型并改成角度形式 + 显式去对角,是同一思路的领域适配。
- vs noisy embedding / active forgetting: 这些 trick 也试图增加表征多样性,但都是间接手段且无几何解释;dispersion 直接、可解释,效果也更好。
- vs 蒸馏类压缩:蒸馏只能传输输出行为,dispersion 传输的是表征几何这一更基本的性质;二者可叠加。
- vs Cai 2021、Bis 2021(isotropy 研究): 都关注嵌入空间的各向同性 / 异性,本文给出了第一个明确的训练目标使之可控。
- 启发:这个思路可以推广到 (1) vision encoder 的 patch token 嵌入;(2) 多模态对齐时不同模态嵌入之间的分散;(3) MoE 专家激活模式的分散,避免专家坍缩。
评分¶
- 新颖性: ⭐⭐⭐⭐ 把理论坍缩定理转化为可训练的辅助损失,并配套提出 4 种变体,属于 "已知现象 + 干净干预" 类新颖性,不是石破天惊但完整可信。
- 实验充分度: ⭐⭐⭐⭐⭐ mid-training + full pre-training(156B token / 640 GPU)双重验证,10 个 benchmark + 3 seed + confounder 控制 + 4 种 dispersion 变体消融,少见的扎实。
- 写作质量: ⭐⭐⭐⭐⭐ 论证链 "理论 → 经验观测 → 反例(蒸馏无效)→ 干预 → 验证" 非常顺,图表清晰,把抽象的几何现象讲得像故事一样。
- 价值: ⭐⭐⭐⭐ 几乎零成本即插即用,对小模型训练社区直接受益;只是大模型上无收益、未测推理 / 代码任务,长期影响力有上限。