跳转至

A2D: Any-Order, Any-Step Safety Alignment for Diffusion Language Models

会议: ICLR 2026
arXiv: 2509.23286
代码: 有
领域: LLM对齐
关键词: diffusion language model, safety alignment, token-level defense, jailbreak, masked diffusion

一句话总结

提出 A2D,一种针对扩散语言模型(dLLM)的 token 级安全对齐方法,通过训练模型在遇到有害内容的 mask 位置输出 [EOS] token 来实现任意解码顺序、任意解码步的安全防御,将 DIJA 模板攻击成功率从 80%+ 降到近零(1.3%/0.0%),并支持早期拒绝实现 19.3x 加速。

研究背景与动机

领域现状:扩散语言模型(如 LLaDA、Dream)通过迭代去mask而非从左到右生成文本,支持任意顺序解码。已有安全对齐方法从 AR 模型继承,依赖响应级拒绝和固定解码顺序假设。

现有痛点:dLLM 的任意顺序解码大幅扩大了攻击面——有害内容可以在任意位置出现。DIJA 攻击通过在 [MASK] token 之间穿插对抗文本绕过早期拒绝,成功率超 80%。per-token KL 分析显示 dLLM 的安全对齐是"浅层"的——仅在前几步有效,后续步骤安全信号快速衰减。

核心矛盾:AR 模型的响应级拒绝假设固定的从左到右解码,但 dLLM 的解码可以是任意顺序、任意步骤的——传统对齐在 dLLM 上根本不适用。

本文目标 如何让 dLLM 在任意解码顺序和任意解码步骤都能可靠拒绝有害内容?

切入角度:将安全对齐从响应级降到 token 级——让模型在任何 mask 位置遇到有害内容时都输出 [EOS] 作为通用抑制信号。

核心 idea:token 级 [EOS] 对齐 + 随机 mask 训练,让 dLLM 在解码的任何位置任何步骤都能拒绝有害续写。

方法详解

整体框架

A2D 修改了标准 masked diffusion 训练目标:(1) 对有害文本,所有 mask 位置的监督目标从原始 token 改为 [EOS];(2) 对安全文本(包括有害 prompt + 安全回复),保持正常重建目标。通过均匀采样 mask ratio 让模型暴露于早期和晚期解码阶段。

关键设计

  1. Token 级 [EOS] 对齐:

    • 功能:训练 dLLM 在有害续写的任何 mask 位置都预测 [EOS]
    • 核心思路:在有害样本中,采样随机 mask,将所有 mask 位置的目标设为 [EOS]。模型学会在任何部分上下文中都能识别有害内容并输出终止信号
    • 设计动机:[EOS] 是模型已经熟悉的 token(用于填充和结束),不需要引入新词。token 级对齐天然兼容 dLLM 的任意顺序解码
  2. 均匀 mask ratio 采样:

    • 功能:训练时均匀采样 \(\lambda \sim U(0,1)\) 的 mask 比例
    • 核心思路:\(\lambda = (1-\epsilon)t + \epsilon\),使模型暴露于从几乎完全 mask(早期解码)到几乎无 mask(晚期解码)的所有阶段
    • 设计动机:解决"浅层对齐"问题——per-token KL 分析显示现有 dLLM 仅在前几步有安全信号,均匀采样确保所有解码阶段都对齐
  3. 早期拒绝机制:

    • 功能:在第一步解码时检测最左 mask 位置的 [EOS] 概率,超阈值则立即终止
    • 核心思路:A2D 训练后模型对有害输入在 mask 位置赋予高 [EOS] 概率,这个概率可以作为内部安全信号。阈值化后实现无输出的快速拒绝
    • 效果:最高 19.3x 更快的安全终止

损失函数 / 训练策略

标准 masked diffusion 交叉熵损失,唯一修改是有害样本的目标变为 [EOS]。在 BeaverTails 数据集上用 30K 样本训练。应用于已对齐的 instruction-tuned dLLM 之上。

实验关键数据

主实验(攻击成功率 ↓%)

模型 方法 Zeroshot PAIR ReNeLLM Prefilling DIJA Avg
LLaDA Original 14.6 77.5 56.5 69.6 82.9 60.2
VRPO 2.5 32.3 19.2 9.0 45.0 21.6
A2D 2.1 ~低 ~低 ~低 1.3 ~最低
Dream A2D - - - - 0.0 -

能力保持

指标 Original A2D
General (MMLU等) 66.6 66.2
Math (GSM8K等) 41.4 40.6
Coding (HumanEval等) 32.6 35.0

关键发现

  • A2D 将 DIJA 攻击成功率从 82.9% 降到 1.3%(LLaDA)和 0.0%(Dream)——彻底消除了模板攻击
  • 能力保持甚至略有提升(coding 从 32.6→35.0),说明 token 级对齐不损害一般能力
  • XSTest 上 0% 误拒率——完全不过度拒绝良性 prompt
  • 早期拒绝机制实现 19.3x 更快的安全终止
  • 在三种解码策略(左到右/置信度/随机)下都有效——真正的 any-order 防御

亮点与洞察

  • 揭示了 dLLM 的核心安全漏洞:KL 散度分析首次系统证明 dLLM 的安全对齐是浅层的,且比 AR 模型更严重
  • token 级 [EOS] 对齐的简洁性:不引入新架构或外部分类器,仅修改训练目标——最小化的改动实现最大化的防御
  • 内置安全监控能力:[EOS] 概率自然作为实时安全信号,支持解码过程中的持续监控

局限与展望

  • 仅在 BeaverTails 上训练,有害类型的覆盖可能不够全面
  • 应用于已对齐模型之上(而非从头对齐),与原始对齐的交互效果未完全理解
  • 自适应攻击(知道 A2D 的机制后的攻击)的鲁棒性未深入分析
  • 早期拒绝的阈值需要每个模型单独调整

相关工作与启发

  • vs AR 模型的安全对齐: AR 的 RLHF/DPO 假设固定解码顺序,不适用于 dLLM;A2D 原生支持任意顺序
  • vs DIJA 攻击: DIJA 利用 dLLM 的 mask 机制构造模板攻击,A2D 恰好利用同样的 mask 机制做防御——以其人之道还治其人之身
  • vs Circuit Breaker / AlphaSteer: 这些方法针对 AR 模型的激活空间,A2D 直接在 dLLM 的训练目标上做文章

评分

  • 新颖性: ⭐⭐⭐⭐⭐ 首次系统研究 dLLM 安全对齐,token 级 [EOS] 方案非常新颖
  • 实验充分度: ⭐⭐⭐⭐⭐ 3 个 dLLM、5 种攻击、多维能力评测、KL 分析
  • 写作质量: ⭐⭐⭐⭐⭐ 漏洞分析→方法设计→实验验证的逻辑链非常清晰
  • 价值: ⭐⭐⭐⭐⭐ 为 dLLM 安全部署铺平了道路