Designing Instance-Level Sampling Schedules via REINFORCE with James-Stein Shrinkage¶
会议: CVPR2026
arXiv: 2511.22177
代码: 待确认
领域: 扩散模型 / 图像生成
关键词: 采样调度, REINFORCE, James-Stein 收缩, Dirichlet 策略, 文生图后训练
一句话总结¶
不动模型权重,只为冻结的文生图采样器学一个"按 prompt 和噪声定制的采样时间表"——用单次前向的 Dirichlet 策略一口气吐出整条 schedule,并用 James-Stein 收缩做 REINFORCE 的 reward baseline 来压低梯度方差,使得 SD/Flux 在相同步数下文图对齐更好、5 步就能逼近蒸馏过的 Flux-Schnell。
研究背景与动机¶
领域现状:扩散 / flow-matching 文生图模型的推理质量很大程度取决于采样时间表(sampling schedule)——即在固定的步数预算下,如何把这几步分配到连续的去噪轨迹上。但主流生产级 backbone(SD-XL、SD-3.5、Flux)对所有输入都用同一条全局固定的 schedule。
现有痛点:一条"通用 schedule"不可能对测试时遇到的各式 prompt 都最优——不同 prompt 需要的空间/语义细节不同,不同的噪声种子也会带来不同的初始条件("golden noise"现象就说明了 seed 敏感性)。后训练(post-training)的两条主流路线都在动权重:要么微调 backbone 做对齐,要么蒸馏 backbone 换取少步数效率,成本高且改变了原模型。
核心矛盾:作者主张存在一个被忽视的、正交于改权重的"杠杆"——只重排采样时间线就能榨出预训练采样器额外的生成潜力,且推理几乎零额外开销。难点在于:要按实例学 schedule,自然想到把它当策略来 RL 优化,但 schedule 是高维、开环、一次性的"前瞻计划",REINFORCE 在这种高维空间里梯度方差极大,训不稳。
本文目标:(1) 设计一个单次前向就输出整条 schedule 的策略(避免自回归逐步预测带来的 \(O(L)\) 推理开销);(2) 给这种高维 one-shot 策略梯度配一个可证明更优的方差缩减 baseline。
切入角度:把"逐 context 的 RLOO baseline"和"跨 context 共享 baseline"看成两个极端,用 James-Stein 收缩在二者间做数据驱动的插值——既保留 context 特异性,又借用全局信息稳住估计。
核心 idea:用"冻结采样器 + 单次 Dirichlet 调度策略 + James-Stein 收缩 baseline 的 REINFORCE"把采样调度变成一种模型无关的后训练手段。
方法详解¶
整体框架¶
方法把"为每个 (噪声 \(x_T\), prompt \(c\)) 设计采样时间表"形式化为策略优化问题。一个轻量策略网络 \(\pi_\theta\) 在单次前向中,针对输入条件直接输出整条 schedule \(\tau\)(一组归一化时间步),交给冻结的预训练采样器 \(S_\phi\) 执行得到图像 \(x_0\),再用一个奖励模型(HPSv2 或 GenEval 规则奖励)打分 \(r(x_0(\tau); c)\)。优化目标是最大化期望奖励
由于整条 schedule 是一个高维"动作"、且奖励只在终端给出,直接用 REINFORCE 梯度方差极大。本文的核心贡献因此落在两处:(a) 用 James-Stein 收缩 baseline 把梯度估计的方差/MSE 压到可证明低于 RLOO;(b) 把策略实例化为 Dirichlet 单次调度器,用一个 simplex 上的连续分布优雅地表示"把单位区间切成 \(L{+}1\) 段"的调度动作。训练时全程不加 KL 约束、策略从零初始化,以便干净地凸显框架本身的效果。
%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400}}}%%
flowchart TD
A["输入:噪声 x_T + 文本 c"] --> B["Dirichlet 单次调度策略<br/>20M 轻量网络,一次前向出整条 τ"]
B --> C["冻结采样器 S_φ<br/>按 τ 去噪生成图像 x_0"]
C --> D["奖励模型<br/>HPSv2 / GenEval 规则奖励 r"]
D --> E["James-Stein 收缩 baseline<br/>b_JS = (1-α)·b_RLOO + α·b_xctx"]
E -->|"(r − b_JS)·∇logπ"| F["REINFORCE 更新 π_θ<br/>采样器权重不动"]
F -.->|下一轮| B
关键设计¶
1. 单次 Dirichlet 调度策略:把整条时间表当成一个 simplex 动作
针对"自回归逐步预测下一个时间步会带来 \(O(L)\) 推理开销"这个痛点,本文不再逐步决策,而是把整条去噪 schedule 当作一个联合动作一次性采样。具体地,令 \(\tau \sim \text{Dirichlet}(\alpha_\theta(x_T, c))\),策略网络输出 \(L{+}1\) 个非负参数 \(\alpha_\theta \in \mathbb{R}^{L+1}_+\)。每个分量 \(\tau_t\) 是一段非负区间,simplex 约束 \(\sum_{t=1}^{L+1}\tau_t = 1\) 保证这些区间恰好划分单位区间 \([0,1]\);再用累加 \(\tilde t_\ell = \sum_{j=1}^{\ell}\tau_j\)、\(t_\ell = 1 - \tilde t_\ell\) 转成严格递减的时间序列 \(1 = t_0 > t_1 > \cdots > t_L > t_{L+1} = 0\)。第 \(L{+}1\) 段 \(\tau_{L+1}\) 充当可学习的"停止边距",让策略动态调节有效采样视野。
这个参数化天然满足 simplex 约束、在合法 schedule 上定义了一个光滑分布、还避免了逐步离散决策;相比 \(O(L)\) 的逐步 RL,单次 Dirichlet 策略把策略梯度开销摊销为常数。策略网络本身很轻(20M 参数,<采样器网络的 1%):噪声 \(x_T\) 经多尺度卷积块提特征,与预训练文本嵌入 \(c\) 做 cross-attention 融合,再经 MLP 投到 \(L{+}1\) 个通道,最后用 softplus + 一个小常数偏移(\(10^{-3}\))保证输出严格为正、Dirichlet 参数合法且梯度数值稳定。
2. 方差-最优 baseline 与 RLOO 的重新诠释:先讲清"理想 baseline 长什么样"
REINFORCE 梯度 \(\nabla_\theta J = \mathbb{E}[(r(\tau) - b)\nabla_\theta \log \pi_\theta(\tau)]\) 对任意与 \(\theta\) 无关的 baseline \(b\) 都无偏,但 \(b\) 的选择直接决定方差。作者先推出方差-最优 baseline(命题 3.1):
它在高维策略下无法精确算,但当策略接近确定性(同一 context 内 \(\nabla_\theta \log \pi_\theta\) 变化很小)时分子分母解耦,近似为该 context 的条件均值奖励 \(b^* \approx \mathbb{E}_{\tau}[r(\tau)\mid x_T, c]\)。这一步把"最优 baseline"和"contextual mean reward"挂上钩,于是常用的 per-context RLOO(对每个 context 抽 \(K_c\) 条 rollout、用留一法平均 \(b_{\text{RLOO}}^{(c,i)} = \frac{1}{K_c-1}\sum_{j\neq i} r^{(c,j)}\))就可被理解为 \(b^*\) 的一个 within-context 蒙特卡洛近似。问题是:当每 context 的 rollout 数很小(本文实际只用 2)或不同 prompt 的奖励量纲差别大时,RLOO 是个很噪的均值估计。另一极端是跨 context baseline \(b_{\text{xctx}}\)(对整个 mini-batch 除当前样本外全平均),方差小但忽略了 context 间系统性的奖励尺度差异,会对某些 prompt 过/欠补偿。
3. James-Stein 收缩 baseline:在 RLOO 与跨 context 之间做可证明更优的插值(核心贡献)
这是全文的核心。作者用随机效应(random effects)视角建模奖励:\(r^{(c,i)} = \mu_c + \varepsilon^{(c,i)}\),\(\varepsilon^{(c,i)}\sim\mathcal{N}(0,\sigma^2)\);\(\mu_c = \mu_0 + \xi^{(c)}\),\(\xi^{(c)}\sim\mathcal{N}(0,\delta^2)\)。其中 \(\sigma^2\) 是 context 内奖励方差,\(\delta^2\) 刻画 context 间异质性(即 prompt 难度差异)。该模型下隐变量均值 \(\mu_c\) 的后验均值会自适应地把经验均值 \(\bar r_c\) 朝全局均值 \(\mu_0\) 收缩,收缩强度 \(\alpha_c^* = \frac{\sigma^2/K_c}{\sigma^2/K_c + \delta^2}\)——rollout 数 \(K_c\) 越小、或 context 越同质(\(\delta^2\) 小),收缩越强。
由于 \((\sigma^2,\delta^2)\) 未知,用经验估计代入就得到 JS reward baseline,它正好是前面两个 baseline 的凸组合:
两个 baseline 都用留一法(排除同一样本 \((c,i)\))算,所以 JS 仍保持 REINFORCE 梯度无偏。\(\hat\alpha_c \to 0\) 时退回 RLOO,\(\hat\alpha_c \to 1\) 时趋向跨 context baseline,是个数据驱动的折中。方差分量用两条 baseline 的统计量估:\(\hat\sigma^2\) 由各奖励围绕 \(b_{\text{RLOO}}\) 的离散度算,\(\hat\delta^2\) 用一个矩估计(MoM)去偏——\(\hat\delta^2 = \max\!\big(0,\ \frac{1}{B-1}\sum_c(\bar r_c - b_{\text{xctx}}^{(c,\cdot)})^2 - \hat\sigma^2/\bar K\big)\),减去 \(\hat\sigma^2/\bar K\) 把 context 内噪声从 context 间估计中剔除,\(\max(0,\cdot)\) 保证非负。每轮迭代重新估这些量,开销可忽略。
理论上(定理 3.2,要求 \(B\geq 3\) 个 context):(i) JS baseline 对 \(\mu_c\) 的 MSE 严格低于无偏的 RLOO baseline;(ii) \(b_{\text{JS}}\) 恰是 \(\mu_c\) 的经验贝叶斯后验均值,因此是 \(b_{\text{RLOO}}\) 与 \(b_{\text{xctx}}\) 的 MSE-最优凸组合。换言之,对任意有限 \(K_c\),JS 都能在不引入偏差的前提下拿到更低方差——这正是它能稳住高维 one-shot 策略梯度、并带来下游生成提升的根本原因。⚠️ 公式转写自 PDF,细节以原文为准。
损失函数 / 训练策略¶
训练就是带 JS baseline 的 REINFORCE(Algorithm 1):每轮抽 \(B\) 个 context,每个 context 抽 \(K_c\) 条 schedule,执行采样器拿奖励 \(r^{(c,i)}\),算出 detach 掉的 \(b_{\text{JS}}^{(c,i)}\),再用 \(\frac{1}{BK_c}\sum (r^{(c,i)} - b_{\text{JS}}^{(c,i)})\nabla_\theta\log\pi_\theta(\tau^{(c,i)}\mid c)\) 平均更新策略。关键超参/设定:rollout 数统一用 2;小模型(SD-XL、SD-3.5M)batch size 32,大模型(SD-3.5L、Flux-Dev)batch size 16;步数预算 \(L\in\{5,10,20,40,80\}\)。刻意不加 KL 约束/正则、策略从零初始化,以保持实验干净可解释。
实验关键数据¶
主实验¶
在 HPD v2(约 100K 训练 prompt + 3200 held-out 测试 prompt)上以 HPSv2 为奖励,对四个 backbone × 五种步数预算对比"默认 schedule / 跨 context RLOO(XCTX)/ RLOO / Ours(JS)",Flux 额外加一个 TPDM 式 PPO 变体。JS 在所有 backbone、所有步数下都拿到最高对齐分,低预算(\(L\leq 20\))增益最大:
| Backbone | 方法 | L=5 | L=10 | L=20 | L=40 | L=80 |
|---|---|---|---|---|---|---|
| SD-XL | Default | 18.25 | 25.47 | 27.69 | 28.52 | 28.55 |
| SD-XL | Ours (JS) | 24.22 | 26.89 | 27.98 | 28.53 | 28.66 |
| SD3.5-L | Default | 24.24 | 28.04 | 29.85 | 30.43 | 30.61 |
| SD3.5-L | Ours (JS) | 26.28 | 28.88 | 29.98 | 30.41 | 30.64 |
| Flux-Dev | Default | 23.73 | 28.06 | 29.88 | 30.84 | 31.04 |
| Flux-Dev | RLOO | 26.48 | 30.41 | 30.77 | 30.92 | 31.10 |
| Flux-Dev | Ours (JS) | 29.21 | 30.86 | 31.12 | 31.23 | 31.36 |
5 步逼近蒸馏模型(HPSv2 on HPDv2,Flux-Dev):JS 在仅 5 步下几乎追平专门蒸馏出来的 Flux-Schnell,说明不蒸馏的 backbone 本身就藏着可观的少步数能力。
| Default | TPDM PPO | Cr. RLOO | RLOO | Ours (JS) | Flux-Schnell |
|---|---|---|---|---|---|
| 23.73 | 15.73 | 26.92 | 26.48 | 29.21 | 29.42 |
消融实验¶
主表里的 baseline 对比本身就是核心消融——同一套架构/优化设定,只换 reward baseline,性能差异全部来自 baseline 选择:
| 配置(Flux-Dev, L=5) | HPSv2 | 说明 |
|---|---|---|
| Default 固定 schedule | 23.73 | 不学调度的下界 |
| TPDM 式 PPO(自回归逐步) | 15.73 | 逐步预测在该设定下反而崩,且 \(O(L)\) 开销 |
| Cross-Context RLOO(全局池化) | 26.92 | 方差小但忽略 context 尺度差异 |
| RLOO(逐 context 留一) | 26.48 | rollout 仅 2 时估计很噪 |
| Ours (JS) 收缩 | 29.21 | 两者凸组合,方差/MSE 可证明更低 |
细粒度对齐上,文本渲染(Flux-Dev,固定步数)JS 把 OCR-Recall 从 49.77 提到 58.58;GenEval(40 步)整体分 SD3.5-M 0.62→0.68、Flux-Dev 0.64→0.70,其中 Counting 提升最猛(Flux 0.58→0.77)。
关键发现¶
- JS 的优势在低预算/高奖励异质场景最明显:步数少时采样方差主导、baseline 估计噪声大,收缩带来的方差缩减收益最大;步数大到离散化误差消失时 JS 仍一致领先,只是差距缩小。
- rollout 只有 2 时 RLOO 最脆弱,正是 JS 借跨 context 信息救场的地方;这与定理 3.2"任意有限 \(K_c\) 都严格优于 RLOO"吻合。
- 大预算下全局偏好分(HPSv2)差距会变窄,但细粒度能力差距仍大:即便 40 步够用,重排 schedule 仍显著改善文本渲染笔画完整度、减少字符丢失,并提升计数等对象级正确性——说明全局分掩盖了细粒度收益。
- 自回归的 TPDM-PPO 在本设定下反而劣于默认(15.73 < 23.73),佐证"单次前向 + 重排固定预算"这条路与"逐步早停求效率"目标不同、且更稳。
亮点与洞察¶
- 把"采样调度"立成一种独立的后训练杠杆:不碰权重、不蒸馏,只重排时间线就能在相同步数下提升对齐、在 5 步逼近蒸馏模型——模型无关、推理几乎零额外开销,和蒸馏正交可叠加。这个 framing 本身很有迁移价值。
- 把 schedule 当 simplex 上的单次 Dirichlet 动作很优雅:天然满足递减/划分约束、提供光滑可微分布、把逐步 RL 的 \(O(L)\) 开销摊成常数,还顺手用第 \(L{+}1\) 段学一个"停止边距"。
- James-Stein 收缩 baseline 是可即插即用的通用方差缩减原语:它把 RLOO 与跨 context baseline 统一为经验贝叶斯后验均值,可证明 MSE 严格优于 RLOO、保持无偏、几乎零额外计算。这套东西不止用于 schedule 学习,作者明确指向 RLHF 等长视野 one-shot 策略梯度都能直接换上。
- "预训练采样器本就有少步数潜力"这个观察很有启发:少步数能力未必非要靠蒸馏重训,合理重排既有预算就能解锁相当一部分。
局限与展望¶
- 作者承认:调度器架构很简单、奖励类型有限(只用了 HPSv2 与 GenEval 规则奖励),更丰富的策略和更广的目标留待将来。
- 近似依赖"策略接近确定性"假设:方差-最优 baseline 退化为 contextual mean 这步建立在 \(\nabla_\theta\log\pi_\theta\) 在 context 内变化很小之上,训练早期策略熵大时该近似的紧致程度存疑。⚠️ 以原文为准。
- 随机效应模型假设奖励高斯、context 间同方差,真实奖励模型(HPSv2/规则奖励)分布可能偏斜或重尾,\(\hat\delta^2\) 的矩估计在小 batch 下也可能不稳(故用 \(\max(0,\cdot)\) 兜底)。
- 评测高度依赖 HPSv2/OCR/GenEval 这些自动指标,缺少大规模人评;且"5 步逼近 Schnell"只在 HPSv2 单一奖励、Flux-Dev 单一 backbone 上给出,跨奖励/跨架构的普适性还需更多验证。
- 展望:自适应早停做动态步数预算、多目标/过程级奖励(组合性、美学、安全)、把 JS 收缩推广到 RLHF 及视频/flow/3D 生成管线。
相关工作与启发¶
- vs TPDM(自回归调度):TPDM 用内部 latent + 当前时间自回归地预测下一步,目标是早停/变长采样求效率,推理需 \(O(L)\) 次策略调用;本文是单次前向输出整条 schedule、目标是固定预算下质量-步数 Pareto 前沿,推理常数开销。同设定下本文显著更强(且 TPDM-PPO 在主表里甚至劣于默认)。
- vs 蒸馏类少步数方法(Progressive/Consistency Distillation、Flux-Schnell):它们重训/蒸馏 backbone 换少步数,成本高且改了模型;本文不重训、只重排既有预算,是正交且互补的路线,5 步即可逼近蒸馏结果。
- vs RLOO / 跨 context baseline:RLOO 是逐 context 留一均值(\(b^*\) 的 within-context MC 近似),跨 context 是全局池化;本文证明二者分别是 JS 收缩的两个极端,JS 的经验贝叶斯凸组合可证明 MSE 更低。
- vs DDPO / DPOK / Diffusion-DPO(RL 对齐扩散):这些把去噪 latent 序列当动作来微调生成器权重;本文把整条 schedule 当一个高维开环计划、且不动权重,并针对由此暴露的高方差问题给出原理化的方差缩减方案。
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ "重排采样时间线作后训练杠杆 + James-Stein 收缩 baseline"两个点都新颖且互补,理论与应用兼具。
- 实验充分度: ⭐⭐⭐⭐ 覆盖 4 个 backbone × 5 种步数 + 文本渲染/GenEval 细粒度,baseline 对比扎实;缺人评、奖励类型偏少。
- 写作质量: ⭐⭐⭐⭐ 动机清晰、理论推导(命题/定理)与工程实现衔接好;公式密集,部分需结合附录。
- 价值: ⭐⭐⭐⭐⭐ 模型无关、即插即用、推理零额外开销,且 JS baseline 可外溢到 RLHF 等更广的策略梯度场景。