跳转至

dParallel: Learnable Parallel Decoding for dLLMs

会议: ICLR 2026
OpenReview: https://openreview.net/forum?id=hVOcstAURb
代码: https://github.com/czg1225/dParallel
领域: LLM 推理加速 / 扩散语言模型 / 并行解码
关键词: diffusion LLM, parallel decoding, certainty-forcing distillation, self-distillation, LLaDA, Dream

一句话总结

通过"确定性强制蒸馏"把扩散语言模型(dLLM)原本"逐字串行收敛"的预测确定性改造成"并行同时收敛",让 LLaDA-8B 在 GSM8K 上把解码步数从 256 砍到 30(8.5× 加速)而精度不降。

研究背景与动机

领域现状:扩散语言模型(dLLM,如 LLaDA、Dream)用双向注意力替代自回归的左到右生成,理论上每步能并行预测所有被 mask 的 token,被寄望于带来远低于 AR-LLM 的推理延迟。然而现有开源 dLLM 要维持生成质量,解码步数几乎仍要正比于序列长度(256 长度就要 256 步),并行潜力名存实亡。

现有痛点:已有加速工作分两路——一路(Dual-Cache 等)给 dLLM 加 KV 缓存降低单步耗时,一路(各种采样策略)优化 remasking 减少步数。但只要把并行度推高(一步定多个 token),性能就崩。原因没人讲清楚,大家只在"轨迹对齐"层面打转(teacher forcing / diffusion forcing),并没碰到问题本质。

核心矛盾:dLLM 虽然每步并行预测全部 mask token,但这些预测的"确定性(certainty/confidence)"是逐字左到右串行收敛的——每步只有紧挨已知上下文的少数 token 能达到高置信,其余一直低置信,直到新的确定 token 提供新上下文才轮到下一批。而作者实证发现高确定性是正确生成的必要条件(置信度与准确率强正相关),所以一旦强行提前 commit 低置信 token,就级联出错。这就是并行解码真正的瓶颈:串行的确定性收敛

本文目标:训练 dLLM,让它能在多个位置并行地达到高确定性,从而打破串行瓶颈、把解码步数大幅压缩而不掉点。

核心 idea[确定性即训练信号] 提出 certainty-forcing distillation——让模型沿自己原始的采样轨迹自蒸馏(保证轨迹一致不跑偏),同时直接把"对正确预测 token 的输出熵"压低,逼模型更快、更并行地把确定性顶到峰值。

方法详解

整体框架

方法是一个自蒸馏(self-distillation) 训练流程:用预训练 dLLM 当 teacher 生成目标轨迹,再让一份完全相同的 student 复制体沿这条轨迹学习,但训练目标多了一条"逼高确定性"的项。全程不改变原始生成轨迹,只重塑确定性的收敛节奏(从串行变并行),训练用 LoRA,8 张 24GB A5000 上 10 小时即可完成。

flowchart TD
    A[预训练 dLLM = Teacher] -->|半自回归采样| B[目标轨迹 Y, 分成 N 个 block]
    B --> C[半自回归前向 mask: 上下文块/活跃块/未来块]
    C --> D[Student 前向预测]
    D --> E["LConsistency: 活跃块 mask token 的 CE 损失(对齐轨迹)"]
    D --> F["LCertainty: 仅对预测正确的 token 压低输出熵(逼并行高确定)"]
    E --> G["LCFD = LConsistency + β·LCertainty"]
    F --> G
    G -->|LoRA 更新| D

关键设计

1. 半自回归前向 mask:让训练状态对齐真实采样过程。 Teacher 先用半自回归 remasking(总长 \(L\)、块大小 \(L_b\))生成目标响应 \(Y=(y_1,\dots,y_L)\) 并切成 \(N=L/L_b\) 个连续块。训练时不像标准 dLLM 预训练那样在全序列随机 mask,而是先采一个块索引 \(n\),把序列切成三段构造噪声输入 \(\tilde Y\):第 \(n\) 块之前是上下文块(全保留,当条件)、第 \(n{+}1\) 块是活跃块(以概率 \(q\) 把 token 替换成 [MASK])、其后是未来块(全 mask,因为还没生成)。这样 \(\tilde Y\) 恰好模拟了"已知前 \(n\) 块、正在生成第 \(n{+}1\) 块"的半自回归中间态,使自蒸馏的输入分布与推理时真正会遇到的状态一致——消融显示这一步对最终的高效率+高精度不可或缺。

2. 一致性损失:沿原始轨迹自蒸馏,不让模型跑偏。 学习信号只施加在活跃块 \(B_{n+1}\) 的 mask 位置集合 \(M_a\) 上,用标准交叉熵把 student 拉向 teacher 轨迹的正确 token: $\(\mathcal{L}_{\text{Consistency}} = -\frac{1}{|M_a|}\sum_{i\in M_a}\log p_\theta(y_i\mid \tilde Y).\)$ 它保证 student 的生成轨迹与原模型一致,但单靠它解决不了串行收敛:CE 只关心"对不对",一旦预测对了梯度迅速消失,完全没有动力去把置信度进一步顶高——这正是 consistency distillation 提速有限的原因。

3. 确定性强制损失:直接把"已对"token 的熵压到更低。 这是全文核心。先取出活跃块里 student 已经预测正确的 token 集合 \(M_c=\{i\in M_a\mid \arg\max_v p_\theta(v\mid\tilde Y)=y_i\}\),只对这部分 token 最小化带温度 \(T\) 的输出分布熵: $\(\mathcal{L}_{\text{Certainty}} = \frac{1}{|M_c|}\sum_{i\in M_c}\Big(-\sum_{v\in V} p_\theta(v\mid\tilde Y;T)\log p_\theta(v\mid\tilde Y;T)\Big).\)$ 之所以只对已预测正确的 token 压熵,是为了让 CE 管"方向正确"、熵项管"信心更足",二者分工:前者保证收敛到 teacher 轨迹,后者在已对的位置上把置信度并行地推到峰值,使更多 token 能在同一步跨过 commit 阈值。温度 \(T\)(实验取 0.5)控制强制力度。总目标为 $\(\mathcal{L}_{\text{CFD}} = \mathcal{L}_{\text{Consistency}} + \beta\,\mathcal{L}_{\text{Certainty}}.\)$ \(\beta\) 平衡"贴 teacher 轨迹"与"逼高确定性"。这套组合把原本左到右串行爬升的置信曲线,改造成多位置并行同时冲顶——推理时配合 entropy-threshold 的半自回归 remasking,一步就能确定多个 token。

实验关键数据

主实验(LLaDA-8B-Instruct,序列长 256/块长 32)

Benchmark Method #Steps ↓ Latency ↓ Speedup ↑ Acc ↑
GSM8K-CoT LLaDA-8B(原) 256 18.6s 1.0× 75.7%
Conf-threshold 72 5.2s 3.6× 75.5%
Consistency Distill 64 4.7s 4.0× 69.9%
dParallel(本文) 30 2.2s 8.5× 76.1%
MATH(4-shot) LLaDA-8B 256 50.9s 1.0× 33.5%
dParallel 46 8.9s 5.7× 31.5%
HumanEval LLaDA-8B 256 23.5s 1.0× 38.4%
dParallel 33 2.9s 8.2× 40.2%
MBPP(3-shot) LLaDA-8B 256 50.1s 1.0× 42.4%
dParallel 24 4.8s 10.5× 40.8%

Dream-7B-Instruct 上同样有效:GSM8K 39 步 6.9× 加速(82.1% vs 原 82.9%),MBPP 29 步 8.8× 加速。注意 Dream 因初始化自 AR-LLM,半自回归 mask 会让它退化回 AR,故改用全序列随机 mask,加速幅度略低于 LLaDA。

消融实验(LLaDA,GSM8K-CoT / HumanEval)

Consistency Certainty Semi-AR Mask GSM8K Steps/Speed/Acc HumanEval Steps/Speed/Acc
53 / 4.5× / 73.5% 71 / 3.6× / 36.0%
23 / 10.4× / 57.8% 28 / 9.8× / 30.5%
44 / 5.5× / 73.3% 61 / 4.3× / 32.9%
30 / 8.5× / 76.1% 33 / 8.2× / 40.2%

mask 比例消融:50% 最优(76.3% Acc);只去 certainty loss 退回基线速度,只留 certainty loss 速度飙到 10.4× 但精度崩到 57.8%。

关键发现

  • 置信度=正确性的必要条件:置信度与生成准确率强正相关,低置信强行 commit 必出错。
  • 三个组件缺一不可:一致性损失管"不跑偏"、确定性损失管"并行冲顶"、半自回归 mask 管"训练-推理状态对齐",任意去一个都掉点或掉速。
  • 跨任务泛化:LLaDA 只用数学 prompt 训练,代码任务上的并行解码能力也显著提升。
  • trade-off 曲线碾压:同 9.4× 加速下,本文在 LLaDA-GSM8K 比 conf-threshold 高 16.5% 精度,HumanEval 高 21.3%。

亮点与洞察

  • 抓住了真问题:把"dLLM 并行解码为何不行"从含糊的"轨迹没对齐"精确定位到"确定性串行收敛",并用实证(置信度传播图、收敛轨迹图)把它讲透,这个洞察本身比方法更有价值。
  • 熵正则的巧用:不是泛泛地最小化全序列熵,而是只对已预测正确的 token 压熵——既不破坏正确性又精准提升信心,这个限定是方法 work 的关键。
  • 极低成本:纯自蒸馏,不需任何外部标注数据(target 全是模型自己生成的);LoRA + 8×A5000 训 10 小时,工程上非常可复现。

局限与展望

  • AR-初始化模型受限:Dream 用半自回归 mask 会退化回 AR,只能退而求其次用随机 mask,加速幅度明显低于原生 dLLM LLaDA,说明方法对"模型是否原生扩散"敏感。
  • 数学/代码为主:训练与评测集中在 GSM8K/MATH/HumanEval/MBPP 这类有明确答案、可过滤错误轨迹的任务,开放式生成、长文本上的并行收益尚未验证。
  • 超参敏感:温度 \(T\)、权重 \(\beta\)、mask 比例都需调(50% mask 才最优),缺乏自适应机制。
  • 理论留白:为何"压已对 token 的熵"能从串行收敛跳到并行收敛,目前是实证现象,缺乏更深的理论刻画。

相关工作与启发

  • dLLM 加速两条主线:降单步耗时(Dual-Cache/KV 缓存/token dropping)vs 减步数(改 remasking 采样策略)。本文属第三条路——直接训练改变模型的确定性动力学,与前两者正交,可叠加。
  • 蒸馏类方法:SDTT 用渐进蒸馏减步、Consistency Distillation 学预测剩余 mask token,但都只对齐轨迹,本文指出这对解决串行收敛"不够"。
  • 启发:把"模型内部某种统计量(这里是预测确定性)的演化节奏"作为直接优化对象,而非只盯着输出正确性,可能是加速一类迭代式生成模型(扩散、流匹配)的通用思路;"只对已正确样本做正则"也是一个值得迁移的 trick。

评分

  • 新颖性: ⭐⭐⭐⭐⭐ 把并行解码瓶颈精确归因到"确定性串行收敛",并用确定性作直接训练信号,视角新且切中要害。
  • 实验充分度: ⭐⭐⭐⭐ 两个代表性 dLLM × 四个 benchmark × 多基线 + 完整消融(策略/mask 比例)+ trade-off 曲线,扎实;但任务类型偏窄(数学/代码)。
  • 写作质量: ⭐⭐⭐⭐ 问题定位清晰、实证图表有力(置信度传播/收敛对比),方法叙述简洁。
  • 价值: ⭐⭐⭐⭐⭐ 8.5–10.5× 无损加速 + 极低训练成本,直接把 dLLM 推理效率拉到可用区间,为 few-step / 并行 dLLM 立了新 baseline。