SAEmnesia: Erasing Concepts in Diffusion Models with Supervised Sparse Autoencoders¶
会议: ICML 2026
arXiv: 2509.21379
代码: https://github.com/EIDOSLAB/SAEmnesia
领域: AI 安全 / 概念遗忘 / 扩散模型可解释性
关键词: 概念擦除, 稀疏自编码器, 监督训练, 特征中心化, 扩散模型
一句话总结¶
通过在稀疏自编码器(SAE)训练阶段加入监督的"概念—潜变量"指派损失,强制每个目标概念集中到单个神经元(feature centralization),从而把扩散模型的概念擦除从"搜多神经元 + 调强度"的二维超参搜索压成"只调一个 multiplier",在 UnlearnCanvas 上比 SOTA 的 SAeUron 平均提升 9.22 个点,超参搜索代价降低 96.67%,并对对抗攻击更鲁棒。
研究背景与动机¶
领域现状:文生图扩散模型(SD 系列)的安全部署需要"概念遗忘"——选择性地抹掉裸露、版权角色、特定物体等不希望生成的概念,同时保留模型其他生成能力。目前主流路线分两类:(i) 微调整模型权重(ESD、UCE、SalUn 等);(ii) 不改权重,借助稀疏自编码器(SAE)在 cross-attention 激活上做机制式干预(Concept Steerers、SAeUron),后者优点是可逆(拔掉 SAE 模型完全恢复)且具备可解释性。
现有痛点:SAE 路线的最强代表 SAeUron 用无监督方式训练 SAE,导致 feature splitting——同一个概念(如"Bears")分散在多个潜变量上。要擦掉它必须:(1) 在数千个潜变量里搜索"哪几个组合是 Bears",文献里实测要枚举 30 种 latent 子集 × 7 种强度 = 210 次评估;(2) 多个潜变量之间还会和相邻概念交叠,干预时容易误伤"Cats"等相关概念。
核心矛盾:无监督 SAE 学出的 monosemanticity(一个神经元只对一个概念敏感)和 one-to-one(一个概念只对应一个神经元)这两个性质并不天然成立,后者长期缺失。没有 one-to-one,机制式干预就必须做组合搜索,可解释性也只是事后归因。
本文目标:在保留 SAE 重建质量的前提下,训练时就把每个待擦除概念绑定到唯一一个 latent 上,让推理阶段的擦除退化为"对一个标量乘负数"。
切入角度:扩散训练时其实已经有现成的监督信号——生成时用的 anchor prompt(如 An image of Bears)天然带概念标签。把这个信号塞回 SAE 训练,比事后用 score function 去对齐要直接得多。
核心 idea:在标准 TopK SAE 损失上加两项监督损失——Concept Assignment 损失把每个概念的激活推到一个指定 latent,Decorrelation 损失把不同宏类(objects vs. styles)的 latent 激活之间的相关性压低,从而把"特征分裂"扼杀在训练阶段。
方法详解¶
整体框架¶
SAEmnesia 的 pipeline 接在冻结的 Stable Diffusion v1.5 的 cross-attention block up.1.1 后面:
- 激活收集:对每个目标概念 \(c\),用 80 条 anchor prompt(如
An image of Bears)让 SD 跑完 50 步 denoising,从up.1.1抽出 cross-attention 特征图 \(\mathbf{F}_t \in \mathbb{R}^{h\times w\times d}\);把每个空间位置的 \(d\) 维向量当作一个 SAE 训练样本,并直接打上该 prompt 对应的 object/style 标签。 - 两阶段训练:(i) 先按常规无监督 TopK SAE 训练(重建损失 + 防 dead latent 的辅助损失)建立稀疏分解;(ii) 再用 SAEmnesia 复合损失 finetune,强化"概念—潜变量"绑定。
- 概念—潜变量指派 \(\Phi\):训练前用 score function(公式见下方关键设计)扫一遍,给每个概念 \(c\) 指派得分最高的 latent 索引 \(i_c = \Phi(c)\),作为后续监督的目标。
- 推理阶段擦除:要删掉概念 \(c\),只对 latent \(i_c\) 的激活乘一个负标量 \(\gamma_c < 0\);后续 SAE 解码回原激活空间,再丢回扩散主干继续 denoising。
整个过程对扩散主干无任何改动,SAE 可热插拔,"忘记—恢复"一行代码切换。
关键设计¶
-
监督的概念—latent 指派与 Concept Assignment 损失:
- 功能:把"哪个 latent 编码哪个概念"这件事从事后归因变成训练时硬约束。
- 核心思路:先用 SAeUron 的 score function \(\text{score}(i,t,c,D) = \frac{\mu(i,t,D_c)}{\sum_j \mu(j,t,D_c)+\delta} - \frac{\mu(i,t,D_{\neg c})}{\sum_j \mu(j,t,D_{\neg c})+\delta}\) 度量 latent \(i\) 在概念 \(c\) 上的"专属性"(在 \(D_c\) 中相对激活高、在 \(D_{\neg c}\) 中相对激活低)。选 score 最大的 latent 作为该概念的指定槽位 \(i_c\)。训练时对每个含概念 \(c\) 的样本施加 \(\mathcal{L}_{\text{CA}} = -\frac{1}{B}\sum_b \frac{1}{|\mathcal{T}^{(b)}|}\sum_{c \in \mathcal{T}^{(b)}} \log \sigma(v^{(b)}_{i_c})\),即对指派 latent 的预激活值 \(v_{i_c}\) 做 BCE,逼它在该概念出现时强烈激活。
- 设计动机:CA 损失只在"概念出现的样本 × 该概念对应的 latent"上施压,是一种局部稀疏监督,不会破坏其他 latent 的无监督表征学习,所以可以叠加在标准 SAE 损失之上而不损害重建质量。
-
跨宏类去相关约束(Decorrelation Constraint):
- 功能:防止不同宏类(object vs. style)的指派 latent 在同一 batch 内"共振"——例如 Bears 的 latent 总是和 Cubism 的 latent 一起亮,干预时会互相牵连。
- 核心思路:把概念集合划分成不相交的宏类 \(\mathcal{C} = \bigcup_m \mathcal{C}_m\)(论文里就两组:objects 和 styles),对每个 latent 在 mini-batch 内的激活向量 \(\mathbf{a}_c = [v_{i_c}^{(1)}, \dots, v_{i_c}^{(B)}]^\top\) 算 Pearson 相关系数 \(\rho\),约束 \(\mathcal{L}_{\text{DC}} = \frac{\sum_{m<m'} \sum_{i\in\mathcal{I}_m, j\in\mathcal{I}_{m'}} \rho(\mathbf{a}_i, \mathbf{a}_j)}{\sum_{m<m'}|\mathcal{I}_m||\mathcal{I}_{m'}|}\),只对跨组的 latent 对惩罚相关性。最终 \(\mathcal{L}_{\text{SAEmnesia}} = \mathcal{L}_{\text{unsupSAE}} + \beta \mathcal{L}_{\text{supSAE}} + \lambda \mathcal{L}_{L_1}\),其中 \(\mathcal{L}_{\text{supSAE}} = \mathcal{L}_{\text{CA}} + \eta \mathcal{L}_{\text{DC}}\)。
- 设计动机:完全去相关会破坏自然语义相关性("Cats" 和 "Dogs" 本就应该相关),所以只在宏类粒度去相关,保留组内自然相似度,同时把"擦物体不要影响风格"这种实际诉求直接编进 loss。这也解释了为什么论文坦承 within-group interference 仍是未解问题。
-
单 latent 阈值化擦除(Inference-time steering):
- 功能:把"擦除"操作压缩到一次标量乘法,且只在该 latent 真的被激活时才生效,避免把无关样本里本就微弱的激活随机乱抹一气。
- 核心思路:对指派 latent \(i_c\) 的激活按公式 \(z_{i_c} = \gamma_c \mu(i_c, t, D_c) z_{i_c}\) 当且仅当 \(z_{i_c} > \mu(i_c, t, D)\)(即超过全体样本的平均激活)时执行,否则保持不变。这里 \(\gamma_c < 0\) 是唯一的可调超参,\(\mu(i_c,t,D_c)\) 是验证集上该概念样本的平均激活,起归一化作用。还可选择只在后 25 步 denoising 启用 SAEmnesia,前期保留预训练先验以减少 artifact。
- 设计动机:阈值化是关键护栏——score 高未必意味着"现在这一步这一空间位置确实在表达该概念",加阈值后只在"确实在画 Bears"时干预,把误伤降到最低。\(\mu(i_c,t,D_c)\) 的归一化让 \(\gamma_c\) 在不同概念间可用统一量纲调,论文里用 \(\gamma_c=-1\) 在所有 20 类物体上都能稳定工作(见 Figure 3)。
损失函数 / 训练策略¶
完整目标为 \(\mathcal{L}_{\text{SAEmnesia}} = \mathcal{L}_{\text{unsupSAE}} + \beta(\mathcal{L}_{\text{CA}} + \eta \mathcal{L}_{\text{DC}}) + \lambda \mathcal{L}_{L_1}\)。训练分两段:先纯无监督预训练 SAE 稳定重建,再开启监督项 finetune。activation 来自 SD v1.5 的 up.1.1 cross-attention,每个 timestep 都参与训练;UnlearnCanvas 上指派 70 个概念(20 objects + 50 styles)。
实验关键数据¶
主实验¶
UnlearnCanvas 物体擦除(百分比,越高越好;UA = Unlearning Accuracy,IRA/CRA = In-/Cross-domain Retain Accuracy):
| 方法族 | 方法 | UA ↑ | IRA ↑ | CRA ↑ | Avg ↑ |
|---|---|---|---|---|---|
| Fine-tune | ESD | 92.15 | 55.78 | 44.23 | 64.05 |
| Fine-tune | SalUn | 86.91 | 96.35 | 99.59 | 94.28 |
| Adapter | SPM | 71.25 | 90.79 | 81.65 | 81.23 |
| SAE 无监督 | SAeUron | 87.16 | 85.57 | 74.14 | 82.29 |
| SAE 监督 | SAEmnesia | 94.65 | 91.39 | 88.48 | 91.51 |
相对 SAE 家族 SOTA SAeUron:UA +7.49、IRA +5.82、CRA +14.34、Avg +9.22。SalUn 在单纯物体上 Avg 更高,但论文 Appendix 表 15 显示一旦把 style 擦除算进来 SAEmnesia 综合 94.85% 反超。
消融 / 关键分析¶
| 配置 / 场景 | 关键指标 | 说明 |
|---|---|---|
| 超参搜索代价 | 7 vs. 210 次评估 | SAeUron 需 \(m=7\) multiplier × \(l=30\) latent 组合;SAEmnesia 只调 \(m\),降 96.67% |
| 顺序擦除 9 个物体 | UA 92.4% vs. 64.0% (baseline) | +28.4 点;保留精度 RA 60.9% vs. 48.4% |
| 白盒攻击 UnlearnDiffAtk | 攻后 UA 57.50% vs. 34.20% | 攻前→攻后掉 40.1 点,SAeUron 掉 49.5 点 |
| 黑盒攻击 Ring-A-Bell | 攻后 UA 97.0% vs. 79.5% | 鲁棒性优势跨威胁模型成立 |
| NSFW 抑制(I2P, SD v1.4) | 9 例检出 vs. SAeUron 18 例 | 仅用 2 个 latent("naked man"/"naked woman") |
| K-NN on top-1 latent | 接近"用全部 latent" | 验证概念信息真的浓缩到单个 latent |
| Feature score 分布(Flowers) | 0.0404 vs. 0.0166 | 监督训练后 score 峰值 2.43× 提升 |
关键发现¶
- CA 损失是擦除性能的主要贡献者:score 峰值 2.43× 是所有下游收益的根源,没有 feature centralization 后面三条路(顺序、对抗、效率)都建立不起来。
- 去相关只在宏类粒度起作用:作者明确说明同一宏类内(如 Dogs vs. Cats)的干扰仍未解决,是限制项里的明确遗留。
- 对抗鲁棒性的"溢出收益":one-to-one 让对抗 prompt 必须精确打中那个 latent,攻击面变窄;这并不是论文显式优化的目标,但天然出现。
- uniform multiplier 也稳定:Figure 3 表明 SAEmnesia 在 \(\gamma_c\) 取全域同值时仍优于 SAeUron 在最优搜索下的表现——说明"为每个概念精调 \(\gamma\)"的需要被大幅削弱。
亮点与洞察¶
- 把"事后归因"前置成"训练时硬约束":SAE 社区长期把 monosemanticity 当成训练后的解释指标,本文证明它可以作为训练目标直接优化,且不破坏重建——这是把机制可解释性从"观察工具"升级为"控制工具"的关键一步。
- "概念—latent" 的 1:1 不只是干净,而是把搜索空间从乘法降到加法:\(m\times l\) → \(m\) 的差异在概念数量上线性放大,9 个物体顺序擦除时 SAEmnesia 比 baseline 高 28.4 个点的本质就是组合爆炸被砍掉了。
- 激活阈值化护栏:
if $z_{i_c} > \mu(i_c, t, D)$这个判断看似工程细节,实则把"score 静态高"和"当前 forward 中确实在表达"解耦,是擦除精准度的隐藏关键——值得借鉴到其他基于 latent steering 的可控生成工作。 - macro-group 去相关的取舍:作者只对 objects vs. styles 这种"用户语义粒度对齐"的宏类去相关,而不是全 pairwise 去相关——这种"按下游需求而非数据分布去相关"的思路可以直接迁移到多任务 SAE、模块化 LoRA 等场景。
局限与展望¶
- 作者承认的局限:(1) 只在 U-Net 上验证,FLUX 等 transformer-based 扩散需架构适配;(2) closed-vocabulary,新增概念需要重新算 score(不重新训练 SAE 时只能用与 SAeUron 一样的事后绑定,少一层一致性保证);(3) within-group interference(如 Dogs vs. Cats)仍未解决;(4) scalability 在 20K+ 概念时 latent 容量和稳定指派都会成为瓶颈。
- 我方补充:训练分两阶段(无监督 + 监督 finetune),意味着实际部署要维护两版 SAE checkpoint,存储/工程成本被低估;论文里也只在
up.1.1一个 block 做了 SAE,跨 block 的 concept 表征如何选/合并完全没讨论。 - 改进思路:(1) 把概念间相似度(CLIP 嵌入余弦)作为 pairwise DC 权重,缓解 within-group interference;(2) 在 transformer-based 扩散上把 SAE 接到 MM-DiT 的 modulation 输出做联合 supervision;(3) 对 closed-vocabulary 限制,可以把 SAE 设计成"基座 latent + 概念专用 latent"两段式,新增概念时只 finetune 后者。
相关工作与启发¶
- vs SAeUron:同样用 SAE 在 cross-attention 激活上做擦除,区别是 SAeUron 完全无监督,依赖事后 score function 找概念,导致 feature splitting;SAEmnesia 把 score 直接搬进训练 loss。SAEmnesia 在 UnlearnCanvas Avg 上 +9.22 点,超参搜索压 96.67%,且擦除可逆性、推理成本、对抗鲁棒性全面占优。
- vs Concept Steerers (Kim & Ghadiyaram):Concept Steerers 在 text embedding 上训 SAE,作用点在 cross-attention 之前;SAEmnesia 在 cross-attention 之后的视觉路径上干预,对绕过 text encoder 的对抗攻击(Ring-A-Bell 等)更鲁棒。
- vs SalUn / ESD / UCE 等微调路线:微调路线要改主干权重,不可逆且容易牵连下游能力;SAEmnesia 完全不改权重,只挂一个可插拔的 SAE,从安全部署角度更可控(出错可拔,可审计单 latent 的语义)。
- vs ScaPre(顺序擦除专用方法):ScaPre 用 spectral trace 正则化处理顺序遗忘的 concept interference;SAEmnesia 不专门设计顺序擦除,仅靠 one-to-one 性质天然支持 9 个物体顺序擦除 UA 92.4% 的可加性,给出"先解决表示再解决调度"的另一条路。
评分¶
- 新颖性: ⭐⭐⭐⭐ 把监督信号引入 SAE 训练实现 feature centralization 是 SAE-based 擦除领域的清晰跨越,但单点创新落在 loss 设计上,并未触及 SAE 架构本身
- 实验充分度: ⭐⭐⭐⭐ UnlearnCanvas + I2P + 白盒/黑盒对抗 + 顺序擦除四线证据互相印证,但只在 SD v1.5/1.4 上验证,新一代架构泛化性留白
- 写作质量: ⭐⭐⭐⭐ 动机—方法—证据链清晰,公式编号严谨,limitations 写得诚实
- 价值: ⭐⭐⭐⭐⭐ 把"机制式擦除"从研究 demo 推到了可工程化部署的门槛(一个 latent、一次乘法、可热插拔、可审计),对生成模型合规化有直接价值