跳转至

NITP: Next Implicit Token Prediction for LLM Pre-training

会议: ICML 2026
arXiv: 2605.24956
代码: 待确认
领域: LLM 预训练 / 表示学习
关键词: NTP 表示退化, 隐式目标, 浅层监督, 余弦相似度

一句话总结

NITP 通过用浅层表示作为隐式目标为最后隐藏状态提供连续的表示空间监督——补充标准 NTP 防止隐藏表示退化为低维各向异性配置,在 9B MoE 上 MMLU-Pro 提升 5.7%、推理任务普遍提升 4-6%,额外计算开销仅 ~2%。

研究背景与动机

领域现状:标准下一 token 预测(NTP)是 LLM 预训练的主流范式。NTP 本质上是在输出 logit 空间提供离散、独热的监督。

现有痛点:虽然梯度通过输出投影反向传播到隐藏状态,但 NTP 目标主要沿目标 logit 方向约束表示,在潜在空间留下大量弱约束的自由度。这导致表示退化——基于似然的训练会将学到的表示压缩到狭窄的各向异性圆锥体内,严重限制表达能力并与下游性能下降有关。

核心矛盾:NTP 定义了"预测什么",但没有约束"如何表示";隐藏状态可以采用多种几何不同的配置,但实际中会陷入表示退化——牺牲语义丰富性而获得判别效率。

本文目标:解决 NTP 在隐藏表示几何上的盲区,通过显式的表示级监督引导隐藏状态保持结构化、语义丰富的配置。

切入角度:不在离散 token 空间中工作,而是在连续表示空间中进行监督——让模型预测下一 token 的隐式语义表示(用模型自身的浅层表示作为自监督目标)。浅层之所以合适,是因为保留了丰富的词汇和局部语义细节。

核心 idea:NITP = NTP(离散监督)+ NITP(连续表示空间监督);用浅层的下一 token 表示作为隐式目标,通过余弦相似度损失强制最后隐藏状态与之对齐,参数高效(隐式目标来自已计算的中间激活,无需额外前向传播)。

方法详解

整体框架

双监督机制——(1)标准 NTP \(\mathcal{L}_{\text{NTP}}\);(2)NITP 辅助目标 \(\mathcal{L}_{\text{NITP}} = 1 - \frac{\mathcal{P}(h_t)^\top z_{t+1}}{\|\mathcal{P}(h_t)\|_2 \cdot \|z_{t+1}\|_2}\);(3)联合优化 \(\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{NTP}} + \lambda \mathcal{L}_{\text{NITP}}\)。在标准前向传播中完成:先通过前 \(k\) 层(如第 4 层)计算 \(t+1\) 位置的隐式目标 \(z_{t+1}\)(停止梯度),再从最后层的隐藏状态 \(h_t\) 通过投影头预测它。

关键设计

  1. 隐式目标构造:

    • 功能:为最后隐藏状态提供上文相关的、语义丰富的监督信号。
    • 核心思路:用模型自身浅层(第 4 层,~20% 模型深度)在位置 \(t+1\) 的表示 \(z_{t+1} = \text{sg}[E_{\text{shallow}}(x_{\leq t+1})^{(t+1)}]\) 作为隐式目标。停止梯度保持稳定(浅层收敛更快);不引入额外计算。
    • 设计动机:浅层的语义丰富性强制深层表示必须维持足够的表达能力才能预测它,防止各向异性坍陷。
  2. 余弦相似度损失:

    • 功能:在表示空间中对齐预测状态与隐式目标。
    • 核心思路:最小化余弦相似度损失,\(\mathcal{P}(\cdot)\) 是简单投影头(MLP);余弦相似度在 \([-1, 1]\) 上对称,对尺度不敏感。
    • 设计动机:消融显示余弦损失比 MSE、Smooth-\(\ell_1\)、KL 散度更稳定——MSE 因二次惩罚会放大层间尺度失配;KL 将向量当分布处理会引入几何扭曲。
  3. 自监督设计 + 停止梯度:

    • 功能:无需外部数据或编码器,自动生成监督信号并保证训练稳定。
    • 核心思路:隐式目标通过 sg 停止梯度,梯度只流向最后层和投影头,不反向传播到浅层;浅层作为稳定的"语义锚"。
    • 设计动机:降低计算成本(~2% 额外 FLOPs)、提高训练稳定性、完全自监督。

理论分析:语义流形的规则化

NTP 目标对 \(h_t\) 的约束主要来自其与目标 token embedding 的点积——Hessian 秩亏、允许表示在零空间中任意漂移。NITP 通过引入正曲率规则化这些方向:NITP Hessian \(H_{\text{NITP}}(h) \approx \frac{1}{r^2} P_{\perp u}\)(超球面切空间投影);在所有正交方向引入严格正曲率,强制表示维持结构化几何。

实验关键数据

主实验

模型 方法 MMLU MMLU-Pro C3 CommonsenseQA 平均提升
1.9B MoE (0.3B active) NTP 31.05 7.14 32.21 25.38
1.9B MoE NITP 31.68 7.47 29.69 26.61 +0.8
3B MoE NTP 34.60 11.00 39.06 34.15
3B MoE NITP 37.37 12.29 44.38 37.92 +2.1
9B MoE NTP 43.71 15.29 56.65 45.70
9B MoE NITP 46.14 21.00 63.01 49.96 +2.7

9B 上 MMLU-Pro 绝对提升 5.7%;阅读理解和常识推理分别增长 6.4% 和 4.3%。

消融实验

配置 MMLU MMLU-Pro CommonsenseQA BBH 平均
基线 NTP 34.60 11.00 34.15 21.92 25.42
浅层(L₄) 37.37 12.29 37.92 26.14 28.43
中层(L₈) 35.33 11.57 34.72 22.07 25.92
深层(L₁₄) 35.79 10.43 38.90 23.25 27.09
当前位置 t→t 33.09 8.14 29.15 20.96 22.84
MSE 损失 32.77 10.29 30.38 21.55 23.75
余弦正则(无预测) 34.45 10.14 33.25 22.29 25.03

关键发现

  • 浅层选择的必要性:使用浅层表示(~20% 模型深度)比中层 / 深层都好——浅层保留更丰富的词汇和局部语义。
  • 时间结构至关重要:预测下一 token 隐式表示(\(t \to t+1\))比当前位置对齐(\(t \to t\))性能高 5.6 个百分点。
  • 损失函数的稳定性差异:MSE 会导致梯度尖峰和临时发散;只有余弦相似度完全稳定且性能最好。
  • 正则化不等于预测:通用余弦正则化虽约束表示几何但不提升性能——收益来自"预测对齐的"语义监督。
  • 计算效率:额外 FLOPs 仅 ~2%;\(\lambda = 1.0\) 最鲁棒。

亮点与洞察

  • 诊断表示退化的根本原因:通过有效秩和余弦相似度可视化清晰展示 NTP 如何导致表示向低维各向异性配置漂移;理论分析用 Hessian 谱解释根本机制。
  • 自监督隐式目标的巧妙设计:浅层表示作为"语义锚"既不需要外部数据或模型,又因为浅层语义信息最丰富而成为理想监督信号。
  • 通用性和可迁移性:NITP 在 MoE 和稠密模型、0.5B 到 9B 参数范围、多种评估基准上都有效。
  • 最小计算开销下的显著收益:~2% 额外训练 FLOPs 代价下获得 5%+ 的知识理解提升和 6%+ 的推理能力提升。

局限与展望

  • NITP 引入额外超参数(目标层、NITP 权重 \(\lambda\)),不同模型上的稳定性需进一步验证。
  • 当前位置对齐完全失效的解释有待深化。
  • 对于更大规模模型(> 100B)、不同架构、多模态模型的适用性需要验证。
  • 隐式目标的浅层选择(第 4 层)可能对不同模型深度不是最优的。

相关工作与启发

  • vs 多 token 预测(MTP):MTP 在离散 token 空间扩展预测范围,NITP 在表示空间做监督,两者可互补。
  • vs 层蒸馏:蒸馏对齐两个不同模型的表示,NITP 在同一模型内用浅层指导深层,避免外部分布偏移。
  • vs 自监督对比学习(BYOL):对比学习鼓励不同视角间的一致性,NITP 聚焦于时间维度的预测。
  • 启发:表示级别监督可以是解决 LLM 预训练目标不完全性的方向。

评分

  • 新颖性: ⭐⭐⭐⭐⭐ 用浅层隐式目标补充 NTP、从 Hessian 角度理论解释表示退化——想法简洁但深刻。
  • 实验充分度: ⭐⭐⭐⭐⭐ 多个模型规模、两种架构、丰富消融、理论分析与经验验证相结合。
  • 写作质量: ⭐⭐⭐⭐ 逻辑清晰,理论部分稍显抽象但重点突出。
  • 价值: ⭐⭐⭐⭐⭐ 直接改进 LLM 预训练效率和性能,2% 额外成本获 5%+ 收益,工业应用价值大。