Learning to Parallel: Accelerating Diffusion Large Language Models via Learnable Parallel Decoding¶
会议: ICLR 2026
OpenReview: https://openreview.net/forum?id=bFJ8Sdr224
代码: 已开源(项目页 + GitHub,论文 abstract 中给出)
领域: LLM 推理加速 / 扩散语言模型
关键词: diffusion LLM, 并行解码, 可学习 filter, 推理加速, LLaDA, KV-Cache
一句话总结¶
针对扩散语言模型(dLLM)并行解码依赖固定启发式(如置信度阈值)、对不同输入不自适应的痛点,本文用一个极轻量(2 层 MLP、约 2 千参数、6 分钟训练)的可学习 filter 去逼近"一旦预测正确就立即定稿"的 oracle 策略,再配合 End-of-Text 早停,在 LLaDA-8B 上实现最高 22.58× 加速且几乎不掉点,叠加 KV-Cache 后达 57.51×。
研究背景与动机¶
- 领域现状:自回归 LLM 逐 token 解码需要 \(O(n)\) 个串行步,吞吐受限。扩散语言模型(LLaDA、Dream、DiffuLLaMA 等)通过迭代去噪一次性刷新整段序列,理论上可并行出多个 token,并普遍采用"分块从左到右"的半自回归解码来兼顾质量与并行度。
- 现有痛点:要真正吃到并行红利,关键在于"每一步该定稿哪些 token、哪些该重新 mask"。现有方法(confidence 阈值、Fast-dLLM、SlowFast-Sampling、Prophet 等)都是静态、与输入无关的启发式——一刀切的规则在不同任务/不同样本上无法自适应,导致速度-质量权衡次优。
- 核心矛盾:作者实测发现 dLLM 存在严重的冗余重复解码——大量 token 早就被预测正确,却因保守的 remask 策略被反复 mask、反复重算。在 GSM8K 上多数 token 在首次预测正确后还要被解码 10 次以上;按 EGP oracle 统计每个 block 中位数只需 2 步,而 vanilla 要 32 步。此外当生成长度拉到 1024 时,约 89.59% 的算力浪费在
[EoT]之后的 padding token 反复解码上。 - 本文目标:把"一刀切静态规则"换成"逐样本、逐 token 自适应"的并行解码策略,在不动主模型参数、不掉精度的前提下榨干并行潜力。
- 核心 idea:[Oracle 逼近] 先定义一个用 ground-truth 的理想策略 EGP(预测对就立刻定稿),证明它有 15-20× 加速潜力;再训练一个只看置信度信号的轻量 filter 在推理时逼近这个 oracle;[早停] 额外用 EoTP 在
[EoT]出现时砍掉后续 padding。
方法详解¶
整体框架¶
Learn2PD 把"该不该定稿某个 token"建模为一个二分类问题:主扩散模型每步给出各位置的预测 token 和置信度 \(c_i\),把这些置信度喂给一个冻结的轻量 filter \(f_\theta\),由它输出每个位置"无需 remask"的 logit;超过阈值 \(\tau\) 的位置立即定稿、退出 mask 集合。主模型参数全程冻结,只在后训练阶段花极少算力训练 \(f_\theta\)。再叠加 EoTP 早停模块处理长生成场景。
flowchart LR
A[掩码序列 X] --> B[扩散主模型 M<br/>冻结]
B -->|预测 token + 置信度 conf| C[Filter 模型 fθ<br/>2层 MLP]
C -->|logit > τ?| D{定稿判断}
D -->|是| E[定稿该 token<br/>移出 mask 集]
D -->|否| F[保持 [MASK]<br/>下一步重算]
E --> G[EoTP: 检测到 EoT<br/>丢弃后续位置]
G --> B
F --> B
关键设计¶
1. EGP oracle:把"冗余解码"量化成可逼近的上界。 作者先建一个理想策略 Extremely Greedy Parallel:在第 \(k\) 步,当且仅当模型预测 \(M(x_k)_i = y_i\)(\(y_i\) 为参考答案)时才解掩该位置,永不 remask 已正确的 token。它需要 ground truth 因而推理时不可用,但意义在于给出了"并行能快到什么程度"的天花板——实测 15-20× 加速且不掉质量,每 block 中位数仅 2 步 vs vanilla 的 32 步。这把一个模糊的"加速空间"变成了一个明确的、可被监督学习模仿的目标。
2. 可学习 filter \(f_\theta\):用置信度模式逼近 oracle。 关键观察是扩散模型的置信度有可预测的波动规律——置信度本身就携带"模型对该预测是否真正接受"的信息,足以判断某 token 是否已收敛。于是把 EGP 的定稿决策蒸馏为二分类,用 BCE 损失训练 filter: $\(\mathcal{L}_{\text{BCE}} = -\frac{1}{m}\sum_{i=1}^{m}\Big[y_i\log\sigma(z_i) + (1-y_i)\log(1-\sigma(z_i))\Big]\)$ 其中 \(y_i\in\{0,1\}\) 是 EGP 给出的标签(1=可定稿,0=需 remask),\(z_i=f_\theta(\text{conf})\) 是 filter logit,经 \(\sigma\) 后与阈值 \(\tau\) 比较离散化。训练分两阶段:先按 EGP 策略跑一遍收集每步的置信度与定稿标签(4×A6000 约 3 小时),再用这批数据在一张 T4 上训 filter(仅 6 分钟)。令人意外的是,最简单的两层 MLP(block size 32 时仅 2,112 个可训练参数)就已足够——block 级置信度模式信息量充分,无需复杂结构或特征工程;推理时 filter 冻结、无梯度更新,额外开销可忽略。
3. End-of-Text Prediction(EoTP):砍掉 padding 的算力黑洞。 当生成长度设为 1024 而真实答案远短时,多出来的位置全被 [EoT] 填充,且模型会反复解码这些 padding——实测占总算力的 89.59%。EoTP 的做法很直接:每一步一旦某个 block 内解出 [EoT] token,就丢弃其后所有位置,用缩短后的序列作为下一步输入,从而在去噪过程中动态压缩有效长度。它与 Learn2PD 正交,主要在长生成场景带来额外大幅加速(22.58× 中相当一部分来自此模块)。
实验关键数据¶
主实验(LLaDA-8B-Instruct,TPS=tokens/sec,Score=任务精度)¶
| 任务 | 方法 | Gen Len | TPS | 加速 | Score |
|---|---|---|---|---|---|
| GSM8K (5-shot) | LLaDA baseline | 1024 | 0.54 | 1.00× | 77.60 |
| + Learn2PD | 1024 | 6.63 | 12.21× | 77.26 | |
| + Learn2PD + EoTP | 1024 | 12.26 | 22.58× | 79.83 | |
| Math (4-shot) | + Learn2PD + EoTP | 1024 | 12.27 | 7.22× | 34.60 |
| HumanEval (0-shot) | + Learn2PD + EoTP | 1024 | 6.63 | 12.55× | 35.98 |
| MBPP (3-shot) | + Learn2PD + EoTP | 1024 | 9.89 | 17.16× | 11.02 |
长度 256 时通常 3-5× 加速,长度 1024 时 6-22× 加速,精度普遍落在基线 ±1-2 点内,部分任务(GSM8K)甚至略升。
KV-Cache 兼容 + 消融¶
| 配置 | TPS | 加速 | Score |
|---|---|---|---|
| Learn2PD & EoTP | 12.26 | 22.58× | 79.83 |
| + Dual Cache | 31.23 | 57.51× | 74.00 |
| + Prefix Cache | 14.79 | 27.23× | 77.71 |
| Filter 深度 | TPS | 加速 | Score |
|---|---|---|---|
| 单层 | 8.77 | 2.57× | 78.62 |
| 两层 | 14.07 | 4.13× | 78.62 |
| 四层 | 11.41 | 3.35× | 78.85 |
关键发现¶
- 方法与 KV-Cache 完全正交,叠加 Dual Cache 把加速推到 57.51×(精度略降到 74.00),Prefix Cache 则在几乎不掉点下到 27.23×。
- 两层 MLP 是甜点:单层表征不足、四层精度微升但速度反降,2 层在效率/质量间最优。
- 生成长度越长收益越大:128→1024 加速从 3.36× 升到 22.58×,因为长序列里 EoTP 能砍掉的 padding 冗余更多。
亮点与洞察¶
- 把"加速空间"先证明再逼近:EGP oracle 这一步很漂亮——先用 ground truth 量化出 15-20× 的天花板,证明冗余确实存在且巨大,再用可学习 filter 去蒸馏,逻辑闭环、说服力强。
- 极致轻量的后训练:只训 2 千参数、6 分钟 T4 时间、主模型完全冻结,几乎零成本即插即用,工程友好度极高。
- 正交叠加:Learn2PD(解决 block 内冗余 remask)+ EoTP(解决长序列 padding)+ KV-Cache(解决步间重复算)三条独立路径相乘,复合加速到两位数甚至 57×。
局限与展望¶
- 只在 LLaDA-8B 单一主模型上验证:对 Dream、DiffuLLaMA 等其它 dLLM 的泛化性未充分展示。
- filter 监督依赖"参考答案":标签来自 LLaDA 自身标准解码产生的 reference answer,本质是蒸馏自身行为,filter 上限受主模型质量约束;若主模型本身错了,filter 学到的是"自信地定稿一个错 token"。
- 激进定稿的质量风险:叠加 Dual Cache 后精度从 79.83 掉到 74.00,说明在追求极端速度时质量并非完全无损,τ 与缓存策略的组合需谨慎调。
- 训练数据来自 FLAN 66 类共 2640 样本:filter 的跨域鲁棒性(尤其面对训练分布外任务)还需更多检验。
相关工作与启发¶
- 对比静态加速方法:Fast-dLLM(置信度阈值 + 近似 KV-Cache)、SlowFast-Sampling(两阶段采样器)、Prophet(top-2 logit gap 早提交)、dllm-Cache/FreeCache(训练free缓存)——它们都是固定规则,本文最大区别是把"定稿决策"做成可学习、输入自适应的。
- 启发:将一个不可用的 oracle(依赖 ground truth)显式定义出来、量化其上界、再用最廉价的监督模型逼近,是一种通用且高性价比的加速范式,可迁移到投机解码、early-exit 等"该不该停/该不该信"的决策场景。置信度信号作为唯一输入特征竟足够,提示扩散模型的内部不确定性结构有很强的可读性,值得进一步挖掘。
评分¶
- 新颖性: ⭐⭐⭐⭐ — "EGP oracle 量化 + 轻量 filter 逼近"的组合在 dLLM 并行解码里是首个可学习策略,思路清晰且立得住。
- 实验充分度: ⭐⭐⭐ — 四个 benchmark + KV-Cache 兼容 + 深度/长度消融较完整,但只用了 LLaDA 单一主模型,跨 dLLM 泛化欠缺。
- 写作质量: ⭐⭐⭐⭐ — 从"发现冗余→定义 oracle→证明潜力→蒸馏逼近"的叙事链条非常顺,图表支撑充分。
- 价值: ⭐⭐⭐⭐ — 几乎零成本、即插即用、与现有加速正交,对 dLLM 实际部署有直接落地价值。