Beyond Masked and Unmasked: Discrete Diffusion Models via Partial Masking¶
会议: NeurIPS 2025
arXiv: 2505.18495
代码: 无
领域: 图像生成 / 离散扩散
关键词: 离散扩散模型, 掩码扩散, 部分掩码, 子token, 文本生成
一句话总结¶
Prime(Partial masking scheme)通过将每个token用base-b子token序列表示并在子token级别独立掩码,为掩码扩散模型引入中间状态,实现细粒度去噪过程,在OpenWebText上以15.36困惑度首次让MDM在不使用自回归公式的情况下超越ARM(17.54)。
研究背景与动机¶
领域现状 掩码扩散模型(MDM)是离散数据生成的有力模型,通过逐步揭示掩码token来生成样本。每个token只有两种状态:掩码或未掩码。
现有痛点 二值表示导致严重的计算浪费——在逆扩散过程中,大量步骤序列不发生任何变化(idle steps),模型在重复处理完全相同的输入。实验表明37%的步骤是无效的。
核心矛盾 MDM的二值状态限制了模型利用率:要么完全掩码(无信息),要么完全揭示(最终确定),缺少中间过渡状态来实现渐进的信息释放。
本文目标 重新定义扩散过程,将idle步转化为有信息量的更新,提升模型在生成过程中的利用率。
切入角度 用base-b编码将每个token拆分为子token序列,在子token级别独立掩码,自然产生中间状态。
核心 idea 通过子token的部分掩码从"二值掩码/未掩码"扩展为"多级中间状态",使四选一预测可分解为多步二选一决策。
方法详解¶
整体框架¶
MDM-Prime包含三步:(1) 用可逆函数 \(f\) 将每个token \(x_0^i \in \mathcal{X}\) 映射为长度 \(\ell\) 的子token序列 \(\mathbf{y}_0^i \in \mathcal{Y}^\ell\)(base-\(b\)编码,\(b = \lceil \sqrt[\ell]{C} \rceil\));(2) 在子token级别独立执行掩码扩散前向过程;(3) 逆扩散过程中逐步揭示子token,实现从完全掩码到中间状态再到完全揭示的细粒度转换。
关键设计¶
-
部分掩码方案(Prime):
- 功能:为离散扩散引入中间状态
- 核心思路:将token \(x_0^i\) 编码为子token序列 \(\mathbf{y}_0^i = f(x_0^i)\),子token独立掩码产生中间状态。例如4类token用2-bit编码,中间状态为"m0"或"1m",提供部分信息。中间状态数为 \((b+1)^\ell - (C+1)\),始终为正
- 设计动机:中间状态使模型能基于部分已知的token信息做更精确的预测,减少idle步。理论证明 \(\ell\) 增大时idle步单调递减
-
联合概率参数化:
- 功能:建模子token间的依赖并防止生成无效样本
- 核心思路:直接参数化联合分布 \(p_\theta(\mathbf{y}_0^i|\mathbf{y}_t)\),只对有效的base-\(b\)编码(\(\mathbf{y}_0^i \in f(\mathcal{X})\))分配概率权重,将 \(|\mathcal{V}(\mathbf{y}_t^i)|\) 外的logit显式置零。同时满足carry-over约束:已揭示的子token保持不变
- 设计动机:独立参数化 \(\prod_j p_\theta(y_0^{i,j}|\mathbf{y}_t)\) 不仅引入错误独立性假设(导致采样分布退化),还可能生成无效的子token组合(如GPT-2 50257词表映射时)
-
子token嵌入编码器:
- 功能:高效处理子token输入
- 核心思路:为每个子token创建独立的 \(D/\ell\) 维嵌入查表,拼接 \(\ell\) 个嵌入得到 \(D\) 维token嵌入。查表大小仅需 \((b+1) \times D/\ell\),远小于完整的 \(|\tilde{\mathcal{Y}}^\ell|\) 维查表
- 设计动机:子token空间 \(\tilde{\mathcal{Y}}^\ell\) 可能远大于原token空间,直接建查表不可行;拼接策略保持与标准MDM架构兼容
损失函数 / 训练策略¶
变分上界损失:\(\mathcal{L}_{vb}(\mathbf{y}_0;\theta) = \int_0^1 \frac{\alpha'_t}{1-\alpha_t} \mathbb{E}_{q(\mathbf{y}_t|\mathbf{y}_0)}[\sum_i \log p_\theta(\mathbf{y}_0^i|\mathbf{y}_t)] dt\),即加权交叉熵损失,理论保证为负对数似然的上界。
实验关键数据¶
主实验——文本生成(OpenWebText困惑度PPL)¶
| 方法 | PPL ↓ | Idle步比例 |
|---|---|---|
| ARM(自回归)* | 17.54 | - |
| MDLM | ≤22.98 | 36.77% |
| EDLM-coAR* | ≤17.58 | - |
| MDLM-Prime (ℓ=2) | ≤17.90 | 13.52% |
| MDLM-Prime (ℓ=4) | ≤15.62 | 1.83% |
| MDLM-Prime (ℓ=6) | ≤15.36 | 0.25% |
主实验——图像生成¶
| 方法 | CIFAR-10 FID ↓ | ImageNet-32 FID ↓ |
|---|---|---|
| 连续扩散SOTA | ~2.5-3.5 | ~6-8 |
| MDM-Prime | 3.26 | 6.98 |
消融实验¶
| 配置 | OWT PPL | 说明 |
|---|---|---|
| 独立参数化 | 退化 | 子token独立假设导致分布扭曲 |
| 联合参数化 | 15.36 | 捕捉子token依赖 |
| 无carry-over | 更高 | carry-over对零样本泛化很重要 |
| ℓ=2→8 | 17.90→15.48 | ℓ≥4时性能收敛 |
关键发现¶
- 首次让MDM在不依赖自回归公式的情况下超越ARM(15.36 vs 17.54)
- idle步比例与PPL高度相关——从36.77%(MDLM)降至0.25%(Prime ℓ=6)时PPL从22.98降到15.36
- 在图像生成上与连续扩散方法相当(CIFAR-10 FID 3.26)
- ℓ≥4时性能收敛,推荐选择ℓ=4或6
亮点与洞察¶
- "二值→多级中间状态"的核心idea直觉简单但效果惊人——仅修改嵌入层即可将MDLM提升7个PPL点
- idle步分析为理解MDM性能瓶颈提供了新视角
- 联合参数化+carry-over的设计既保证了理论正确性又实现了高效实现
局限与展望¶
- 子token编码增加了序列长度(\(L \times \ell\)),增加Transformer的计算量
- 当前仅在130M参数模型上验证,更大规模LLM上的表现待确认
- base-b编码是手工设计的,可能存在更优的token分解策略
相关工作与启发¶
- vs MDLM: Prime是MDLM的直接增强,仅修改嵌入层,架构完全兼容
- vs SEDD: SEDD用吸收状态+得分匹配,Prime用部分掩码+变分上界,两种互补视角
- vs BD3-LM: BD3混合自回归公式使MDM更强,但Prime证明无需AR也可超越ARM
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ 部分掩码idea简洁有力,首次让MDM超越ARM是里程碑式结果
- 实验充分度: ⭐⭐⭐⭐ 文本+图像跨模态验证,七个零样本基准,消融充分
- 写作质量: ⭐⭐⭐⭐ 理论推导严谨,图示清晰
- 价值: ⭐⭐⭐⭐⭐ 对离散扩散模型领域有重要推动