跳转至

Flatness-Aware Stochastic Gradient Langevin Dynamics

会议: ICML 2026
arXiv: 2510.02174
代码: https://github.com/youngsikhwang/Flatness-aware-SGLD (有)
领域: 优化 / 贝叶斯采样 / 平坦最小值
关键词: SGLD, 平坦最小值, Hessian-trace 正则, Gibbs 分布, 随机权重扰动

一句话总结

本文提出 fSGLD:在标准 SGLD 更新里把梯度处的参数 \(\theta\) 换成被高斯扰动过的 \(\theta+\epsilon\),并将扰动尺度 \(\sigma\) 与逆温度 \(\beta\) 通过 \(\sigma=\beta^{-(1+\eta)/4}\) 严格耦合,从而在不增加任何梯度/内存开销的前提下,让算法的不变测度逼近 Hessian-trace 正则化目标 \(v(\theta)=u(\theta)+\tfrac{\sigma^2}{2}\mathrm{tr}(H(\theta))\) 对应的 Gibbs 分布,并给出 Wasserstein-1 与超额风险的非渐近界,在 CIFAR/WebVision/ViT 上取得与 SAM/ASAM 相当或更优、但训练时间近乎减半的效果。

研究背景与动机

领域现状:深度网络泛化与损失曲面的"平坦性"高度相关,主流做法是 SAM 系列(min-max 内层扰动 + 双梯度)与 Entropy-SGD/Entropy-MCMC(引入辅助变量做局部熵平滑),它们都能把训练推向低曲率盆地,但代价不小:SAM 每步两次梯度,Entropy 系列翻倍内存。

现有痛点:这些方法本质都是"局部"的——只用当前点周围一小圈的几何信息,因此在多模态、高度非凸的损失曲面上很难跳出尖锐盆地;理论保证也基本只到局部收敛。另一条线是 Langevin 类全局采样(SGLD),理论上能在足够低温下集中到全局极小,但它的不变测度 \(\pi_\beta^{\text{SGLD}}\propto\exp(-\beta u)\) 完全由目标函数决定,对曲面几何无感,所以它能找到的是"任意一个"全局极小,而不是"平坦的"全局极小。

核心矛盾:现有体系里没有同时具备 (a) 全局探索能力、(b) 对低曲率区域的归纳偏置、(c) 与 SGD 同等计算/内存代价 这三点的算法。Entropy-MCMC 算是最接近的工作,但它需要辅助变量、内存翻倍,理论也只在强凸下成立。

本文目标:设计一个一阶、不增任何额外梯度或内存的 Langevin 算法,使其不变测度集中在"Hessian-trace 正则化目标" \(v(\theta)=u(\theta)+\tfrac{\sigma^2}{2}\mathrm{tr}(H(\theta))\) 的全局极小(即"全局平坦最小值"),并在非凸设定下给出非渐近 Wasserstein 与超额风险界。

切入角度:作者敏锐地注意到,把 SGLD 里的梯度 \(\nabla U(\theta,X)\) 换成在扰动点 \(\theta+\epsilon\) 处求的扰动梯度 \(\nabla U(\theta+\epsilon,X)\),其期望恰好是随机化平滑代理 \(g_\epsilon(\theta)=\mathbb{E}[u(\theta+\epsilon)]\) 的梯度;而 \(g_\epsilon\) 的二阶 Taylor 展开正好等于 \(u(\theta)+\tfrac{\sigma^2}{2}\mathrm{tr}(H(\theta))\) 加上一个高阶残差。换句话说,"扰动梯度 + Langevin 噪声"天然内嵌了 Hessian-trace 正则——只要能控住那个高阶残差。

核心 idea:用一条"\(\sigma\)\(\beta\) 耦合公式" \(\sigma=\beta^{-(1+\eta)/4}\)\(\eta\) 固定在 0.1)同时充当采样温度与扰动尺度的桥梁,使得当 \(\beta\) 升高时残差恰好以可控速率消失,从而让 fSGLD 的不变测度严格逼近"平坦偏置 Gibbs 分布" \(\pi^\star_{\beta,\sigma}\propto\exp(-\beta v(\theta))\)

方法详解

整体框架

fSGLD 几乎和 SGLD 一模一样,唯一区别在于"在哪里算梯度"。算法每一步:

  1. 采一个高斯扰动 \(\epsilon_{k+1}\sim\mathcal{N}(0,\sigma^2 I_d)\) 和一个 Langevin 噪声 \(\xi_{k+1}\sim\mathcal{N}(0,I_d)\)
  2. 扰动点 \(\theta_k+\epsilon_{k+1}\) 处对一个小批量样本 \(X_{k+1}\) 求梯度 \(\nabla_\theta U(\theta_k+\epsilon_{k+1},X_{k+1})\)
  3. 标准 SGLD 形式的更新:

    \(\theta_{k+1}=\theta_k-\lambda\,\nabla_\theta U(\theta_k+\epsilon_{k+1},X_{k+1})+\sqrt{2\lambda\beta^{-1}}\,\xi_{k+1}\)

  4. 关键约束:\(\sigma\) 不是独立超参,而是由 \(\beta\) 通过 \(\sigma=\beta^{-(1+\eta)/4}\)\(\eta=0.1\) 决定,所以对用户暴露的超参数与 SGLD 完全相同(只需调 \(\beta\)\(\lambda\))。

输入是模型参数 \(\theta_0\) 和数据分布;输出是参数链 \(\{\theta_k\}\),可以像 SGLD 一样做后验平均得到贝叶斯预测器,也可以当成普通优化器取末态参数用。

关键设计

  1. 扰动梯度作为 Hessian-trace 的隐式估计器:

    • 功能:用 \(\nabla_\theta U(\theta+\epsilon,X)\) 替换 SGLD 里的 \(\nabla_\theta U(\theta,X)\),在 0 额外梯度的代价下注入二阶曲率信息。
    • 核心思路:扰动梯度的期望 \(\mathbb{E}_{\epsilon,X}[\nabla_\theta U(\theta+\epsilon,X)]=\nabla g_\epsilon(\theta)\),而高斯期望下 Taylor 展开给出 \(g_\epsilon(\theta)=u(\theta)+\tfrac{\sigma^2}{2}\mathrm{tr}(H(\theta))+\mathbb{E}[\mathcal{R}(\theta,\epsilon)]\)。所以一步看似只是"加点权重噪声",实际上把 Hessian-trace 偷偷塞进了优化目标里。
    • 设计动机:SAM 用一次额外的"上升梯度"显式求曲率,Hessian-penalty 方法要近似 Hessian-vector product;fSGLD 借助高斯随机化把这两件事都省掉,使算法保持 SGLD 的 O(d) 内存与单次梯度成本。
  2. \(\sigma\)\(\beta\) 耦合公式 \(\sigma=\beta^{-(1+\eta)/4}\):

    • 功能:用一条解析关系把"扰动尺度"绑死到"采样温度"上,使 Taylor 残差 \(\mathbb{E}[\mathcal{R}(\theta,\epsilon)]=O(\sigma^4 d^2)\) 与 Gibbs 测度的"温度灵敏度" \(\beta\) 之间达到精确平衡。
    • 核心思路:在 Proposition 3.4 中作者证明,当 \(\eta\in(0,1)\)\(W_2(\pi^{\text{fSGLD}}_\beta,\pi^\star_{\beta,\sigma})=O(\beta^{-\eta/4}\sqrt d+\beta^{-\eta/2}d+\beta^{-(1+\eta)/2}d^2)\),可以靠增大 \(\beta\) 任意收敛到 0;同时 \(\sigma=\beta^{-(1+\eta)/4}\) 又保证平坦偏置强度不会随 \(\beta\to\infty\) 太快消失,从而存在有限 \(\beta\) 的"甜区"。
    • 设计动机:如果 \(\sigma\) 是独立超参,Taylor 残差与 \(\beta\) 互不相关,要么残差炸(破坏 Hessian-trace 偏置),要么扰动太小(退化成普通 SGLD),不存在统一可控的非渐近界;耦合公式让两件事同步衰减,把"近似精度 vs 平坦偏置强度"的 trade-off 缩成一条曲线,且把暴露给用户的超参数压回到 SGLD 同等数量。
  3. 平坦偏置 Gibbs 分布作为唯一理论目标:

    • 功能:把"找平坦极小"从一个启发式目标升级成 \(\pi^\star_{\beta,\sigma}\propto\exp(-\beta v(\theta))\) 这一明确的概率测度,并给出 Wasserstein-1 与超额风险的非渐近界。
    • 核心思路:在标准 SGLD 假设(四阶可微 + 数据相关的 Lipschitz + 耗散性)下,Theorem 3.5 证明 \(W_1(\mathcal{L}(\theta_k^{\text{fSGLD}}),\pi^\star_{\beta,\sigma})\le D_1 e^{-\dot c\lambda k/2}+(D_2+D_3)\sqrt\lambda+\underline{D}\),三项分别对应过阻尼 Langevin 的指数混合、Euler–Maruyama 离散误差 \(O(\lambda^{1/2})\) 以及不变测度偏差;Theorem 3.8 进一步给出 \(\mathbb{E}[v(\theta_k)]-\inf v\le D_1^\diamond e^{-\dot c\lambda k/4}+D_2^\diamond\lambda^{1/4}+D_3^\diamond\) 的超额风险界。
    • 设计动机:以前的 Langevin 全局收敛理论都瞄准原目标 \(u\) 的极小,第一次把目标换成 \(v\),论证算法的偏置不是"经验上跑出来的平坦",而是被理论刻画的"平坦最小值的全局采样";离散化误差速率与最优的标准 SGLD 分析(Zhang et al., 2023)一致,说明加入平坦偏置没有损失收敛速率。

损失函数 / 训练策略

作者没有显式改损失函数——优化的"有效目标"由算法动力学隐式定义为 \(v(\theta)=u(\theta)+\tfrac{\sigma^2}{2}\mathrm{tr}(H(\theta))\)。训练时只需把 SGLD 实现里"梯度处的参数"加一次高斯扰动即可;\(\eta=0.1\) 全程固定,\(\beta\) 与步长 \(\lambda\) 按各 benchmark 标准 SGLD 调度。理论上要求 \(\beta\)\(\lambda\)、迭代数 \(k\) 满足 (63)–(65) 给出的下/上界以保证 \(W_1\) 误差 \(\le\bar\delta\)

实验关键数据

主实验

ResNet-18 上的贝叶斯图像分类(贝叶斯模型平均,结果取 3 个随机种子均值±std;除 fSGLD 与 ASAM 外其他基线引自 Entropy-MCMC 原文):

数据集 指标 fSGLD 之前 SOTA 提升
CIFAR-10 ACC % ↑ 95.73 Entropy-MCMC 95.69 +0.04
CIFAR-10 NLL ↓ 0.144 ASAM 0.150 -0.006(≈ 4% 相对)
CIFAR-100 ACC % ↑ 78.53 Entropy-MCMC 79.16 -0.63(第三)
CIFAR-100 NLL ↓ 0.810 ASAM 0.814 -0.004
CIFAR-10→SVHN OOD AUROC % 98.91 Entropy-SGD 98.71 +0.20
CIFAR-100→SVHN OOD AUPR % 88.01 ASAM 87.93 +0.08

ResNet-34/50 在带噪声标签的 CIFAR-N 和 WebVision 上从头训练(5 个种子均值;s/epoch 在 CIFAR-10N 上测得):

Model Optimizer CIFAR-10N CIFAR-100N WV-1 WV-5 s/epoch
ResNet-34 SGD 89.31 58.47 71.87 89.33 22.0
ResNet-34 SAM 91.53 59.18 73.49 90.32 41.3
ResNet-34 ASAM 91.73 60.79 73.46 90.14 41.4
ResNet-34 fSGLD 91.37 61.51 73.95 90.03 23.7
ResNet-50 SAM 90.88 59.01 72.52 89.53 60.7
ResNet-50 ASAM 91.25 60.47 71.92 88.48 60.9
ResNet-50 fSGLD 90.86 61.26 73.54 90.34 34.1

ViT-B/16 微调:fSGLD 在 CIFAR-100N 上 75.67,超过 ASAM 的 74.86,而单 epoch 用时 345.8s(SAM 656.7s、ASAM 662.5s),近乎减半。

消融实验

配置 关键指标 说明
耦合 \(\sigma=\beta^{-(1+\eta)/4}\)\(\eta\in(0,1)\) 性能稳定在峰值 推荐 \(\eta=0.1\)
固定 \(\beta=10^8\)\(\sigma\) 隐含 \(\eta\notin(0,1)\) 时显著掉点 验证扰动尺度不能脱离温度独立设置
固定 \(\sigma=10^{-3}\)\(\beta\) 同上 反向验证:温度同样不能脱离扰动独立设置
Hessian 谱比较 (ResNet-34 / CIFAR-10N) fSGLD 的 \(\lambda_{\text{top}}\)\(\mathrm{tr}(H)\) 都明显小于 SGD/SGLD 直接证实 fSGLD 收敛到更平坦的极小

关键发现

  • 与 SAM/ASAM 相比:fSGLD 在 CIFAR-100N、WebVision Top-1 这种更"难"(高噪 + 多类)的任务上反超,且训练时间约为 SAM/ASAM 的一半——证明"用扰动梯度替代显式二阶"既省又好。
  • 与 Entropy-MCMC 相比:fSGLD 不需要辅助变量,内存减半,性能在 CIFAR-10 上反超,CIFAR-100 上略低 0.6%(但 NLL 更优)。
  • \(\eta\)\((0,1)\) 内对性能几乎不敏感(图 1 在很宽范围内保持峰值平台),说明耦合公式既必要又稳健,工程上调一个 \(\beta\) 就够。
  • Hessian 谱实验给出"算法机制 → 几何效果"的闭环验证:理论说 fSGLD 隐式正则 \(\mathrm{tr}(H)\),实验上 \(\mathrm{tr}(H)\) 真的明显变小。

亮点与洞察

  • "随机化平滑 = 隐式 Hessian-trace 正则" 这个等价被用得很干净:作者没有引入任何辅助变量、Hessian-vector 估计或双梯度,把 SAM/Hessian-penalty 想要的东西全部塞进 SGLD 的一次扰动里。
  • 把超参数耦合做成理论结论而不是工程 trick\(\sigma=\beta^{-(1+\eta)/4}\) 不是经验调出来的,而是由 Wasserstein 界和 Taylor 残差量级反推出来的最优耦合速率;正因如此,作者敢声明"fSGLD 暴露给用户的超参数和 SGLD 一样多"。
  • 可迁移设计:任何基于 Langevin/扩散的优化器(如训练扩散生成模型、Bayesian fine-tuning)都可以套用"扰动点处求梯度 + 温度耦合扰动"这一招来无痛获得平坦偏置;作者也在结论里点名扩散生成是下一步方向。
  • 理论范式的小升级:从"对原目标 \(u\) 的 Wasserstein 收敛"切到"对平坦目标 \(v\) 的 Wasserstein 收敛",第一次给出"采样收敛到平坦极小"的非渐近、全局结果——之前这条路只有局部 PAC-Bayes 界。

局限与展望

  • 作者承认的局限:常数 \(D_1,D_3\) 对维度 \(d\) 与温度 \(\beta\) 是指数依赖(继承自 Eberle 等的耦合论证),这是当前 SGLD 理论的天花板;以及现有分析需要 \(u\) 满足 Assumption 3.2 的全局 Lipschitz,semiconvex 情形留待后续。
  • 自己发现的局限:理论上的 \(\beta\)\(\lambda\)\(k\) 选取(公式 63–65)涉及与 \(d^2\) 同阶的常数,工程上无法直接套用,实际还是按 SGLD 经验调 \(\beta\);实验全部集中在 ResNet/ViT 图像分类,没有验证 NLP、检测、扩散生成等更高维的任务,能否扩展到现代 LLM/扩散模型的训练规模还是开放问题。
  • 改进思路:(i) 把 \(\eta\) 做成 schedule(前期大、后期小)以兼顾探索与精度;(ii) 与 preconditioned/replica-exchange SGLD 结合,缓解高维下的指数常数;(iii) 实证扩展到训练扩散生成模型,验证作者预言的"更平坦 → 更多样/高质量样本"。

相关工作与启发

  • vs SAM/ASAM:SAM 用 min-max 在邻域内取最坏点做梯度,需要双梯度;fSGLD 用高斯期望取邻域平均,只需单梯度,且天然有全局采样属性(Langevin 噪声),不会困在局部尖谷。实验上 fSGLD 在高噪/大类任务反超,且训练时间减半。
  • vs Entropy-SGD / Entropy-MCMC:两者都引入辅助变量近似局部熵 / 平坦后验,内存翻倍且 Entropy-MCMC 的理论只在强凸下成立。fSGLD 没有辅助变量、内存与 SGLD 一致,且在一般非凸+耗散假设下给出非渐近 Wasserstein 界。
  • vs 标准 SGLD(Welling-Teh / Raginsky / Zhang 2023):标准 SGLD 的 Gibbs 测度对几何无感,只能集中到 \(u\) 的全局极小;fSGLD 把目标换成 \(v=u+\tfrac{\sigma^2}{2}\mathrm{tr}(H)\),提供了第一个"采样到平坦极小"的非渐近全局结果,且离散化误差速率 \(O(\lambda^{1/2})\) 与最优 SGLD 分析持平。
  • vs Random Weight Perturbation (RWP, Ahn 2024 等):RWP 通常把扰动尺度当独立超参,缺乏全局收敛保证;fSGLD 可视为"SGLD + 强制耦合的 RWP",并把 RWP 的几何作用首次纳入 Langevin 全局非渐近分析框架。

评分

  • 新颖性: ⭐⭐⭐⭐⭐ 第一次把"扰动梯度 + 耦合温度"做成证明可行的平坦偏置 SGLD,理论与工程同时给出干净答案。
  • 实验充分度: ⭐⭐⭐⭐ 贝叶斯分类/不确定性/OOD/有噪标签/ViT 微调全覆盖,且做了 \(\beta\)-\(\sigma\) 解耦消融与 Hessian 谱可视化;缺点是只在视觉分类,没有 NLP 与生成任务。
  • 写作质量: ⭐⭐⭐⭐ 概念递进清晰(动机→randomized smoothing→耦合→非渐近界→实验),公式略密但每步都给了直观说明;对相关工作交代到位。
  • 价值: ⭐⭐⭐⭐⭐ 一阶、单梯度、无额外内存就能拿到 SAM/ASAM 量级或更优的平坦偏置,且训练时间减半,可作为通用 SGD 替换并搭载到任意 Bayesian 流程上,性价比极高。