跳转至

RebRL: Reinforcing Discrete Visual Diffusion Models with Rebalanced Timestep Credits

会议: CVPR 2026
论文: CVF Open Access
代码: 无(论文未提及开源)
领域: 扩散模型 / 图像生成 / 强化学习
关键词: 离散扩散模型, 强化学习, GRPO, 信用分配, 文本到图像生成

一句话总结

针对离散扩散模型(DDM)用 GRPO 做强化学习时被忽视的「时间步信用分配严重失衡」问题,本文从策略梯度推导出失衡的数学根源,并提出即插即用的 RebRL——通过时间步级与 token 级两层重平衡因子拉平累积梯度,在 GenEval 上达到 SOTA,人类偏好分最高提升 3.40,同时训练步数减少约 40%。

研究背景与动机

领域现状:离散扩散模型(Discrete Diffusion Models, DDM)因为能在一次前向里并行预测多个离散 token、且天然兼容语言模型,正成为视觉生成的新热点。把 GRPO 这类 RL 方法接到 DDM 上做后训练(对齐人类偏好、提升组合生成能力)也开始流行,代表工作有 diffu-GRPO、UniGRPO、MaskGRPO。

现有痛点:DDM 没有自回归模型那种「逐 token 的对数似然分解」,重要性比率(importance ratio)只能在整条序列层面算,拿不到逐 token 的估计。为了让模型从不同推理时间步学习,现有做法是对完全去噪后的 token 序列施加不同的 mask 比例来模拟多步推理。本文实验发现,这套「模拟去掩码」流程会造成严重的信用分配失衡:越晚被解码(mask 比例越低)的 token,会在策略梯度求和式里出现越多次,累积到的梯度尺度越大。

核心矛盾:作者用 ∆HPS(沿推理轨迹的人类偏好分变化量)度量「探索潜力」,发现早期时间步 ∆HPS 大、探索潜力高、决定全局结构,本该获得更大梯度;但 GRPO 目标只在「当前仍被 mask 的 token」上评估,导致早期高价值 token 反而拿到最小的梯度。梯度走势与重要性走势恰好相反——这正是失衡的本质。

本文目标:在不引入额外存储/计算开销、不加超参的前提下,把跨时间步、跨 token 的累积梯度尺度拉平,从而改善探索-利用权衡、加速收敛。

切入角度:先对 DDM 的策略梯度做一次原理性推导,把「为什么会失衡」用闭式的累积尺度 \(w(\Delta t_j)\) 写出来,再据此设计重平衡因子——这样修正是「有解析依据」的,而非拍脑袋的启发式。

核心 idea:用「时间步级因子 \(\lambda(t_j)\) + token 级频率倒数 \(1/F\)」两层重平衡,替换 UniGRPO 那种对所有时间步一视同仁的均匀策略。

方法详解

整体框架

RebRL 不改 GRPO 的整体流程,而是在 GRPO 目标的梯度里插入两个重平衡因子。整条管线仍是标准的 GRPO 后训练循环:对每个 prompt \(c\) 采样 \(G\) 个 rollout、算组内相对优势 \(A_i\)、在 \(\mu\) 个内层迭代里用不同 mask 比例 \(t_j\) 模拟生成的各个时间步。RebRL 的全部改动集中在「算损失时怎么给每个时间步、每个 token 加权」这一步,因此它是即插即用、几乎零额外参数与算力、且无新增超参(重平衡因子都来自模型自身统计量)。

DDM 把生成看作连续时间的吸收态 Markov 过程:clean 数据记为 \(x_0\),随 \(t\to 1\) 逐步加噪。GRPO 适配到 DDM 时,重要性比率按整条序列、逐步(per-step)计算,奖励项写作对 \(\mu\) 个 reverse step 的求和(式 5)。本文的策略梯度近似形式(式 9)揭示出关键问题:由于指示函数 \(\delta(o^{t_j}_k, m)\) 的存在,一个 token 只要还被 mask 就持续贡献梯度,于是越晚解码的 token 在求和式里出现的项越多。对线性调度 \(t_j=j/\mu\),某 token 在 mask 比例 \(t_j\) 处被揭示时,其贡献被累积尺度加权:

\[w(\Delta t_j) = \sum_{j'=j}^{\mu} \frac{1}{t_{j'}} = \mu \sum_{j'=j}^{\mu} \frac{1}{j'}\]

\(w(\Delta t_j)\) 随揭示时刻越晚而单调增大,这就是失衡的解析来源。RebRL 据此分两层把它压平(这是纯粹的梯度重加权改进,机制靠公式说清即可,不需要 pipeline 框架图)。

关键设计

1. 策略梯度失衡的解析诊断:把「为什么晚 token 梯度大」写成闭式

这是后续两个设计的地基。作者先把 DDM 在 GRPO 下的策略梯度近似成可计算形式(式 9),用 DDM 自身的 ELBO 损失项 \(\ell_{\pi_\theta}\) 当作不可解的对数似然 \(\log\pi_\theta(o|o^{t_j},c)\) 的代理。由于每个 token 只在「仍被 mask」时进入求和,晚解码 token 比早解码 token 多出现在更多的 \(j\) 项里,累积尺度 \(w(\Delta t_j)=\mu\sum_{j'=j}^{\mu}1/j'\) 随解码时刻变晚单调递增、且步进增量随 reverse index \(j\) 减小而变大,使失衡进一步加剧。这个闭式诊断的价值在于:它把「现有 mask 策略缺乏理论依据」这件事变成了一个可被精确修正的目标——只要乘上一个能抵消 \(w(\Delta t_j)\) 增长的因子,就能把梯度尺度拉平。

2. 时间步级重平衡 \(\lambda(t_j)\):先把晚时间步的整体梯度压下去

针对「晚时间步累积梯度过大」这个一阶问题,作者在策略梯度里乘一个随 \(t_j\) 单调递增的时间步因子 \(\lambda(t_j)\)(式 11),给大 mask 比例(即生成早期)的 token 更高权重。给了两个简单选择:一阶因子 \(\lambda(t_j)=t_j\) 与二阶因子 \(\lambda(t_j)=t_j^2\),对应的重平衡后累积尺度为

\[w^{(1)}(\Delta t_j) = 1-\frac{j}{\mu}, \qquad w^{(2)}(\Delta t_j) = \frac{\sum_{j'=j}^{\mu} j'}{\sum_{j'=1}^{\mu} j'}\]

这样晚时间步的整体尺度被显著抑制,曲线明显比无重平衡更平。但它有个根本局限:\(\lambda(t_j)\)对一条序列里所有被 mask 的 token 一视同仁地施加的,晚解码的 token 仍会继续累积梯度——只是缓解、没有根治。

3. Token 级重平衡 \(1/F\):按掩码频率给每个 token 反比加权,彻底拉平

要根治就得打破「同一时间步内对所有 token 一刀切」的处理。作者注意到:一个 batch 内每个 token 参与梯度更新的次数本就不同,于是构造一张掩码频率图 \(F\),统计每个 token 在 rollout 阶段被 mask 的总次数 \(F(k)=\sum_{j=1}^{\mu}\delta(o^{t_j}_k, m)\),再把每个 token 的损失乘以预先算好的权重 \(1/F(k)\)(算法 1 第 17 行 \(J_{\text{rebalanced}}\leftarrow 1/F \odot J\))。被 mask 次数越多(即越晚解码、累积梯度越大)的 token,权重越小,恰好把不同 token 的累积梯度贡献等化。这个因子完全来自模型自身的掩码统计、不引入任何新超参,却把失衡从「时间步级缓解」推进到「token 级根治」,让每个 logit 都被充分利用、同时保留对整条序列的全局视角。

损失函数 / 训练策略

基座为多模态 DDM MMaDA-8B-Base,从公开预训练 checkpoint 初始化。学习率 \(3\times10^{-6}\),6×A100 训练(另 2 卡跑 reward server),全局 batch 72、每 prompt 采 \(G=9\) 个 rollout,推理 12 步、CFG scale 3.5,共训练 1800 个全局步、每步 \(\mu=8\) 个内层 GRPO 迭代。整体目标沿用 GRPO 的「奖励 − KL 惩罚」权衡(式 4),RebRL 仅在内层把 \(\lambda(t_j)\)\(1/F\) 依次乘进目标 \(J\)

实验关键数据

主实验

GenEval 组合生成基准(MMaDA 为基座,对比各 DDM 强化学习方法):

方法 Counting Position Attr. Binding Overall↑
MMaDA*(基座) 0.60 0.61 0.67 0.74
w/ diffu-GRPO 0.72 0.78 0.74 0.82
w/ UniGRPO 0.77 0.75 0.75 0.82
w/ MaskGRPO 0.69 0.77 0.72 0.84
w/ RebRL(本文) 0.81 0.83 0.76 0.86

RebRL 把基座 0.74 拉到 0.86,且增益集中在决定全局结构的属性上——Counting 0.81(次优 UniGRPO 0.77)、Position 0.83(次优 diffu-GRPO 0.78),印证「拉平梯度后策略能更有效地从高 mask 比例的早期样本学习」这一核心主张。

人类偏好对齐(HPS 系列打分,⚠️ HPSv3 为越大越好的偏好分,具体量纲以原文为准):

方法 HPSv3↑ ImageReward↑ DeQA↑
MMaDA 9.61 0.99 3.98
w/ diffu-GRPO 11.99 1.20 4.15
w/ UniGRPO 12.09 1.19 4.16
w/ MaskGRPO 12.35 1.21 4.19
w/ RebRL(本文) 13.01 1.28 4.22

RebRL 的 HPSv3 13.01 显著超过所有 DDM-RL 方法,相对基座提升约 3.40 分。

消融实验

四种重平衡策略对比(HPSv3 reward 与 KL 散度趋势,定性结论):

配置 reward 收敛 训练稳定性(KL) 说明
w/o Re-balance(UniGRPO) 最慢、终值最低 KL 高且快速上升,比其他高一个数量级 累积梯度失衡基线
1-Order(\(\lambda=t_j\) 明显加快 KL 低且稳定 时间步级缓解
2-Order(\(\lambda=t_j^2\) 明显加快 KL 低且稳定 时间步级缓解
Token-level(\(1/F\) 最快、终值最高 KL 低且稳定 token 级根治

关键发现

  • token 级重平衡贡献最大:它在 token 层面彻底解决失衡,reward 收敛最快、终值最高;时间步级只能缓解。
  • 训练稳定性差异显著:无重平衡基线的 KL 损失最终比三种重平衡策略高一个数量级,说明失衡梯度会让策略剧烈偏离参考策略、易 reward hacking。
  • diffu-GRPO 的全掩码策略是另一类权衡:它对整条 rollout 全部 mask,天然避开累积失衡,早期(<400 步)收敛极快甚至超过 RebRL;但「全掩码一步预测」与多步迭代推理严重不一致,导致训练-推理偏差,后期饱和、收敛到更低上界。RebRL 既解决失衡又忠于多步生成过程,最终稳定且更优。
  • 效率:相比 UniGRPO,RebRL 在达到同等性能时训练步数减少约 40%。

亮点与洞察

  • 把启发式 mask 策略变成可证明的修正目标:先用闭式累积尺度 \(w(\Delta t_j)\) 把失衡量化,再设计恰好抵消它的因子——这种「先诊断后开方」的路径比直接拍一个 reweighting 更有说服力,也解释了为何它无需调超参。
  • \(1/F\) 频率倒数加权非常优雅:重平衡权重完全来自 rollout 阶段的掩码统计,自适应、零额外参数、零额外超参,却把失衡从时间步级推进到 token 级,这个 trick 可迁移到其他基于 mask 的生成模型 RL。
  • 「早期时间步价值更高」用 ∆HPS 实证:用人类偏好分变化量度量探索潜力,直观展示「梯度走势与重要性走势相反」,为信用分配研究提供了可借鉴的诊断工具。

局限与展望

  • 作者承认:当前重平衡因子都是经验式\(t_j\)\(t_j^2\)\(1/F\)),如何升级为可学习的范式仍待探索。
  • 自己发现:实验只在 MMaDA-8B 单一基座 + GenEval/HPS 上验证,跨基座、跨任务(如视频离散扩散)的普适性未充分检验;⚠️ 二阶因子 \(w^{(2)}\) 与一阶 \(w^{(1)}\) 的优劣边界没有给出明确选择准则。
  • 改进思路:把 \(\lambda(t_j)\)\(1/F\) 参数化为可学习的网络输出,或结合 ∆HPS 信号在线自适应调整权重,可能进一步贴合真实探索潜力。

相关工作与启发

  • vs UniGRPO / MaskGRPO: 它们都用变化 mask 比例模拟生成,但对所有时间步均匀加权,造成累积梯度失衡;RebRL 在同一框架内插入两层重平衡因子,几乎零成本地把梯度拉平,是它们的即插即用补丁。
  • vs diffu-GRPO: 它用全掩码单步目标规避失衡,但带来训练-推理偏差、上界更低;RebRL 保持多步生成的忠实性同时解决失衡,后期性能更优。
  • vs 连续 flow-matching 的信用分配工作: 连续模型里也观察到「均匀信用分配无法刻画各生成步重要性差异」;RebRL 把这一洞察专门化到离散、mask-based 的 DDM 上,并给出 token 级解法。

评分

  • 新颖性: ⭐⭐⭐⭐ 把 DDM-RL 的信用失衡问题首次解析化并给出零超参解法,角度新颖但属于 GRPO 框架内的改进
  • 实验充分度: ⭐⭐⭐ GenEval + HPS 充分,但仅单一基座、缺跨任务验证
  • 写作质量: ⭐⭐⭐⭐ 从诊断到方法逻辑清晰,公式与可视化(梯度尺度、∆HPS)支撑到位
  • 价值: ⭐⭐⭐⭐ 即插即用、训练提速约 40%,对 mask-based 生成模型的 RL 后训练有较强实用价值