Trust Functions: Near-Lossless Weak-to-Strong Generalization by Learning When to Trust the Weak Teacher¶
会议: ICML 2026
arXiv: 2606.01000
代码: 论文中提及(Code / Website 链接)
领域: 对齐RLHF / 弱监督学习 / 数据选择
关键词: 弱到强泛化, 信任函数, 数据筛选, 教师隐藏状态, 超对齐
一句话总结¶
本文把"弱到强泛化(Weak-to-Strong Generalization)"重新框架成一个数据选择问题,提出"信任函数(Trust Function)"用一个轻量 MLP 读取教师模型最后一层隐藏状态、预测弱标签是否可靠,然后只挑高信任样本去训练强学生,从而在多任务上实现近无损甚至超越 ground-truth 的监督效果,并可迭代成"弱到强链"放大收益。
研究背景与动机¶
领域现状:随着 LLM 在复杂任务上逼近甚至超过人类水平,传统"人类提供可靠监督"的假设崩塌,超对齐(Superalignment)转向用一个弱教师 \(\pi_{\mathcal{W}}\) 去训练更强的学生 \(\pi_{\mathcal{S}}\)。Burns 等人的开创性工作显示弱监督可以让学生超过教师,但始终留有一段无法弥合的差距(与 GT 监督相比)。
现有痛点:弱教师产出的伪标签包含两类系统性错误——(i) 错误标签会沿着数据几何结构被强模型继承下来;(ii) 任务相关方向若不在弱教师表征空间里则无法被传递。结果是弱监督在分布漂移下经常带来不稳定甚至退化,难以闭合到 GT 水平。
核心矛盾:现有"选数据"的尝试普遍用 输出层启发式——例如熵、多模型一致性、自我评估——这些信号本身在复杂任务上就标定差(confident error 高分、correct-but-uncertain 低分),在分布漂移下尤其脆弱。问题的根本在于:输出层信号不足以判断弱标签的可靠性。
本文目标:在固定架构和训练算法前提下,找出弱标注池中"真正能让学生变强"的子集,并把"如何判断标签可靠"这一问题统一形式化。
切入角度:作者注意到先前工作(Kadavath et al. 2022; Kuhn et al. 2023)发现中间表征本身就编码了"答案是否正确"的可分信号,只是被解码层抹平了。因此应该回到隐藏状态去训练一个判别器,而不是去信解码后的概率。
核心 idea:用一个小 MLP \(\tau\) 直接从弱教师的隐藏状态预测"这条弱标签到底对不对",只用高信任样本做 SFT/GRPO,再把训出来的学生当作下一轮的教师,叠成"弱到强链"。
方法详解¶
整体框架¶
框架名为 Learning to Trust (L2T),需要两份数据:有标签源集 \(\mathcal{D}_{\ell}=\{(x_i,y_i)\}\)、无标签目标集 \(\mathcal{D}_u=\{x_j\}\),二者不必同分布。流程分四步:
- 弱教师 \(\pi_{\mathcal{W}}\) 在 \(\mathcal{D}_u\) 上跑前向,产出弱标签 \(\hat{y}=\pi_{\mathcal{W}}(x)\),并同步缓存最后一层、最后生成 token 的隐藏状态 \(g_{\pi_{\mathcal{W}}}(x,\hat{y})\)。
- 在 \(\mathcal{D}_{\ell}\) 上根据真值是否匹配,把 \((g_{\pi_{\mathcal{W}}}(x,\hat{y}),\,\text{is\_correct})\) 当作二分类样本训练 Neural Trust Function(NTF)\(\tau\)。
- 用 \(\tau\) 对 \(\mathcal{D}_u\) 每条样本打信任分 \(t=\tau(g_{\pi_{\mathcal{W}}}(x,\hat{y}))\),按 top-\(n\) 选出高信任子集 \(\tilde{\mathcal{D}}_u\)。
- 用 \(\tilde{\mathcal{D}}_u\) 上的弱标签 SFT 或 GRPO 训练强学生 \(\pi_{\mathcal{S}}\)。整个过程不需要 \(\mathcal{D}_u\) 的真值。
链式版本则把当代学生作为下一代教师,重复上述四步,逐步放大收益。
关键设计¶
-
基于隐藏状态的 Neural Trust Function(NTF):
- 功能:把弱教师内部表征 \(g_{\pi_{\mathcal{W}}}(x,\hat{y})\in\mathbb{R}^d\) 映射成一个 \([0,1]\) 的信任分 \(\tau(\cdot)\),用于估计该弱标签为真的概率。
- 核心思路:输入用最后一层、最终生成 token 的隐藏向量(该 token 已通过 attention 聚合了 prefix 与中间推理)。\(\tau\) 用残差 MLP 堆叠 RMSNorm-SwiGLU 块(Dropout + stochastic depth),最后接 RMSNorm + 线性头出 logit,sigmoid 转概率;loss 用类重加权的 BCE 处理标签不平衡。训练样本由 \(\mathcal{D}_{\ell}\) 上"弱预测 vs 真值"自动构造,规则为任务相关 exact match(MCQA / 数学)或 best-move 匹配(象棋)。
- 设计动机:输出层 confidence 在难题上系统性失准;中间层早就包含"我大概是不是答对了"的可分信号。把判别器搬到隐藏空间能避开 confident-but-wrong 陷阱;同时整条管线主要成本是弱教师前向(生成弱标签时本来就要跑),\(\tau\) 是个小 MLP 训练/打分几乎零开销,\(C_{\text{total}}=O(\bar{C}_{\text{teacher}}(|\mathcal{D}_{\ell}|+|\mathcal{D}_u|)+C_{\text{NTF}}(e|\mathcal{D}_{\ell}|+|\mathcal{D}_u|))\) 实际被教师项主导。
-
In-domain 分布漂移下的零样本部署:
- 功能:放宽"必须有目标域标签"的强假设——\(\tau\) 在源分布上训练,但实际部署到任务接口相同、数据分布不同的目标域上零样本打分。
- 核心思路:作者把泛化场景显式分成三档:ID(同 benchmark held-out)、OOD\(_{\text{dist}}\)(同任务接口、不同数据分布,如 MMLU \(\to\) ARC-Easy)、OOD\(_{\text{domain}}\)(连任务接口都换,如 MCQA \(\to\) 象棋)。文中所有"零样本迁移"声明默认指 OOD\(_{\text{dist}}\)。Table 1 显示 NTF 在 ID 与 OOD\(_{\text{dist}}\) 上 AUC 达 0.83–0.92,纯度 0.69–0.98。
- 设计动机:现实里标签分布严重不均衡——MMLU/MATH 这种大标注集很容易拿到,AIME 之类的目标域则非常稀缺。允许 \(\tau\) 在源域有标签上训练、目标域无标签部署,把信任函数的成本摊到一次性即可;同时显示 OOD\(_{\text{domain}}\) 会退化也老实承认了方法边界。
-
Weak-to-Strong Chain(弱到强链):
- 功能:把 L2T 训出的学生 \(\pi_{\mathcal{S}}^{(1)}\) 作为下一轮的教师 \(\pi_{\mathcal{W}}^{(2)}\),再用同样的 NTF 信任筛选训练更大的学生 \(\pi_{\mathcal{S}}^{(2)}\),迭代多代。
- 核心思路:每一代学生因为接受了高纯度弱标签训练,自身在目标域上的准确率单调升高(论文称之为 snowballing);新学生作为下一代教师时,产出的弱标签纯度也水涨船高,因此即便用同一个 \(\tau\) 打分,可用样本量与平均准确率都会扩大。Figure 1 右下展示了链式累积增益曲线。
- 设计动机:单代 L2T 已经能逼近 GT 监督,但学生规模逐渐放大时仍有空间;链式结构无需引入新组件即可放大收益,且每一代的训练步骤都是同一套 L2T 协议,便于规模化。
损失函数 / 训练策略¶
- NTF 训练:类重加权 BCE + AdamW(带 weight decay),评估用 AUC / ECE / Brier / Purity(top-trust 子集中真正正确的比例)。
- 强学生训练:MCQA 任务用 LoRA-SFT 在 top-\(n\) 高信任样本上拟合弱标签;数学推理用 GRPO 在高信任 rollout 上做 RL。Recovery 指标定义为 \(\text{Recovery}=\frac{\text{Baseline}-\text{Base}}{\text{GT}-\text{Base}}\times 100\%\),衡量"相对 GT 训练的恢复比例"。
实验关键数据¶
主实验¶
World Knowledge(5 个 MCQA benchmark 平均准确率,括号内为 Recovery%):
| 教师 \(\to\) 学生 | Naive | I-Confidence | ICL+I-Conf | Reward Model | NTF(本文) | Ground Truth |
|---|---|---|---|---|---|---|
| OLMo2-1B \(\to\) OLMo2-7B | 69.3 (48.3) | 69.2 (47.1) | 72.0 (79.3) | 68.8 (42.5) | 73.7 (98.9) | 73.8 |
| OLMo2-1B \(\to\) OLMo2-13B | 74.7 (12.2) | 75.1 (17.6) | 77.9 (55.4) | 78.4 (62.2) | 80.9 (95.9) | 81.2 |
| Qwen3-0.6B \(\to\) Qwen3-1.7B | 74.0 (86.0) | 74.3 (91.2) | 74.4 (93.0) | 71.7 (45.6) | 75.0 (103.5) | 74.8 |
| Qwen3-0.6B \(\to\) Qwen3-14B | 86.0 (86.8) | 85.7 (82.9) | 86.5 (93.4) | 86.1 (88.2) | 87.1 (101.3) | 87.0 |
8 个 setting 中 NTF 与 GT 在 5 个上统计无差异(near-lossless),1 个上显著优于 GT(super-recovery),始终强于所有 baseline。
消融实验¶
NTF 在不同领域上的标定指标(Table 1,World Knowledge 与 Strategy Games 用 Qwen3-0.6B,Quantitative Reasoning 用 Qwen3-1.7B / Gemma3-1B):
| 领域 | AUC ↑ | ECE ↓ | Brier ↓ | Purity ↑ |
|---|---|---|---|---|
| World Knowledge | 0.92 | 0.03 | 0.07 | 0.98 |
| Quantitative Reasoning (Omni) | 0.83 | 0.11 | 0.13 | 0.69 |
| Quantitative Reasoning (MATH) | 0.84 | 0.14 | 0.17 | 0.95 |
| Strategy Games | 0.91 | 0.02 | 0.11 | 0.95 |
关键发现¶
- 收益来源不是简单"过滤掉错标签":作者归因到三条机制——保留了诱导 easy-first 隐式课程的样本;有时还能"修正"GT 中本来就次优的标签(在 MATH 等任务上观察到);筛后样本的梯度方向更对齐。
- NTF 对极弱教师依然有效:Qwen3-1.7B 在 AIME 上裸 acc <5%,但配 NTF 后仍能实现近无损 GT 恢复,说明信任函数本身在低纯度池里也能抓住稀有可靠样本。
- OOD\(_{\text{domain}}\)(任务接口都换)会显著退化,说明"信任"与任务接口/输出空间是耦合的,跨接口迁移仍是开放问题。
亮点与洞察¶
- 重定义问题:把弱到强泛化从"如何设计 loss/算法"转向"如何挑数据",并提出 trust function 这个保护伞概念把熵、agreement、自我评估、reward model 等已有做法都装进同一框架,便于横向比较。
- 几乎零额外算力:NTF 只是一个小 MLP,输入是 anyway 都要算的隐藏状态;相比依赖外部 reward model 的 baseline,部署成本低且效果更好,是非常实用的工程优势。
- 链式放大:链式弱到强等于免费把数据筛选当成迭代式自训练,类似"自我对弈式"地把弱监督慢慢提纯,给超对齐场景提供了一种可持续的 bootstrap 路径。
局限与展望¶
- 仍依赖一份源域标签:虽然不需要目标域 GT,但需要"任务接口相同的有标签源域"——极端 superalignment(连源域都没有可靠标签)下并不直接可用。
- 跨接口(OOD\(_{\text{domain}}\))失效:信任函数与任务接口紧耦合,迁到完全不同的任务(如 MCQA \(\to\) 数学)会退化,未来或需共享接口表征或多接口联合训练。
- 评估限于中等规模模型(OLMo2 / Qwen3 1B–14B),最大学生 14B,是否在 70B+ 级别仍保持 near-lossless 还需进一步验证。
- 链式增益的上界与稳定性:论文展示了 snowballing,但缺少崩溃点分析——多少代以后链式会失稳?
相关工作与启发¶
- vs Burns et al. 2023(开创性 W2S): 后者关注训练目标本身(如 confidence loss);本文不动 loss/架构,只在数据侧筛选,但闭合到 GT 的速度更快。
- vs Internal/Verbalized Confidence: 同样想衡量教师可靠度,但只用输出层信号;本文证明隐藏层 + 小判别器在难题上系统性更稳定,且能跨 benchmark 零样本迁移。
- vs Reward Model 过滤: RM 是通用判别器,但泛 reward 信号与"弱标签是否对"并不一一对应;NTF 直接对"correctness"建模,更贴合 W2S 场景。
- vs 自训练 / pseudo-labeling: 经典自训练用模型自身置信度过滤伪标签;这里换成"另一个针对教师隐藏态训出来的小判别器",迁移到 W2S 场景。
评分¶
- 新颖性: ⭐⭐⭐⭐ 把 W2S 框成数据选择问题、并把隐藏状态判别器作为统一形式化,是值得关注的视角转换。
- 实验充分度: ⭐⭐⭐⭐⭐ 三大领域、两族模型、多尺度(1B–14B)、显著性检验、多种 baseline,覆盖较全面。
- 写作质量: ⭐⭐⭐⭐ 形式化清晰、generalization regime 划分严谨;少量实验细节挪到 Appendix 增加阅读跳转。
- 价值: ⭐⭐⭐⭐⭐ 提供了近无损弱到强的工程级方案,对超对齐研究有直接落地意义。