跳转至

Dynamics-Predictive Sampling for Active RL Finetuning of Large Reasoning Models

会议: ICLR 2026
arXiv: 2603.10887
代码: github.com/maoyixiu/DPS
领域: LLM Reasoning / RL Finetuning
关键词: 强化学习微调, 提示采样, 隐马尔可夫模型, 大推理模型, 在线贝叶斯推断

一句话总结

将 RL 微调中每个 prompt 的求解进度建模为隐马尔可夫动力系统,通过轻量贝叶斯推断在线预测 prompt 的求解状态,优先采样"部分求解"的 prompt,以不到 DS 30% 的 rollout 量达到同等甚至更优的推理性能。

研究背景与动机

领域现状:RL 微调(以 GRPO 为代表)已成为提升 LLM 推理能力的核心技术路线。DeepSeek-R1、OpenAI o1 等大推理模型(LRM)通过 RL 微调在数学竞赛、代码生成、逻辑推演等任务上取得了突破性进展。然而,RL 微调的效果高度依赖训练数据的选择质量——并非所有 prompt 都对策略优化有同等贡献。

现有痛点:GRPO 使用组内归一化计算优势函数 \(\hat{A}_i^\tau = \frac{r(\tau, y_i^\tau) - \text{mean}}{\text{std}}\)。当模型对某 prompt 的 \(k\) 个响应全部正确或全部错误时,\(\text{std} = 0\),优势函数退化,梯度信号消失。因此,只有"部分求解"(some correct, some incorrect)的 prompt 才能提供有效的优化信号。

核心矛盾:当前最先进的在线 prompt 选择方法 Dynamic Sampling(DS)通过扩大候选批次(通常为最终批次的 3-4 倍)并逐一 rollout 来筛选有效 prompt,虽然能精准找到部分求解的样本,但 rollout 的计算开销极为昂贵——生成长链推理响应的成本常常超过微调本身。另一方法 History Resampling(HR)在 epoch 级别过滤已完全求解的 prompt,但其"吸收态"假设过于刚性,且只排除已解 prompt 而不主动寻找部分求解的。

本文目标 如何在不进行昂贵 rollout 的前提下,在线预测哪些 prompt 当前处于"部分求解"状态,从而以极低的额外开销实现与 DS 相当甚至更优的训练效果?

切入角度:作者将每个 prompt 的求解进度视为一个随时间演化的动力系统——随着模型更新,prompt 的求解状态(未解→部分求解→已解)会发生转移。这种演化可以用 HMM 建模:隐状态是求解程度,观测是间歇性的 rollout 结果。利用历史 rollout 数据进行贝叶斯推断就可以预测未来的求解状态,无需实际 rollout。

核心 idea:将 prompt 求解进度建模为 HMM 动力系统,通过在线贝叶斯推断预测求解状态,在 rollout 前就选出最可能处于"部分求解"状态的 prompt,实现 rollout-free 的自适应采样。

方法详解

整体框架

DPS 的核心 pipeline 在每个训练步 \(t\) 包含两个阶段:(1) 预测与选择——利用上一步推断得到的先验 \(\mu_t^{\tau,\text{prior}}\),按每个 prompt 处于"部分求解"(State 2)的预测概率排序,选取 Top-\(B\) 个 prompt 构成训练批次 \(\mathcal{B}_t\)(2) 推断与更新——对被选中的 prompt 执行 rollout 获取观测,用贝叶斯规则更新状态后验和转移矩阵后验,然后前向传播生成下一步的先验预测。

关键设计

  1. 三状态隐马尔可夫建模:

    • 功能:将每个 prompt 的求解进度形式化为可追踪的动力系统
    • 核心思路:定义隐状态 \(z_t^\tau \in \{1, 2, 3\}\),分别对应完全未解(\(k\) 个响应全错)、部分求解(部分正确部分错误)、完全求解(全部正确)。初始先验为均匀分布 \(\mu_1^{\text{prior}} = [1/3, 1/3, 1/3]\)。状态转移由列随机矩阵 \(\Phi \in \mathbb{R}^{3 \times 3}\) 描述,观测模型为退化发射:若 prompt 被选中 rollout 则状态被精确观测,否则无观测
    • 设计动机:三状态划分既符合 GRPO 的梯度信号理论(只有 State 2 有非零优势函数),又保持模型简洁——消融实验证明更细或更粗的划分都会降低预测精度
  2. 在线贝叶斯推断三步流水线:

    • 功能:在每个训练步实时更新状态估计和转移模型
    • 核心思路:(i) 观测更新——若 prompt 被选中 rollout,用贝叶斯规则将先验 \(\mu_t^{\text{prior}}\) 更新为后验 \(\mu_t^{\text{post}}\);退化发射模型下后验直接坍缩到观测状态。(ii) 转移更新——利用 Dirichlet-Categorical 共轭性,用后验转移伪计数 \(\xi_t(i,j) = \mathbb{P}(z_{t-1}=j, z_t=i \mid y_{1:t})\) 增量更新 Dirichlet 参数 \(\alpha_t\)。(iii) 下一步预测——\(\mu_{t+1}^{\text{prior}} = \Phi_t \mu_t^{\text{post}}\),将后验通过转移矩阵传播为下一步先验
    • 设计动机:经典 HMM 的前向-后向算法需要完整轨迹,无法在线使用。DPS 的在线推断只依赖当前步观测和上一步后验,全部操作为 \(3 \times 3\) 矩阵运算,开销可忽略
  3. 非平稳指数衰减机制:

    • 功能:使转移模型适应模型学习过程中不断变化的求解动态
    • 核心思路:引入衰减因子 \(\lambda \in (0,1)\),更新规则变为 \(\alpha_t = \lambda \cdot \alpha_{t-1} + (1-\lambda) \cdot \alpha_0 + \xi_t\)。较小的 \(\lambda\) 使模型更快遗忘旧统计量、适应新动态。此机制同时隐式提供探索:长期未采样的 prompt 的后验会逐渐衰减至均匀分布,使其在缺少明确高信息 prompt 时被自然重新访问
    • 设计动机:LRM 的学习过程高度非平稳——随着策略更新,prompt 的求解概率持续变化。标准 HMM 的平稳假设在此不成立,而指数衰减是一种轻量且有效的非平稳扩展

Prompt 采样策略

\(\mu_t^{\tau,\text{prior}}(2)\)(预测为部分求解的概率)对全部 prompt 排序,直接选取 Top-\(B\) 个构成训练批次。这是一个纯利用(exploitation)策略,但由于非平稳衰减的存在,系统内建了隐式探索——未被采样的 prompt 预测分布逐渐漂移至均匀,会被自然纳入采样。这种设计避免了显式探索-利用权衡的超参数调优。

训练更新

DPS 与具体的 RL 算法正交。本文全部实验基于 GRPO,在 verl 框架内实现。每步选出 \(B\) 个 prompt 后,对每个 prompt 生成 \(k\) 个响应并计算奖励,用 GRPO 目标更新策略。DPS 的推断开销仅为 \(O(|\mathcal{D}| \times 3^2)\) 的矩阵运算,实测对总训练时间的影响 < 1%。

实验关键数据

主实验:数学推理(MATH 数据集训练,跨基准测试)

方法 AIME24 AMC23 MATH500 Minerva Olympiad Avg↑ Rollouts↓ 运行时间↓
R1-Distill-1.5B (基线) 18.33 51.73 76.64 23.83 35.31 41.17 - -
+US 26.46 63.18 82.78 27.46 43.00 48.57 737k 27h
+HR 28.13 64.61 82.88 27.37 43.15 49.23 737k 28h
+DS (Oracle) 31.88 67.32 84.79 29.18 46.83 52.00 2933k 89h
+DPS (Ours) 32.71 67.77 84.95 29.09 46.11 52.13 737k 32h
R1-Distill-7B (基线) 37.71 68.45 86.94 34.74 46.94 54.95 - -
+US 45.83 73.57 89.06 37.68 50.42 59.31 287k 30h
+HR 46.46 75.98 90.01 37.94 51.50 60.38 287k 36h
+DS (Oracle) 49.79 78.99 90.96 37.89 54.45 62.42 1147k 73h
+DPS (Ours) 51.04 80.35 91.13 37.82 55.32 63.13 287k 39h

关键结论:DPS 在 1.5B 和 7B 两个规模上均超越 DS oracle(Avg +0.13/+0.71),而 rollout 量仅为 DS 的 25.1%(737k vs 2933k)和 25.0%(287k vs 1147k)。运行时间约为 DS 的 36%(32h vs 89h)和 53%(39h vs 73h)。

跨任务泛化:Countdown 规划 & Geometry 视觉几何

任务 / 模型 方法 测试准确率 Rollouts↓
Countdown / Qwen2.5-3B +US 69.87 / 39.42 246k
+HR 70.19 / 42.10 246k
+DS (Oracle) 74.95 / 47.67 1141k
+DPS 74.27 / 47.78 246k
Countdown / Qwen2.5-7B +US 77.84 / 53.27 246k
+HR 78.15 / 54.54 246k
+DS (Oracle) 81.26 / 60.77 1006k
+DPS 81.15 / 59.61 246k

DPS 在 Countdown(数值规划)和 Geometry3k(视觉几何推理,使用 Qwen2.5-VL 多模态模型)上均匹配 DS 性能,rollout 量约为 DS 的 21-24%,证明方法跨任务和跨模态的泛化能力。

预测精度与有效样本比例

  • 整体预测准确率在训练过程中持续保持高位,Class 2(部分求解)的 precision、recall、F1 均稳健
  • 混淆矩阵随训练推进对角线增强、off-diagonal 减少,体现预测能力持续改善
  • 有效样本比例(批次中部分求解 prompt 的占比):DPS 达到 ~90%,远高于 US(~30-50%)和 HR(~40-60%),接近 DS oracle 水平

消融实验

非平稳衰减 \(\lambda\)\(\lambda = 1\)(无衰减,等权所有历史)导致性能和预测精度下降,说明求解动态确实是非平稳的;\(\lambda = 0\)(仅用最近一次观测)同样退化,因为丢弃了有价值的历史信息。中等 \(\lambda\)(如 0.9-0.99)在响应性与历史利用之间取得最佳平衡。

状态划分数量:2 状态(部分求解 vs 其他)将未解和已解混为一类,掩盖了它们截然不同的动态;4+ 状态将有限观测分散到更多类别,导致稀疏估计。3 状态在建模精度和数据效率之间达到最佳。

响应组大小 \(k\)\(k = 4\) 时 DPS 相对 US 优势最大。原因是较小 \(k\) 下,一个 prompt 同时产生正确和错误响应的概率为 \(1 - p^k - (1-p)^k\),该概率随 \(k\) 减小而降低,使得 US 的有效样本比例极低,而 DPS 通过主动预测弥补了这一劣势。

亮点与洞察

  • HMM 建模的精妙适配:三状态 HMM 恰好对应 GRPO 的梯度信号结构(State 1/3 梯度为零,State 2 梯度非零),使得状态预测直接等价于梯度信号预测。退化发射模型(观测即真实状态)使推断极度简化,整个系统仅需维护每个 prompt 的 \(3 \times 1\) 信念向量和 \(3 \times 3\) Dirichlet 参数
  • 隐式探索机制的优雅设计:非平稳衰减 \(\lambda\) 一箭双雕——既适应非平稳动态,又通过后验漂移至均匀分布自动实现探索。实验(Fig. 7)验证较小 \(\lambda\) 产生更均匀的采样频率分布,有效避免采样死锁
  • 预测准确率的自增强效应:DPS 优先采样 State 2 prompt,获得更多 State 2 的观测数据,反过来提升对 State 2 的预测精度。这形成正反馈循环,是方法越用越准的内在机制
  • 与课程学习的隐式联系:DPS 本质上实现了一种数据驱动的自适应课程——训练早期大量 prompt 处于 State 1(未解),DPS 自然采样到刚开始能解的 prompt;训练后期更多 prompt 转向 State 3(已解),DPS 自动聚焦于剩余的挑战性样本。这种课程无需人工设计难度指标

局限与展望

  • 奖励结构假设:当前方法依赖正确性二值奖励来定义三个状态。对于密集奖励(如 process reward)或连续奖励场景,需要重新设计状态划分方案。作者提到可以通过划分累积回报区间来扩展,但未验证
  • Prompt 间独立假设:每个 prompt 维护独立的 HMM,未利用 prompt 之间的结构相似性(如同一知识点、同一难度级别的 prompt 可能有相似的状态转移)。共享转移模型或引入 prompt 嵌入可能进一步提升预测精度
  • Top-\(B\) 贪心选择:纯利用策略可能在极端场景下次优。作者提到基于熵的优先级策略作为未来方向——优先采样状态预测最不确定的 prompt 可能在某些设定下更优
  • 可扩展性未充分分析:虽然推断开销为 \(O(|\mathcal{D}| \times 9)\),但当数据集极大(百万级 prompt)时,每步遍历全部 prompt 计算先验并排序可能成为瓶颈。可考虑基于优先级队列的增量更新
  • 仅验证 GRPO:DPS 框架理论上与 RL 算法正交,但实验仅在 GRPO 上验证。在 PPO、REINFORCE++ 等其他 RL 算法上的适用性有待确认

相关工作与启发

  • vs Dynamic Sampling (DS):DS 是 rollout-intensive 的 oracle 方案——需要对 3-4 倍的候选集进行完整 rollout 才能筛选出有效 prompt。DPS 用预测替代 rollout,在保持 DS 精度的同时将 rollout 开销降低 75%。核心区别在于 DS 是"先生成再过滤",DPS 是"先预测再采样"
  • vs History Resampling (HR):HR 在 epoch 级别标记已完全求解的 prompt 并在后续 epoch 中排除。其局限有二:(1) epoch 级吸收态假设过于刚性——模型更新后某些"已解" prompt 可能退回未解状态;(2) 仅排除已解 prompt 而不主动识别部分求解的 prompt,在训练早中期效果有限
  • vs 离线数据过滤:基于难度估计、领域平衡等的静态过滤方法无法适应模型不断变化的能力边界,而 DPS 通过在线推断持续跟踪模型的求解动态
  • 与课程学习的联系:DPS 可视为一种隐式的、数据驱动的课程学习——区别在于传统课程学习需要预定义的难度指标,而 DPS 直接从 rollout 反馈中学习难度动态

评分

  • 新颖性: ⭐⭐⭐⭐⭐ 将 prompt 采样优雅转化为 HMM 在线状态预测问题,视角新颖且理论扎实
  • 实验充分度: ⭐⭐⭐⭐⭐ 覆盖数学/规划/视觉三类任务、1.5B-7B 多规模模型,消融全面
  • 写作质量: ⭐⭐⭐⭐⭐ 建模清晰,推导详尽,实验呈现规范
  • 价值: ⭐⭐⭐⭐⭐ 以 <30% rollout 量达到 SOTA 采样策略水平,对大规模 RL 微调有直接实用价值