跳转至

Elastic Weight Consolidation Done Right for Continual Learning

会议: CVPR 2026
arXiv: 2603.18596
代码: https://github.com/scarlet0703/EWC-DR
领域: LLM安全
关键词: 持续学习, 灾难性遗忘, 弹性权重巩固, Fisher信息矩阵, 权重正则化

一句话总结

本文从梯度视角系统分析了 EWC 及其变体在权重重要性估计上的根本缺陷(EWC 的梯度消失和 MAS 的冗余保护),并提出了一个极其简单的 Logits Reversal 操作来修正 Fisher 信息矩阵的计算,在无样例类增量学习和多模态持续指令微调任务上大幅超越原始 EWC 及其所有变体。

研究背景与动机

持续学习(Continual Learning)要求模型按顺序学习多个任务,但神经网络在学习新任务时会灾难性地遗忘旧任务知识。解决这一问题的主流方法之一是权重正则化:评估每个参数对旧任务的重要性,训练新任务时惩罚对重要参数的修改。

EWC(Elastic Weight Consolidation)是这类方法的奠基之作,通过 Fisher 信息矩阵(FIM)来估计参数重要性,广泛应用于图像分类、指令微调、目标检测等场景。然而,EWC 在实际实验中一直表现不佳,已有多项研究指出其 FIM 近似不够准确,但没有人从根本上分析过 EWC 性能差的真正原因

本文的核心洞察是:EWC 的问题不仅仅是"FIM 近似不准",而是存在两个结构性缺陷——梯度消失导致重要参数被低估,以及 MAS 等变体引入的冗余保护导致不相关参数被过度约束。作者提出的修复方案——Logits Reversal——只需要在计算 FIM 时对 logits 取反,即可同时解决这两个问题。

方法详解

整体框架

EWC-DR 沿用标准 EWC 的学习流程:训练完任务 \(t-1\) 后,用训练数据计算参数重要性矩阵 \(\Omega^{t-1}\),学习新任务 \(t\) 时添加正则化损失 \(\mathcal{L}_{reg} = \frac{\lambda}{2} \sum_i \Omega_i^{t-1}(\theta_i^{t-1} - \theta_i^t)^2\)。本文的改进仅在于如何计算 \(\Omega\)

关键设计

1. 梯度消失分析:揭示 EWC 为什么会系统性低估参数重要性

EWC 的 FIM 来自交叉熵损失对参数的梯度平方,对 FC 层权重 \(w_k\) 写开就是 \(\Omega_{w_k}^{EWC} = \mathbb{E}[(p_k - y_k)^2 \cdot (\frac{\partial z_k}{\partial w_k})^2]\)。问题出在这个 \((p_k - y_k)^2\) 因子:训练快收敛时模型对训练样本通常已经很自信,正确类别 \(c\) 的预测概率 \(p_c \to 1\),于是 \((p_c - 1) \to 0\);而其它类别 \(p_k \to 0\)\(y_k=0\),差值同样趋近于零。两头一起塌缩,意味着越是训练到位的模型,算出来的 FIM 越接近全零。偏偏 EWC 正是在任务训练结束那一刻去估重要性,结果所有参数的重要性都被压低,正则化项形同虚设,旧任务知识自然守不住。这解释了为何 EWC 在实践中长期表现不佳——不是 FIM 近似得粗,而是它在最不该消失的时刻消失了。

2. 冗余保护分析:指出 MAS 的"补丁"换来了另一种偏差

MAS 想绕开梯度消失,改用 \(\ell_2\) 范数的输出代替交叉熵,重要性变成 \(\Omega_{w_k}^{MAS} = \frac{|z_k|}{\|\mathbf{z}\|_2} \cdot |\frac{\partial z_k}{\partial w_k}|\)。这样确实不再依赖会塌缩的 \((p_k-y_k)\),但带来了新麻烦:logits 是无界的,一个绝对值很大的负 logit(对应极低预测概率的类别)会因为 \(|z_k|\) 大而拿到很高的重要性分。然而这类极端负 logit 对最终 softmax 概率几乎没有贡献,保护它们对避免遗忘毫无意义,只是白白冻住了一批本可以用来学新任务的参数,削弱了模型的可塑性。所以 MAS 是从"什么都不保护"滑到了"保护了不该保护的",两种极端都偏离了真正重要的参数。

3. Logits Reversal:一行取反同时消掉上面两个病灶

本文的修复出奇地简单——计算 FIM 时把 logits 整体取反 \(\tilde{z}_k = -z_k\),再照常走 softmax 和交叉熵。取反后的输出是 \(\tilde{p}_k = \frac{e^{-z_k}}{\sum_j e^{-z_j}}\),重要性相应变为

\[\Omega_{w_k}^{LR} = \mathbb{E}\big[(y_k - \tilde{p}_k)^2 \cdot (\tfrac{\partial \tilde{z}_k}{\partial w_k})^2\big].\]

关键在于 \(\frac{\partial \tilde{p}_k}{\partial z_k} < 0\):原本越自信(\(z_c\) 越大)的正确类别,取反后 \(\tilde{p}_c\) 反而越小,\((y_c - \tilde{p}_c)=(1-\tilde{p}_c)\) 越大,于是正确类别的重要性被放大而不是被压平——梯度消失问题就此解决。同时,错误类别取反后 \(\tilde{p}_k\) 很小,不会像 MAS 那样给极端负 logit 灌进虚高的重要性,冗余保护也随之消失。一次取反,让 FIM 的高亮重新聚焦到真正决定正确预测的参数上,而这只是一行代码的改动。

损失函数 / 训练策略

训练损失保持标准 EWC 形式不变:\(\mathcal{L}_{total} = \mathcal{L}_{CE} + \frac{\lambda}{2} \sum_i \Omega_i^{LR}(\theta_i^{t-1} - \theta_i^t)^2\)。唯一改变是 \(\Omega\) 的计算方式。

实验关键数据

主实验

数据集 设置 指标 EWC Online EWC MAS EWC-DR 提升(vs EWC)
CIFAR-100 Big-start T=5 \(A_{last}\) 14.61 29.70 35.37 50.23 +35.62
CIFAR-100 Big-start T=5 \(A_{avg}\) 32.82 45.65 48.32 63.75 +30.93
ImageNet-Sub Big-start T=5 \(A_{last}\) 11.44 23.56 21.06 66.18 +54.74
ImageNet-Sub Big-start T=5 \(A_{avg}\) 26.57 46.68 42.59 76.00 +49.43
Tiny-ImageNet Big-start T=5 \(A_{last}\) 9.74 27.02 25.53 38.24 +28.50
MCIT (VCR后) 增量精度 \(A_t\) 42.99 52.59 +9.60

消融实验

配置 关键指标 说明
EWC (原始FIM) FC层重要性矩阵几乎全黑 梯度消失导致所有类别重要性极低
MAS (ℓ2范数) GT类+非GT类均高亮 对class 0和class 4产生冗余保护
EWC-DR (LR) 仅GT类(class 2)高亮 选择性且判别性的重要性估计
MCIT: EWC遗忘率 NLVR2任务90.66% 严重灾难性遗忘
MCIT: EWC-DR遗忘率 NLVR2任务27.48% 遗忘显著降低,可塑性保持

关键发现

  • EWC-DR 在所有 18 个 EFCIL 设置中取得最佳结果,最大提升幅度达 \(A_{last}\) +53.18%、\(A_{avg}\) +55.47%
  • 临界差异(CD)分析证实 EWC-DR 的提升具有统计显著性(CD=1.438,显著性0.05)
  • 多模态持续指令微调中,EWC-DR 在不损失新任务学习能力的同时,大幅降低遗忘率

亮点与洞察

  • 极其优雅的分析框架:从梯度角度统一审视 EWC 家族的缺陷,发现了两个之前被忽略的根本问题
  • 修复方案极度简洁:只需一行代码(logits 取反)就能大幅提升性能,体现了"找对问题比设计复杂方案更重要"
  • 重要性矩阵的可视化分析非常直观:EWC 全黑、MAS 过度高亮、EWC-DR 精确聚焦,一目了然

局限与展望

  • 理论分析聚焦于 FC 层权重,对中间层参数的影响通过反向传播间接作用,缺乏直接分析
  • 仅与 EWC 家族(EWC、Online EWC、SI、MAS)比较,缺少与知识蒸馏、架构扩展等其他 CL 方法类别的系统对比
  • Logits Reversal 的理论最优性没有严格证明,可能存在更优的 logit 变换

相关工作与启发

  • 与 Online EWC 的对比表明,在线累积重要性权重并不能根本解决梯度消失问题
  • MAS 虽然避免了梯度消失,但引入了新问题(冗余保护),说明损失函数的选择需要更加审慎
  • 该工作提示我们:经典方法的性能不佳可能不是"方法不好",而是"实现有bug"——从基本原理重新审视可能找到简单而高效的改进

评分

  • 新颖性: ⭐⭐⭐⭐ 分析视角新颖,但解决方案(取反 logits)的技术贡献偏轻量
  • 实验充分度: ⭐⭐⭐⭐⭐ 三个数据集×三种任务划分×两种设置 + MCIT 实验 + 统计检验
  • 写作质量: ⭐⭐⭐⭐⭐ 逻辑清晰,从分析到方法到实验一气呵成,可视化优秀
  • 价值: ⭐⭐⭐⭐ 对 EWC 研究社区有重要参考价值,方法简单易用