跳转至

Unlearning During Training: Domain-Specific Gradient Ascent for Domain Generalization

会议: ICLR 2026
OpenReview: https://openreview.net/forum?id=9ufS5Jl0O0
领域: 域泛化 / 泛化理论 / 机器遗忘
关键词: 域泛化, 机器遗忘, 影响函数, 域特异通道, 梯度上升

一句话总结

本文提出 Identify and Unlearn (IU):一个模型无关的"训练中遗忘"模块,每个 epoch 结束后用影响函数挑出"徒增模型复杂度却几乎不提升泛化"的训练样本,用跨域方差 (IDV) 精确定位捕获域特异特征的通道,再对这些通道在这些样本上做梯度上升 (DSGA),从而在保留域不变特征的前提下抹除域特异依赖,在 7 个基准、15+ 个 DG baseline 上平均涨点最高 3.0%。

研究背景与动机

领域现状:域泛化 (Domain Generalization, DG) 想让模型只在多个有标注的源域上训练,就能泛化到训练时完全没见过的目标域,避免域适应 (DA/UDA) 那种"必须拿到目标域数据"的限制。主流做法分三类:数据增强、表示学习(学域不变特征)、训练策略(meta-learning / 集成 / 课程学习等)。

现有痛点:这些方法本质都是在训练目标里加约束,试图从一开始就阻止模型学到域特异特征。但它们都缺一个"事后纠错"的机制——一旦模型在训练途中已经捕获了某些域特异特征,这些方法没有任何手段把它们再删掉。

核心矛盾:域特异依赖不是一次性产生的,而是在训练的不同阶段动态涌现的。只在训练目标上做文章,等于"只设防、不排查",无法应对中途才冒出来的偏置。这就需要一个持续运行、自适应纠正的过程。

本文目标:设计一个机制,能够(i)识别出"正在引入域特异偏置"的训练样本,(ii)定位承载域特异特征的通道,(iii)只抹掉这些样本在这些通道上的影响,同时保住域不变特征。

切入角度:作者借两条已有原理——影响函数(Koh & Liang, 2017,能估单个样本对参数/验证性能的影响)和"复杂度更低的模型泛化更好"(即在训练表现相当时,让模型变复杂却不帮泛化的样本,大概率是在喂域特异偏置)。把机器遗忘 (Machine Unlearning) 这个原本用于"删数据满足隐私请求"的工具,第一次借来当作提升 DG 泛化的手段。

核心 idea:把"遗忘"塞进训练循环——每个 epoch 后挑出"高复杂度、低泛化贡献"的样本,只在域特异通道上对它们做梯度上升(即反向优化、主动遗忘),选择性地把域特异特征"忘掉"。

方法详解

整体框架

IU 是一个挂在任意 DG baseline 上的后置模块:正常训练完一个 epoch 后,它插进来跑一轮"识别 + 遗忘",再把更新后的模型交回去继续下一个 epoch。一轮 post-epoch 干预由三步串成:先用影响函数算出每个样本的遗忘分数、用 MAD 阈值卡出"遗忘集" \(D_u\);再用跨域方差 IDV 算出每个通道的域特异程度、同样用 MAD 卡出"域特异通道集" \(C_{spc}\);最后只在 \(C_{spc}\) 这些通道、只用 \(D_u\) 这些样本做梯度上升,得到去除了域特异依赖的新模型 \(f^*_\theta\)

整个 IU 不改 baseline 的训练目标、不引入新网络结构,纯粹是"训练间隙的一次外科手术",因此天然 model-agnostic。

%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400}}}%%
flowchart TD
    A["训练完一个 epoch<br/>得到模型 θ + 训练集 D"] --> B["遗忘集选择<br/>影响函数算复杂度/泛化分<br/>→ 遗忘分数 Ux → MAD 卡阈值"]
    B -->|得到遗忘集 Du| C["域特异通道选择<br/>跨域方差 IDV<br/>→ MAD 卡阈值"]
    C -->|得到域特异通道 Cspc| D["域特异梯度上升 DSGA<br/>仅在 Cspc 通道、仅用 Du 样本<br/>反转梯度"]
    D --> E["更新后的模型 f*θ<br/>交回下一个 epoch"]

关键设计

1. 遗忘分数 (Unlearning Score):用影响函数挑出"添乱"的样本

要"事后纠错",第一步得知道纠谁——哪些样本在偷偷往模型里塞域特异偏置。作者把这件事量化成两个基于影响函数的分数。复杂度分数 \(C_x\) 衡量"删掉样本 \(x\) 会让参数变化多大",用参数改变量的 \(\ell_2\) 范数表示:\(C_x = \lVert -H_\theta^{-1}\nabla_\theta L(x,\theta)\rVert_2\),其中 \(H_\theta^{-1}\) 是逆 Hessian。\(C_x\) 越大,说明这个样本对模型复杂度的影响越大。泛化分数 \(G_x\) 衡量样本 \(x\) 对整个验证集 \(D_{val}\) 性能的正向贡献:\(G_x = \sum_{z\in D_{val}} -\nabla_\theta L(z,\theta)^T H_\theta^{-1}\nabla_\theta L(x,\theta)\)

两者合成总的遗忘分数 \(U_x = (G_x)^\alpha / C_x\)\(U_x\),意味着这个样本"几乎不帮泛化、却大幅推高复杂度"——正是该被遗忘的候选。指数 \(\alpha\) 用来缓解作者所谓的 Score Equivalence Bias:没有它时,泛化分和复杂度分差异很大的两个样本可能拿到相同的 \(U_x\),削弱筛选区分度。阈值用 MAD(中位数绝对偏差)卡:\(\tau_U = \tilde{U}_x - k\cdot\text{median}(|U_{x_i}-\tilde{U}_x|)\),分数低于 \(\tau_U\) 的进入遗忘集 \(D_u\)。选 MAD 而非均值绝对偏差,是因为它不假设具体分布、对离群点更鲁棒,全程固定 \(k=2\)

2. 跨域方差 IDV:精确分辨哪些通道是"域特异"的

光挑出样本还不够——如果对这些样本"全盘遗忘",会把宝贵的域不变特征也一起抹掉。所以得先锁定只承载域特异特征的通道,手术刀才下得准。已有方法用 Aggregated Variance (AV),即把所有源域的激活混在一起算总方差,隐含假设"域特异通道在全数据上方差大"。这个假设有两个硬伤:一是对域不平衡极敏感,主导域会绑架统计量;二是它捕的主要是域内离散度而非域间差异,会系统性误判。论文给的反例很形象:一个对毛发边缘敏感的纹理通道,在照片(真实毛发)、卡通(粗描边)、油画(笔触)、素描(轮廓线)里每个域内方差都很高,AV 会给它高分、误判成域特异,但它其实跨域行为一致、是域不变的。

IDV 换了个更对的定义:一个通道是域特异的,当且仅当它的域内方差在各源域之间差异很大。形式上 \(\text{IDV}(c) = \text{Variance}\big(\{v_c^{(d)}\}_{d=1}^N\big)\),其中 \(v_c^{(d)} = \frac{1}{N_d}\sum_i (x_{c}^{(d,i)} - \mu_c^{(d)})^2\) 是通道 \(c\) 在域 \(d\) 内的激活方差。即"先算每个域内的方差,再算这些方差跨域的方差"。它把每个域当独立分析单元、且与域大小无关,所以既 domain-aware 又 domain-size agnostic,天然抗域不平衡,也不会把"全局都吵"的噪声通道误当成真正的域特异通道。同样用 MAD 阈值,IDV 超阈值的通道进入 \(C_{spc}\)

3. 域特异梯度上升 (DSGA):只在该忘的地方反转梯度

有了遗忘集 \(D_u\) 和域特异通道 \(C_{spc}\),最后一步是真正"遗忘"。DSGA 只对域特异通道的参数 \(\theta_c\)、只用 \(D_u\) 里的样本做梯度上升(注意是加号,与常规下降相反):\(\theta_c = \theta_c + \nabla_{\theta_c} L(x,\theta),\ x\in D_u,\ c\in C_{spc}\)。直觉是:对这些"添乱样本"主动升高损失,等于把模型在域特异通道上的预测信心打散。作者还配了理论分析支撑:把表示拆成域不变 \(f_{inv}\) 和域特异 \(f_{spc}\) 两部分、参数对应拆成 \(\theta_{inv}\)\(\theta_{spc}\),证明只更新 \(\theta_{spc}\) 的梯度上升会增大条件熵 \(H(y\mid f_{spc})\),从而降低互信息 \(D_{spc}=I(y;f_{spc})\)(模型对域特异特征的依赖↓);而 \(\theta_{inv}\) 不动,\(D_{inv}\) 基本不变。这正是"选择性遗忘"的关键——只削域特异依赖,保住域不变依赖。

损失函数 / 训练策略

IU 不改 baseline 的主训练损失,只在每个 epoch 后插入一次基于影响函数的样本/通道筛选 + DSGA 更新。一个可选增强是对遗忘分数做 EMA(指数滑动平均)平滑,记作下标 \(\text{IUE}\):跨 epoch 平滑 \(U_x\) 的轨迹,降噪、拉开样本间区分度,让遗忘集选得更稳更准。

实验关键数据

主实验

在 DomainBed 协议下评测 7 个基准(PACS / OfficeHome / VLCS / Terra Incognita / DomainNet / Digits-DG / NICO++),采用 leave-one-domain-out,把 IU/IUE 挂到 15+ 个不同范式的 DG baseline 上。下表节选几个代表 baseline 的平均准确率(%):

基准 ERM ERM\(_{IU}\) ERM\(_{IUE}\) MMD MMD\(_{IUE}\) EFDMix EFDMix\(_{IUE}\)
PACS 83.0 85.7 86.0 83.2 84.9 84.6 86.6
OfficeHome 68.2 69.8 70.0 67.7 70.4 71.2 73.1
VLCS 77.2 80.0 80.6 77.2 80.7 78.3 80.1
Terra 41.7 44.2 44.5 46.6 48.9 49.9 51.5
DomainNet 40.7 42.2 43.1 31.7 34.6 44.2 45.6
Digits 79.4 82.1 82.9 79.9 81.9 82.1 84.3
NICO++ 79.8 81.2 81.5 80.2 83.0 82.6 84.8

三点观察:(1) 不论 baseline 属于哪一范式,挂上 IU 后全部涨点,印证其 model-agnostic 与"事后遗忘域特异特征"的有效性;(2) EMA 平滑(IUE)在 IU 基础上再稳定提升;(3) 连 UDIM、VL2V 这类已在标准基准上接近饱和的强 baseline,IU 也能挤出温和但一致的增益。

消融实验

拆开 IU 的两个组件——遗忘集选择 (USS) 与域特异通道选择 (DSCS):

配置 PACS OfficeHome VLCS Terra DomainNet 说明
ERM 83.0 68.2 77.2 41.7 40.7 基线
ERM\(_{USS}\) 78.9 64.3 72.6 37.6 37.4 只有 USS:对遗忘集全参数反梯度
ERM\(_{DSCS}\) 76.7 62.5 70.4 36.3 34.6 只有 DSCS:对全训练集在域特异通道反梯度
ERM\(_{IU}\) 85.7 69.8 80.0 44.2 42.2 两者结合(完整 IU)

关键发现

  • 两个组件必须配合,单独用都掉点:只做 USS(全参数反梯度)会连域不变知识一起删掉,PACS 从 83.0 掉到 78.9;只做 DSCS(全训练集都在域特异通道反梯度)会过度简化、削弱表示能力,掉得更狠到 76.7。两者结合才把性能从 83.0 抬到 85.7——说明"挑对样本"和"挑对通道"是互补的,缺一不可。
  • IDV 的判别信号呈双峰长尾:绝大多数通道 IDV 值很低(域不变),一小撮形成高值长尾(域特异),这种清晰的双峰让 MAD 阈值能干净地切出域特异通道。
  • EMA 让遗忘分数更可分:跨 epoch 平滑后,不同样本的分数轨迹更平滑、分布更展开,信噪比提升、离群影响样本更突出,因此 IUE 普遍优于 IU。

亮点与洞察

  • 把"机器遗忘"从隐私工具迁成泛化工具:传统 MU 是为删数据满足隐私请求,本文把同一套"主动遗忘"机制重定向到"删域特异特征以提升 OOD 泛化",是一个很漂亮的换用场景——这个迁移思路(遗忘不该被记住的东西)可以推广到去偏、抗捷径学习等任务。
  • IDV 的"方差的方差"定义很巧:用一句"域内方差在域间是否差异大"就绕开了 AV 把域内离散度误当域间差异的坑,且天然抗域不平衡。这个"先组内统计、再组间求方差"的模式可复用到任何需要分辨"组特异 vs 组共性"的特征筛选场景。
  • 训练中持续干预而非一次性后处理:把遗忘做成 per-epoch 的循环,对应了"域特异依赖动态涌现"的观察,比训练完再修一刀更贴合问题本质。

局限与展望

  • 依赖影响函数与逆 Hessian\(C_x\)\(G_x\) 都要算 \(H_\theta^{-1}\),在大模型上精确计算代价高,实际多用近似,近似误差会传导到遗忘集质量(论文正文未详述其近似与开销,⚠️ 以原文附录为准)。
  • 理论分析建立在强假设上:Theorem 1 依赖"表示可干净拆成域不变/域特异、参数也对应解耦"这一理想化前提,真实网络里两者并非如此可分,\(\theta_{inv}\) 完全不受影响只是近似。
  • 超参与阈值敏感性\(\alpha\)、MAD 的 \(k\)、域特异通道阈值都需设定;论文固定 \(k=2\)、把 \(\alpha\) 的敏感性放进附录,主文对调参鲁棒性的覆盖有限。
  • 改进方向:把 IDV 从"通道级"细化到更细粒度的特征子空间,或把遗忘频率自适应到"域特异依赖涌现"的实际节奏(而非每 epoch 固定一次),可能进一步提效。

相关工作与启发

  • vs 表示学习类 DG(特征对齐 / 对抗 / IRM):它们在训练目标里预防模型学域特异特征;IU 是事后纠错,挂在它们之上做增量遗忘,因此能和这类方法叠加涨点而非互斥。
  • vs Aggregated Variance (AV) 通道筛选:AV 用跨样本池化的总方差找域特异通道,对域不平衡敏感、易把域不变的纹理通道误判;IDV 用跨域的"方差的方差",domain-aware 且 domain-size agnostic,定位更准。
  • vs 传统机器遗忘 (MU):传统 MU 目标是"为隐私/安全忘掉特定数据点";本文的目标是"选择性忘掉妨碍泛化的域特异特征",是把 MU 第一次用于增强 DG,问题设定根本不同。

评分

  • 新颖性: ⭐⭐⭐⭐⭐ 首次把机器遗忘当作 DG 工具,IDV 的"方差的方差"定义和 per-epoch 选择性梯度上升都有原创性。
  • 实验充分度: ⭐⭐⭐⭐⭐ 7 个基准 × 15+ baseline 全面涨点,消融清楚拆出 USS/DSCS 的互补性。
  • 写作质量: ⭐⭐⭐⭐ 动机递进清晰、图例(毛发通道反例)很有说服力,理论部分假设偏强但表述完整。
  • 价值: ⭐⭐⭐⭐ model-agnostic 即插即用,能给已饱和的强 baseline 再挤出增益,实用性高。