跳转至

A2D: Any-Order, Any-Step Safety Alignment for Diffusion Language Models

会议: ICLR 2026
arXiv: 2509.23286
代码: 有
领域: LLM对齐
关键词: diffusion language model, safety alignment, token-level defense, jailbreak, masked diffusion

一句话总结

提出 A2D,一种针对扩散语言模型(dLLM)的 token 级安全对齐方法,通过训练模型在遇到有害内容的 mask 位置输出 [EOS] token 来实现任意解码顺序、任意解码步的安全防御,将 DIJA 模板攻击成功率从 80%+ 降到近零(1.3%/0.0%),并支持早期拒绝实现 19.3x 加速。

研究背景与动机

领域现状:扩散语言模型(如 LLaDA、Dream)通过迭代去mask而非从左到右生成文本,支持任意顺序解码。已有安全对齐方法从 AR 模型继承,依赖响应级拒绝和固定解码顺序假设。

现有痛点:dLLM 的任意顺序解码大幅扩大了攻击面——有害内容可以在任意位置出现。DIJA 攻击通过在 [MASK] token 之间穿插对抗文本绕过早期拒绝,成功率超 80%。per-token KL 分析显示 dLLM 的安全对齐是"浅层"的——仅在前几步有效,后续步骤安全信号快速衰减。

核心矛盾:AR 模型的响应级拒绝假设固定的从左到右解码,但 dLLM 的解码可以是任意顺序、任意步骤的——传统对齐在 dLLM 上根本不适用。

本文目标 如何让 dLLM 在任意解码顺序和任意解码步骤都能可靠拒绝有害内容?

切入角度:将安全对齐从响应级降到 token 级——让模型在任何 mask 位置遇到有害内容时都输出 [EOS] 作为通用抑制信号。

核心 idea:token 级 [EOS] 对齐 + 随机 mask 训练,让 dLLM 在解码的任何位置任何步骤都能拒绝有害续写。

方法详解

整体框架

A2D 要解决的问题是:让一个已经 instruction-tuned 的 dLLM,无论按什么顺序解码、解码到第几步,只要某个 mask 位置本该填入有害内容,就改填 [EOS] 这个终止信号。它不改架构、不加外部分类器,只在标准 masked diffusion 训练目标上动两处。第一处是把训练数据拆成 Harmful 集和 Retain 集:对 Harmful 集(有害 prompt + 有害回复),所有 mask 位置的监督目标从原始 token 换成 [EOS];对 Retain 集(安全回复、以及「有害 prompt + 安全回复」这类样本),保持正常的重建目标——让模型学会"只在有害内容上输出 [EOS]、其余照常生成"。第二处是训练时均匀采样 mask 比例,让模型同时见到「几乎全 mask」的早期解码和「几乎无 mask」的晚期解码,把安全信号铺满整条解码轨迹。训练完成后,有害输入会让 mask 位置的 [EOS] 概率飙高——A2D 进一步把这个概率读作内部安全信号,在解码第一步就检测、超阈值即提前终止。

%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400, 'subGraphTitleMargin': {'top': 8, 'bottom': 16}}}%%
flowchart TD
    A["已对齐的<br/>instruction-tuned dLLM"] --> S1
    subgraph S1["Token 级 [EOS] 对齐(双数据集监督)"]
        direction TB
        H["Harmful 集<br/>有害 prompt + 有害回复"] -->|"mask 位目标 → [EOS]"| J["合并训练"]
        R["Retain 集<br/>安全回复 / 有害 prompt+安全回复"] -->|"mask 位目标 → 重建原 token"| J
    end
    S1 --> S2["均匀 mask ratio 采样<br/>λ=(1-ε)t+ε, t~U(0,1)"]
    S2 --> M["A2D 对齐后 dLLM<br/>任意位/任意步见有害即输出 [EOS]"]
    M --> E["早期拒绝机制<br/>首步最左 mask 位 [EOS] 概率 > τ 即终止"]
    E --> O["安全终止(最高 19.3× 加速)/ 正常解码"]

关键设计

1. Token 级 [EOS] 对齐:把响应级拒绝降到每个 mask 位置

dLLM 的任意顺序解码让有害内容可以冒在任何位置,传统的「整段响应级拒绝」根本管不住。A2D 改成在 token 级别对齐:训练数据拆成 Harmful 集和 Retain 集,对 Harmful 集采样随机 mask 后把所有被 mask 位置的监督目标统一设为 [EOS],逼模型学会「不管当前看到的是哪部分上下文,只要识别出有害续写,就在这里输出终止信号」;对 Retain 集则照常重建原 token。选 [EOS] 是因为它本就是模型熟悉的 token(用于填充和结束),不必引入新词。这里的关键巧思是 Retain 集特意纳入「有害 prompt + 安全回复」样本——它教模型「面对有害提问也可以正常给出安全回答」,因此模型只在真正有害的续写上触发 [EOS],而不会把所有看起来敏感的良性请求都一律拒掉(XSTest 上 0% 误拒率即来自此)。token 级的监督形式天然兼容 dLLM 任意顺序解码——任何位置、任何上下文片段都能触发拒绝。

2. 均匀 mask ratio 采样:堵住「浅层对齐」只在前几步生效的漏洞

per-token KL 分析显示现有 dLLM 的安全对齐是浅层的——只在解码的前几步有效,后续步骤安全信号迅速衰减。A2D 在训练时对每个样本采样时间步 \(t \sim U(0,1)\),并令 mask 比例 \(\lambda = (1-\epsilon)t + \epsilon\),使其覆盖从「几乎完全 mask」(对应早期解码阶段)到「几乎无 mask」(对应晚期解码阶段)的所有比例。这样安全监督就不再集中在前几步,而是均匀分布到每一个解码阶段,真正做到 any-step 防御;实验里 DIJA 这类「先填好部分有害续写」的攻击成功率被压到近零,正是因为模型在解码后段也仍会拒绝。

3. 早期拒绝机制:把 [EOS] 概率当成内部安全信号提前刹车

经过前两步训练,模型对有害输入会在 mask 位置赋予很高的 [EOS] 概率——这个概率本身就是一个可读出的内部安全信号。A2D 在第一步解码时只检测最左 mask 位置的 [EOS] 概率,一旦超过阈值 \(\tau\) 便立即终止、不生成任何 token。之所以只看最左位而非后续位置,是为了避免把「okay」这类短良性回复误判成拒绝。由于在解码起点就刹车、跳过了后续所有去 mask 步骤,安全终止显著加速:\(\tau=0.9\) 时在 AdvBench 上提速约 6×、仅 1.6% 过度拒绝,\(\tau=0.8\) 时进一步提速到 19.3×(代价是误拒率略升),阈值给出了「速度 vs 良性兼容」的可调权衡。

损失函数 / 训练策略

仍是标准 masked diffusion 的全 mask 位置交叉熵损失,唯一改动是把 Harmful 集样本的监督目标替换成 [EOS]、Retain 集照常重建原 token。训练数据为 BeaverTails 的 30K 样本(划分为 Harmful 集与 Retain 集),整套方法施加在已对齐的 instruction-tuned dLLM 之上,训练 10 个 epoch、batch size 16、学习率 \(5\times10^{-5}\)(AdamW,weight decay 0.1)。

实验关键数据

主实验(攻击成功率 ↓%)

模型 方法 Zeroshot PAIR ReNeLLM Prefilling DIJA Avg
LLaDA Original 14.6 77.5 56.5 69.6 82.9 60.2
VRPO 2.5 32.3 19.2 9.0 45.0 21.6
A2D 2.1 ~低 ~低 ~低 1.3 ~最低
Dream A2D - - - - 0.0 -

能力保持

指标 Original A2D
General (MMLU等) 66.6 66.2
Math (GSM8K等) 41.4 40.6
Coding (HumanEval等) 32.6 35.0

关键发现

  • A2D 将 DIJA 攻击成功率从 82.9% 降到 1.3%(LLaDA)和 0.0%(Dream)——彻底消除了模板攻击
  • 能力保持甚至略有提升(coding 从 32.6→35.0),说明 token 级对齐不损害一般能力
  • XSTest 上 0% 误拒率——完全不过度拒绝良性 prompt
  • 早期拒绝机制实现 19.3x 更快的安全终止
  • 在三种解码策略(左到右/置信度/随机)下都有效——真正的 any-order 防御

亮点与洞察

  • 揭示了 dLLM 的核心安全漏洞:KL 散度分析首次系统证明 dLLM 的安全对齐是浅层的,且比 AR 模型更严重
  • token 级 [EOS] 对齐的简洁性:不引入新架构或外部分类器,仅修改训练目标——最小化的改动实现最大化的防御
  • 内置安全监控能力:[EOS] 概率自然作为实时安全信号,支持解码过程中的持续监控

局限与展望

  • 仅在 BeaverTails 上训练,有害类型的覆盖可能不够全面
  • 应用于已对齐模型之上(而非从头对齐),与原始对齐的交互效果未完全理解
  • 自适应攻击(知道 A2D 的机制后的攻击)的鲁棒性未深入分析
  • 早期拒绝的阈值需要每个模型单独调整

相关工作与启发

  • vs AR 模型的安全对齐: AR 的 RLHF/DPO 假设固定解码顺序,不适用于 dLLM;A2D 原生支持任意顺序
  • vs DIJA 攻击: DIJA 利用 dLLM 的 mask 机制构造模板攻击,A2D 恰好利用同样的 mask 机制做防御——以其人之道还治其人之身
  • vs Circuit Breaker / AlphaSteer: 这些方法针对 AR 模型的激活空间,A2D 直接在 dLLM 的训练目标上做文章

评分

  • 新颖性: ⭐⭐⭐⭐⭐ 首次系统研究 dLLM 安全对齐,token 级 [EOS] 方案非常新颖
  • 实验充分度: ⭐⭐⭐⭐⭐ 3 个 dLLM、5 种攻击、多维能力评测、KL 分析
  • 写作质量: ⭐⭐⭐⭐⭐ 漏洞分析→方法设计→实验验证的逻辑链非常清晰
  • 价值: ⭐⭐⭐⭐⭐ 为 dLLM 安全部署铺平了道路