Esoteric Language Models: A Family of Any-Order Diffusion LLMs¶
会议: ICML 2026
arXiv: 2506.01928
代码: https://s-sahoo.com/Eso-LMs (有)
领域: LLM 预训练 / 离散扩散语言模型
关键词: Masked Diffusion LM, Any-Order AR, KV Cache, 因果注意力, 混合训练
一句话总结¶
Eso-LMs 把 AR 与 Masked Diffusion 在 loss、注意力、采样三个层面深度融合:用一个 causal-on-shuffled-sequence 的去噪 Transformer 同时支持并行扩散和左到右 AR,从而首次让 MDM 在扩散阶段也能用上精确 KV cache,在 OWT 长上下文上比 MDLM 快 14–65×、比 BD3-LM 快 3–4×,并在 speed–quality Pareto 前沿上取得 SOTA。
研究背景与动机¶
领域现状:语言模型正从纯 AR 走向"AR + 离散扩散"两条腿。AR 模型质量最好但只能逐 token 解码;以 MDLM 为代表的 Masked Diffusion LM (MDM) 支持并行生成、可控生成,规模化到 8B 后在 math/code/science 上已逼近 LLaMA。
现有痛点:MDM 落地有两大致命短板。其一,推理慢——虽然名义上"并行解码",但去噪 Transformer 使用 双向注意力,每一步都要 over full sequence 重新算一次 Q/K/V,无法 KV cache,长序列下比 AR 还慢。其二,likelihood 没办法精确算——NELBO 只是上界,做 GRPO 这类 RL 微调时连一个可用的 policy log-prob 都拿不到。BD3-LM 把序列切成 block,block 之间 AR、block 内 MDM,只能 cache block 之间,block 内仍要 full forward,而且 block 一小(≤16)就出现严重的"并行解码冲突",低 NFE 下样本质量塌掉。
核心矛盾:AR 的"因果注意力"是 KV cache 的前提,MDM 的"双向注意力"是并行去噪的前提——两者架构上互斥;任何想兼得二者优势的方案都要回答"用什么样的注意力既支持随机顺序去噪、又支持 KV 复用"。
本文目标:(1) 设计一个共享的去噪 Transformer,同时承载并行扩散和左到右 AR 两种生成模式;(2) 在扩散阶段也支持精确 KV cache(不是近似);(3) 给出 MDM 第一个可计算的精确 likelihood 公式,让 RL 类目标可用。
切入角度:作者抓住 Ou et al. (2025) 揭示的等价关系——MDM 的 NELBO 等价于在所有排列 \(\sigma\) 上平均的"Any-Order AR" loss \(L_\text{AO} = -\mathbb{E}_\sigma \sum_\ell \log p_\theta(x^{\sigma(\ell)} \mid x^{\sigma(<\ell)})\)。既然 MDM 本质就是 AO-AR,那就直接当 AR 来训:把 zt 里的 clean token 洗牌到前面、masked token 排后面,再用普通的因果注意力,就既是 MDM 又是 AR。
核心 idea:用"clean-tokens-first + causal attention on shuffled sequence"的去噪 Transformer 同时实现并行扩散和 AR;再加一项 AR loss + 一个特殊的稀疏注意力掩码,让 AR 阶段能 reuse 扩散阶段建好的随机顺序 KV,从而构成一个"先 MDM 并行铺一层、再 AR 填空"的两阶段采样器。
方法详解¶
整体框架¶
Eso-LMs 把生成过程分解为 \(p_\theta(x) = \sum_{z_0} p^\text{AR}_\theta(x \mid z_0)\, p^\text{MDM}_\theta(z_0)\) 两段:MDM 组件先并行去噪出一个部分 masked 的中间序列 \(z_0\)(平均 \(\alpha_0\) 比例的位置是 clean);AR 组件再把 \(z_0\) 中残留的 mask 从左到右补齐。\(\alpha_0\) 是一个连续超参——\(\alpha_0=1\) 退化成纯 MDM,\(\alpha_0=0\) 退化成纯 AR,中间值给出 perplexity 上 AR 和 MDM 的平滑插值。整个流程只用一个共享去噪 Transformer \(x_\theta\),靠不同的注意力掩码区分阶段。Variational bound 写出来正好是一项 AR 交叉熵 + 一项 MDM NELBO,训练时按比例 \(\kappa\)(默认 0.5)把 batch 一半喂 AR loss、一半喂 MDM loss。
关键设计¶
-
扩散阶段的 "clean-tokens-first + 因果注意力" 去噪 Transformer:
- 功能:把传统 MDM 的双向去噪 Transformer 改造成因果版本,但保持"任意位置随机顺序去噪"的能力,从而解锁扩散阶段的精确 KV cache。
- 核心思路:给 \(z_t \sim q_t(\cdot \mid x)\),作者把 clean token 连同其原始 positional embedding 一起洗牌到序列前面、mask token 排在后面,然后用标准左到右因果注意力训练去噪。这样一来:(i) 因为 clean token 之间是因果可见的,与采样时"前面步骤已经解出的 clean token"恰好对应——KV cache 可以一直被复用;(ii) 因为 mask token 只看到左侧的 clean token,不会看到未来还要去噪的 mask,符合采样时的因果性约束。每一步采样时 forward pass 只过"已 clean 的 token + 当前要去噪的 mask",而不是 full sequence——长序列下省的不是常数因子。
- 设计动机:MDM 不能 KV cache 的根本原因是双向注意力让"已经预测出的 token"依赖"未来要解码的 token",把这条边切掉就行。作者发现 Any-Order AR 视角下随机顺序的 MDM 其实只是 AR 的一种排列,因此只需把序列"按生成顺序"重排成因果序列即可,无需放弃并行——一次 forward 仍然能同时去噪一批 mask。这是把 MDM 的推理瓶颈从 \(O(L \cdot L)\) 降到 \(O(L)\) 的关键。
-
顺序阶段的 \(z_0 \oplus x\) 拼接 + 稀疏注意力掩码:
- 功能:让 AR 阶段(填补 \(z_0\) 残留的 mask)能够复用扩散阶段建好的、按随机顺序排列的 KV cache,而不是从头跑一遍。
- 核心思路:训练时把 clean+mask 的 \(z_0\) 与完整 \(x\) 拼成 \(z_0 \oplus x\) 长度 \(2L\) 的序列喂进同一个 Transformer,并设计一个 \(2L \times 2L\) 的结构化稀疏注意力 bias \(A\)(依赖一个排列 \(\sigma\)):(i) clean token 在 \(\sigma\) 下排在 mask 之前;(ii) mask token 保持自然顺序;(iii) 每个待 AR 预测的 mask 位置 \(i\) 只能 attend 到其左侧的真实 token \(x_{<i}\)。Transformer 在 \(x\) 侧的输出被丢掉,只用 \(z_0\) 侧 mask 位置的 logits 算 AR 交叉熵。由于 clean token 在扩散阶段是按 \(\sigma\) 顺序生成并 cache 的,AR 采样时直接复用这份 KV、再因果地一个个解 mask 即可。完整实现用 FlexAttention 写出来不到一屏代码(Fig. 9)。
- 设计动机:纯 AR 训练要求每个被预测的 token 都有"完整 clean 左 context",但 \(z_0\) 里夹杂的 mask 没有这种 context;常规做法只能放弃 cache 重新 forward。作者用拼接 + 稀疏 bias 这种"伪左 context"骗过去,等价地让 AR 学会"基于一个非自然顺序的 KV 序列"做条件预测——这是把 cache 在两阶段间无缝衔接的工程关键。代价是序列长度翻倍,但只有一半 batch 走 AR 训练,整体训练只比 MDLM 慢约 1.37×。
-
MDM 的首个精确 likelihood 估计 + 单次前向 NELBO:
- 功能:给出 MDM (以 Eso-LMs \(\alpha_0=1\) 为代理) 的首个 (渐近) 精确 likelihood 公式,并把 NELBO 的 Monte Carlo 估计从 \(L\) 次 forward 降到 1 次,让 GRPO 等 RL 算法终于可以在 MDM 上跑起来。
- 核心思路:基于 \(L_\text{AO}\) 等价性,作者证明了一个 importance-weighted 上界(Theorem 3.1):\(L^K_\text{AO} = -\mathbb{E}_{\sigma_{1:K}}\left[\log \tfrac{1}{K} + \log\sum_{k=1}^K \exp\sum_\ell \log p_\theta(x^{\sigma_k(\ell)} \mid x^{\sigma_k(<\ell)})\right]\),并证明 \(-\log p_\theta(x) \le L^K_\text{AO} \le L_\text{MDM}\)、\(L^K_\text{AO}\) 关于 \(K\) 单调递减、\(K\to\infty\) 时收敛到真 likelihood。更妙的是:一次排列 \(\sigma\) 就刻画了完整的扩散轨迹上 \(L\) 个 latent,所以 \(L_\text{AO}\) 在 Eso-LMs 上只需要一次 forward就能算完(MDLM 因为双向注意力做不到)。表 2 实测:MDLM 用 10 个 MC 样本算 \(L_\text{MDM}\) 标准差 0.56,Eso-LMs 用 1 个 \(\sigma\) 算 \(L_\text{AO}\) 标准差只有 0.03。
- 设计动机:MDM 想做 RL 微调(如 GRPO)必须能算 policy 的 \(\log p\)。MDM 原本的 NELBO 估计每个数据点要 \(L\) 次 forward,长序列下根本不现实;exact likelihood 更是缺失。Eso-LMs 的因果架构刚好让两件事同时成立。后续工作 Wang et al. (2025b) 已把这个估计器用作 GRPO 的 likelihood,并在 0.1B 和 8B 规模上都打过 Black et al. (2024) 与 Zhao et al. (2025)。
损失函数 / 训练策略¶
总目标是 (7) 式的变分上界:\(-\log p_\theta(x) \le \mathbb{E}_{z_0}[\text{AR loss}] + \mathbb{E}_{q_t,t}[\text{MDM loss}]\)。给定 batch,按 \(\kappa\) 拆分:\(\kappa=0.5\) 走扩散 loss、\(1-\kappa\) 走 AR loss(\(\alpha_0=1\) 时 \(\kappa=1\))。AR loss 里用替换算子 \(\odot\) 把 \(z_0\) 的前 \(\ell-1\) 位换成真实 \(x_{<\ell}\),保证被预测的 mask 有干净左 context。噪声调度采用线性 \(\alpha_t = \alpha_0(1-t)\)。\(\alpha_0=1\) 时把 MDM loss 的系数 \(\alpha'_t/(1-\alpha_t)\) 替换为 \(-1\),经验上降低训练方差、加快收敛。
实验关键数据¶
主实验¶
LM1B(\(L=128\), 1M steps)和 OWT(\(L=1024\), 250K steps)的测试 perplexity,AR/MDM 插值非常平滑:
| 方法 | LM1B PPL (NELBO) | LM1B PPL (Exact) | OWT PPL (NELBO) | OWT PPL (Exact) |
|---|---|---|---|---|
| AR Transformer | – | 21.86 | – | 17.78 |
| MDLM | 31.78 | 26.82 | 25.19 | – |
| BD3-LM (\(L'=4\)) | 28.23 | – | 20.96 | – |
| Eso-LM (\(\alpha_0=1\)) | 36.12 | 31.65 | 30.06 | 29.31 |
| Eso-LM (\(\alpha_0=0.5\)) | 32.53 | 28.07 | 27.94 | 26.61 |
| Eso-LM (\(\alpha_0=0.125\)) | 26.29 | 23.02 | 21.92 | 20.53 |
| Eso-LM (\(\alpha_0=0)\) | – | 21.86 | – | 17.78 |
长上下文采样延迟(OWT, \(T \gg L\),与 AR 相同 NFE 级别):
| 上下文 \(L\) | vs MDLM 加速 | vs BD3-LM (\(L'{=}16\)) 加速 | vs BD3-LM (\(L'{=}4\)) 加速 |
|---|---|---|---|
| 2048 | ~14× | 显著 | 显著 |
| 8192 | ~65× | ~3.2× | ~3.8× |
| 10240 (微调后) | 与 BD3-LM 同质量下 ~5× | – | – |
消融实验¶
| 配置 | 关键现象 | 说明 |
|---|---|---|
| Eso-LM (\(\alpha_0=1\), full) | LM1B NELBO 36.12 | 比 MDLM 差约 4 点 |
| Eso-LM (A):仅把 mask 上的 attention 改成因果,clean 仍双向 | \(\alpha_0=1\) 时与 MDLM 持平 | 说明 perplexity gap 主要来自"clean 之间也变因果"——这是为换 KV cache 付出的代价 |
| \(\kappa\) 扫描 (Table 4) | \(\kappa=0.5\) 最优 | AR/MDM loss 各占一半训练样本最佳 |
| MC 估计 NELBO (Table 2) | \(L_\text{AO}\) 单样本 σ=0.03 vs \(L_\text{MDM}\) 10 样本 σ=0.56 | 单次前向更准更省 |
| Block sampler vs 原 ancestral | 低 NFE 下 MAUVE 显著提升 | 只并行解远距离 mask,避免邻近冲突 |
关键发现¶
- Speed–quality Pareto 前沿(Fig. 4,MAUVE vs 采样耗时)上 Eso-LMs 全线压制 MDLM 与 BD3-LM;BD3-LM 在低 NFE 区间样本质量崩塌,Eso-LM 不崩。
- \(\alpha_0=1\) 训练已足够:作者实测一个 \(\alpha_0^\text{train}=1\) 的模型靠在采样阶段调 \(\alpha_0^\text{eval}\) 就能覆盖整个 Pareto 前沿,省得为每个工作点单训一个模型(Remark 2)。
- \(\alpha_0\) 越小越接近 AR,"exact PPL 与 NELBO PPL 的 gap" 也越小——侧面验证了 IW bound 与 NELBO 在不同插值点上的紧致性差异。
亮点与洞察¶
- "Any-Order AR ≡ MDM" 这条等价关系之前已知,作者第一次把它架构层落地:靠"洗牌 + 因果"两步就把 MDM 改造成可 KV cache 的 AR 变体,没引入任何额外参数。这是非常工程化但效果极强的 insight。
- 拼接 \(z_0 \oplus x\) + 稀疏 mask 这种"训练时拼双倍序列、推理时不拼"的设计很值得借鉴——它把"AR 需要左 context"和"MDM 给出随机顺序 KV"的矛盾通过训练阶段单独造一个上下文承担掉,推理时直接复用扩散阶段的 cache。
- Exact likelihood + 单次前向 NELBO 不只是理论好看:它直接把 MDM 接入了 GRPO 这套主流 RL 工具链,已被后续 8B 规模工作 (Wang et al. 2025b) 实证更优。这是 Eso-LMs 比表面 perplexity 数字更深远的影响。
- Remark "perplexity 在有限 NFE 下不反映质量" 是对 diffusion LM 评测范式的一次反思——\(\alpha_0=1\) Eso-LMs PPL 比 MDLM 差,但任何固定时间预算下样本质量都更好。提醒做扩散 LM 的人不要只刷 PPL。
局限与展望¶
- 作者承认:\(\alpha_0<1\) 时训练比 MDLM 慢约 1.37×(序列长度 doubled),但仍快于 BD3-LM;\(\alpha_0=1\) 下 NELBO 比 MDLM 差约 4 点(消融定位到"clean 之间也变因果"是主因);KV 复用有一步延迟,相同 NFE 下比 AR 略慢。
- 我看到的额外局限:(i) 实验全部停在 LM1B/OWT 这种 pretraining 学术规模 (~9K H200 GPU hours),没有 instruction tuning / 下游任务,scaling 主要靠引用 Sahoo et al. 2026 的 1.7B 结果背书;(ii) "\(\alpha_0^\text{train}=1\) 足够"是在 OWT 单一分布上验证的,不同领域是否仍成立未知;(iii) sequential phase 的 \(2L\) 拼接对内存友好度有限,long-context fine-tune 时 fp16/bf16 下的稳定性值得继续考察。
- 改进思路:把 Eso-LMs (A) 的"clean 间双向、mask 处因果"再发展一下,或许能找回 PPL 又不失 cache;另一条路是把 IW bound 直接接入 RLHF/RLAIF pipeline,做 MDM 版的 DPO/GRPO。
相关工作与启发¶
- vs MDLM (Sahoo et al., 2024a): 同为 MDM,但 MDLM 用双向注意力的 DiT 去噪、无法 cache;Eso-LMs 改为 causal-on-shuffled-sequence,长序列推理快一到两个数量级,代价是 \(\alpha_0=1\) 下 NELBO 略差。
- vs BD3-LMs (Arriola et al., 2025): 都做 AR–MDM 插值,但 BD3-LM 通过 block 大小插值、cache 仅在 block 之间,且小 block 在低 NFE 下质量崩塌;Eso-LMs 通过 \(\alpha_0\) 在 token 级别插值、cache 贯穿全程,Pareto 前沿全面更优。
- vs Pannatier et al. (2024) / Xue et al. (2025): 这些是 Eso-LMs 在 \(\alpha_0=1\) 的特例;Xue 额外引入 AdaLN 注入位置信息,Eso-LMs 完全靠 attention mask 实现、不加参数。
- vs 并发 KV cache 工作 (Hu 2025, Wu 2025, Ma 2025): 它们都是近似 cache(block 内仍要 full forward 或频繁刷新),长序列下退化严重;Eso-LMs 是精确 cache。
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ 把 Any-Order AR 视角真正在架构层落地,给出首个精确 likelihood 公式和扩散阶段精确 KV cache,两条都是 MDM 社区悬了很久的问题。
- 实验充分度: ⭐⭐⭐⭐ LM1B/OWT + 长上下文 + 消融 + Pareto 前沿都覆盖到了,唯独缺真实下游任务和大模型指令微调验证。
- 写作质量: ⭐⭐⭐⭐⭐ 公式、图示(Fig. 1 feature 对比、Fig. 2 unified KV cache 示意、Fig. 3 训练/注意力示意)非常清晰,把一套不直观的设计讲明白了。
- 价值: ⭐⭐⭐⭐⭐ 对 diffusion LM 是关键工程级解锁——长上下文 14–65× 提速 + 单次前向 NELBO 直接让 GRPO 在 MDM 上变得可行,已经被后续 8B 规模工作复用。