Machine Unlearning under Retain–Forget Entanglement¶
会议: ICLR 2026
OpenReview: https://openreview.net/forum?id=4WMBSHHJEr
代码: 已开源(论文中提供链接)
领域: 机器遗忘 / AI 安全
关键词: machine unlearning, retain-forget entanglement, augmented Lagrangian, gradient projection, Wasserstein-2 distance
一句话总结¶
针对"遗忘集和保留集语义纠缠"导致的相关样本误伤问题,提出两阶段优化框架:第一阶段用增广拉格朗日法激进遗忘并锁住无关保留样本,第二阶段用 Wasserstein-2 距离正则化的梯度投影修复语义相邻保留样本的精度,同时保住遗忘效果。
研究背景与动机¶
领域现状:机器遗忘(machine unlearning)要求从已训练模型中精准抹除指定数据 \(D_f\) 的影响,同时保留其余数据 \(D_r\) 的性能,应用于隐私合规(GDPR 被遗忘权)、去偏、修复中毒数据等场景。现有工作覆盖随机样本遗忘、类级遗忘、概念级遗忘,并发展出基于梯度、稀疏剪枝、Fisher/影响函数等高效后处理方法。
现有痛点:现有方法几乎都把保留集性能用"整体平均精度"来衡量,掩盖了一个关键事实——遗忘从来不是孤立的。删除某组数据往往会连带损伤与之强相关的另一组数据。例如遗忘某少数群体的有毒言论,会无意中改变模型对该群体非有毒言论的行为;遗忘某个子类,会扰乱同超类下其他相邻子类的预测。这些敏感、相关的子集恰恰是性能最脆弱、后果最严重的部分,却被平均指标淹没。
核心矛盾:遗忘集 \(D_f\) 与"相邻保留集" \(D_r^{adj}\) 共享高度重叠的特征分布——压低 \(D_f\) 的损失会株连 \(D_r^{adj}\),而恢复 \(D_r^{adj}\) 的精度又会反过来削弱遗忘效果(梯度方向冲突)。直接对二者联合优化只会陷入此消彼长的拉锯。
本文目标:在"遗忘集-保留集纠缠"(retain–forget entanglement)这一更贴近真实需求的设定下,把保留集显式拆成相邻子集 \(D_r^{adj}\)(与 \(D_f\) 强相关、易受损)和远端子集 \(D_r^{rem}\)(弱相关),既要彻底遗忘 \(D_f\),又要专门守住敏感的 \(D_r^{adj}\)。
核心 idea:解耦 + 两阶段(divide-and-conquer)——先处理"容易的部分"(遗忘 + 锁远端),再单独修复"难的部分"(相邻子集),并用 Wasserstein-2 分布约束取代传统的均值损失约束,从根上堵住"均值不变但精度反弹"的漏洞。
方法详解¶
整体框架¶
方法把纠缠场景下的遗忘拆成两个串行阶段。第一阶段(增广拉格朗日)只关心"激进遗忘 \(D_f\)"和"锁住远端保留集 \(D_r^{rem}\)",故意回避 \(D_r^{adj}\) 以躲开梯度冲突;这一阶段结束后 \(D_f\) 被遗忘、\(D_r^{rem}\) 完好,但 \(D_r^{adj}\) 因纠缠而塌陷。第二阶段(W-PGD)用梯度投影把 \(D_r^{adj}\) 的精度抬回来,同时用 Wasserstein-2 距离锁住 \(D_f\) 的损失分布(而非仅锁均值),保证修复相邻集时遗忘效果不反弹。
flowchart LR
A[原模型 θ₀] --> B[阶段一: 增广拉格朗日<br/>−Lf 最大化遗忘<br/>约束 Lrem 不变]
B --> C[中间模型 θ̄<br/>Df 已遗忘 / Drem 完好<br/>Dadj 塌陷]
C --> D[阶段二: W-PGD<br/>梯度投影恢复 Dadj<br/>W2 距离锁住 Df 损失分布]
D --> E[遗忘模型 θ<br/>三者均衡]
关键设计¶
1. 增广拉格朗日激进遗忘:自适应平衡遗忘与远端保留。 第一阶段被形式化为约束优化 \(\min_\theta -L_f(\theta)\) s.t. \(L_r^{rem}(\theta)=L_r^{rem}(\theta_0)\)——即最大化遗忘集损失,同时强制远端保留集损失钉死在原始水平。直接用固定权重惩罚项需要反复手调系数,作者改用增广拉格朗日 \(L_{aug}(\theta;\lambda,\mu)=-L_f(\theta)+\lambda(L_r^{rem}(\theta)-L_r^{rem}(\theta_0))+\frac{\mu}{2}(L_r^{rem}(\theta)-L_r^{rem}(\theta_0))^2\)。乘子 \(\lambda\) 从 0 起步,每步先按 \(\theta\leftarrow\theta-\eta\nabla_\theta L_{aug}\) 更新参数,再按约束违反量 \(\lambda\leftarrow\lambda+\mu(L_r^{rem}(\theta)-L_r^{rem}(\theta_0))\) 更新乘子。这样惩罚强度随约束违反程度自动收紧或放松,避免人工 trade-off 调参,训练更稳定。这里刻意不优化 \(D_r^{adj}\) 正是为了规避与遗忘目标的梯度冲突。
2. 揭示经典 PGD 的失效:均值约束的致命漏洞。 阶段二最自然的想法是用多任务学习里常用的线性化投影梯度下降(PGD),把 \(\nabla_\theta L_r^{adj}\) 中与 \(\{\nabla_\theta L_f,\nabla_\theta L_r^{rem}\}\) 张成空间对齐的分量投影掉:\(\theta\leftarrow\theta-\eta(\nabla_\theta L_r^{adj}-\text{Proj}_V\nabla_\theta L_r^{adj})\)。但作者用实验暴露了一个反直觉的失败——\(D_f\) 上的平均损失看似稳定,准确率却一路回升。根因在于 \(D_f\) 与 \(D_r^{adj}\) 的强纠缠:恢复 \(D_r^{adj}\) 会顺带压低 \(D_f\) 中相似样本的损失;为维持均值不变,模型把损失"补"到了不相似样本上,造成两极分化的损失分布——一部分样本损失逼近零(被重新记住)。这说明仅约束均值损失对"低损失样本占比"毫无保证,遗忘形同虚设。
3. Wasserstein-2 距离正则化的 W-PGD:锁住整条损失分布。 为细粒度控制遗忘行为,作者用 W2 距离约束 \(D_f\) 损失分布相对阶段一末模型 \(\bar\theta\) 的整体漂移。一维经验分布的 W2 有排序闭式解:\(W_2(P,Q)=(\frac1N\sum_i(\bar a_i-\bar b_i)^2)^{1/2}\),无需密度估计,远比 KL 散度(需高斯先验或核估计)便宜。定义修正遗忘损失 \(\tilde L_f(\theta)=(1-\alpha)L_f(\theta)+\alpha W_2^2(P_{\bar\theta}^{forget},P_\theta^{forget})\),再把投影空间换成 \(V=\text{span}\{\nabla_\theta\tilde L_f,\nabla_\theta L_r^{rem}\}\)。理论上 Proposition 4.1 保证更新让 \(\tilde L_f\) 和 \(L_r^{rem}\) 的变化是 \(O(\eta^2)\) 二阶小量、而 \(L_r^{adj}\) 严格一阶下降 \(-c\eta\);Proposition 4.2 进一步给出遗忘集准确率上界 \(\text{Acc}_f(\theta)\le\frac{1}{(m-\log n)^2}(\frac{1-\alpha}{\alpha}+\sqrt{\frac{\varepsilon}{\alpha}})^2\),即只要 \(\alpha>0\) 且原最小损失足够大,遗忘集准确率被压在小常数内。实践中取 \(\alpha=0.5\),使 \(D_f\) 损失分布保持均匀、准确率维持在零。
实验关键数据¶
主实验表格(CIFAR-100 / ResNet-18,遗忘"aquarium fish"子类,测试精度)¶
| 方法 | \(D_f\)↓ | \(D_r^{adj}\)↑ | \(D_r^{rem}\)↑ |
|---|---|---|---|
| Original | 90.00 | 80.00 | 85.33 |
| FT | 62.33 | 77.83 | 83.89 |
| Munba | 31.67 | 69.75 | 75.32 |
| SCRUB | 7.00 | 54.75 | 75.42 |
| SalUn | 3.00 | 34.90 | 71.78 |
| DELETE | 0.67 | 2.83 | 82.09 |
| GDR | 8.67 | 22.33 | 79.93 |
| 本文 | 2.33 | 78.17 | 81.10 |
关键对比:DELETE/GDR/SalUn 虽把遗忘集压得很低,但相邻保留集 \(D_r^{adj}\) 精度崩到 2~35%;本文在遗忘到 2.33% 的同时把 \(D_r^{adj}\) 守在 78.17%(接近原始 80%),唯一兼顾遗忘与相邻保留。
其他数据集(测试精度,节选)¶
| 设定 | 方法 | \(D_f\)↓ | \(D_r^{adj}\)↑ | \(D_r^{rem}\)↑ |
|---|---|---|---|---|
| ToxiGen / RoBERTa(去偏纠错) | GDR | 19.83 | 83.92 | 85.52 |
| ToxiGen / RoBERTa | 本文 | 14.29 | 85.86 | 85.23 |
| CelebA / ViT-B | GA/DELETE | 0.00 | 0.00(崩) | ~92 |
| CelebA / ViT-B | 本文 | 25.48 | 75.05 | 92.38 |
| TinyImageNet / ViT(遗忘"dog") | 本文 | 3.11 | 91.27 | 88.88 |
消融实验表格(W2 正则化,CIFAR-100/ResNet18,测试精度)¶
| 配置 | \(D_f\)↓ | \(D_r^{adj}\)↑ | \(D_r^{rem}\)↑ |
|---|---|---|---|
| w/o W2 正则 | 14.33 | 87.00 | 80.55 |
| w/ W2 正则 | 2.33 | 78.17 | 81.10 |
关键发现¶
- W2 正则是遗忘"防反弹"的命门:去掉后遗忘集准确率从 2.33% 反弹到 14.33%(训练集 0%→18.87%),印证了"仅锁均值会被绕过"的分析。代价是 \(D_r^{adj}\) 略降(87→78),属可接受权衡。
- 纠缠越强、基线越崩:CelebA 上相邻集与遗忘集属性高度相似,GA/DELETE 直接把 \(D_r^{adj}\) 打到 0%,而本文守住 75%,纠缠越极端越凸显优势。
- 跨架构/任务一致性:从 ResNet 到 ViT、从视觉分类到 ToxiGen 语言去偏,两阶段框架表现稳定。
亮点与洞察¶
- 问题立意精准:把"平均保留精度掩盖相邻子集塌陷"这一被长期忽视的盲区显式化,将保留集拆成 adjacent/remote 并分别上报,比 LLM 遗忘里常用的 neighbor set 更系统。
- 失败分析有诊断价值:Figure 1 用损失分布两极化清晰解释了"均值不变、准确率反弹"的机制,这是从均值约束转向分布约束的强动机,而非拍脑袋换正则。
- 理论-实践闭环:两个 Proposition 分别保证"相邻集严格下降"和"遗忘集准确率有界",W2 的一维闭式解又让分布约束几乎零额外成本,理论优雅且落地便宜。
局限与展望¶
- 依赖相邻/远端子集的先验划分:方法假设能事先把 \(D_r\) 切成 \(D_r^{adj}\) 和 \(D_r^{rem}\)(靠超类结构或语义分组),但真实场景中"哪些样本与遗忘集纠缠"往往不易界定,划分质量直接影响效果。
- 两阶段串行成本:相比一步式后处理(如 SSD),两阶段优化 + 每步 W2 排序在大规模数据上的计算开销值得进一步评估。
- 遗忘语义偏"压制"而非"重训等价":本文采纳"最大化降低遗忘集性能"视角,更适合去偏/有害内容场景;对隐私场景下"等价于从头重训"的严格定义未直接覆盖。
- CelebA 上遗忘集仍有 25% 残留:纠缠极端时遗忘与相邻保留仍存在张力,未完全消除 trade-off。
相关工作与启发¶
- 约束优化谱系:方法把公平/安全学习里成熟的增广拉格朗日、primal-dual 思想迁移到遗忘场景,提示"遗忘本质是带约束的多目标优化"。
- 与梯度冲突方法对比:GDR、Munba(Nash bargaining)、PGD 都试图调和遗忘-保留的梯度冲突,本文的差异在于分阶段回避冲突 + 分布级约束,而非在单步里硬调和。
- 对分布约束的启发:用 W2 替代 KL 来约束损失分布,因一维闭式解而极轻量,这一技巧可推广到任何"需要锁住某子集输出分布"的持续学习/去偏任务。
评分¶
- 新颖性: ⭐⭐⭐⭐ 把"retain-forget 纠缠"显式建模并用 W2 分布约束破解均值约束漏洞,问题定义和解法都有新意。
- 实验充分度: ⭐⭐⭐⭐ 覆盖 CIFAR-100/TinyImageNet/CelebA/ToxiGen 四数据集、ResNet/ViT/RoBERTa 三架构、9 个基线,并有 W2 消融,较充分;但缺隐私场景(MIA)下的遗忘度量。
- 写作质量: ⭐⭐⭐⭐ 失败案例分析(Figure 1)和理论命题衔接清晰,方法动机层层递进,可读性强。
- 价值: ⭐⭐⭐⭐ 直指现有遗忘评测的盲区(相邻子集塌陷),对去偏、有害内容纠错等真实安全场景有直接落地价值。