Unbiased and Second-Order-Free Training for High-Dimensional PDEs¶
会议: ICML 2026
arXiv: 2605.14643
代码: https://github.com/seojaemin22/Un-EM-BSDE (有)
领域: 科学计算 / 神经 PDE 求解器
关键词: BSDE, 高维 PDE, Euler-Maruyama, 无偏估计, 二阶导数自由
一句话总结¶
本文针对 EM-BSDE 训练 loss 的离散化偏置问题,提出 Un-EM-BSDE:把单步误差用两组独立的 Monte Carlo 子样本平均后做"乘积"形成无偏估计,既消除偏置又不需要 Hessian,在 HJB/BSB/AC 等基准 PDE 上达到 Heun-BSDE / FS-PINNs 的精度但训练时间仅 1.79× EM-BSDE(相比 Heun-BSDE 的 42.91× 与 FS-PINNs 的 32.07×)。
研究背景与动机¶
领域现状:高维 PDE 求解器有两大主流——PINNs 把 PDE 残差写进损失函数,但对高频或多尺度解训练不稳;Deep BSDE 利用 PDE 与随机微分方程(SDE)的连接,把问题转化成沿轨迹的概率表示,避开维度灾难。Deep BSDE 又用 Euler-Maruyama(EM)做时间离散,构造 self-consistency loss \(\ell_{\text{EM}}=\mathbb{E}[|\text{err}^{\text{EM}}_n|^2]\)。
现有痛点:Park & Tu (2025) 证明 EM-BSDE loss 在有限步长 \(\Delta t\) 下是离散化-有偏估计——偏置项 \(\frac{1}{2}\text{Tr}[(\sigma^T(\nabla^2 u_\theta)\sigma)^2]\) 直接污染梯度方向。为消偏,他们提出 Heun-BSDE(用 Stratonovich + Heun 积分),但代价是必须显式计算二阶导(Hessian),训练时间是 EM-BSDE 的 42.91 倍。Xu & Zhang (2025) 的 Shotgun 方法只能把偏置降为 \(1/M\),并不彻底消除。
核心矛盾:消偏(无偏)和高效(无二阶导)这两个目标在 BSDE 训练中似乎不可兼得——Heun-BSDE 牺牲效率换无偏,Shotgun/Multi-Shot EM 牺牲无偏换效率,FS-PINNs 用 forward SDE 采样但仍要 Hessian。
本文目标:(i) 完全消除 EM 离散化偏置;(ii) 不需要任何 \(\nabla^2 u_\theta\) 计算;(iii) 训练时间不能比 EM-BSDE 慢太多;(iv) 对 BZ(fully-coupled FBSDE)、PIDE(带跳跃)这种复杂动力学也能 work。
切入角度:利用统计学里的 sample-splitting 经典原理——如果把同一个二阶矩 \(\mathbb{E}[X^2]\) 用两个独立子样本 \(\mathbb{E}[X_1\cdot X_2]\) 替代,由于 \(X_1, X_2\) 独立,\(\mathbb{E}[X_1 X_2]=\mathbb{E}[X_1]\mathbb{E}[X_2]=(\mathbb{E}[X])^2\),偏置项(来自 \(\text{Var}(X)\))自然消失。
核心 idea:用"两组独立 Shot 子样本的乘积"替代"单组样本的平方",从单步误差形成无偏估计 \(\ell^{M_1, M_2}_{\text{UEM}}=\mathbb{E}[\text{Shot}_{M_1}[\text{err}^{\text{EM}}_n]\cdot\text{Shot}_{M_2}[\text{err}^{\text{EM}}_n]]\)。
方法详解¶
整体框架¶
PDE \(\mathcal{L}[u](t,x)=\phi(t,x,u,\nabla u)\) 通过 Itô 公式转化为 FBSDE 系统 \(dX_t=\mu\,dt+\sigma\,dW_t\),\(dY_t=\phi\,dt+Z_t^T\sigma\,dW_t\),其中 \(Y_t=u(t,X_t)\)、\(Z_t=\nabla u(t,X_t)\)。用 EM 在时间网格 \(t_n=n\Delta t\) 上离散,得到单步前向 \(F_n(x)=x+\mu\Delta t+\sigma\Delta W_n\) 和反向 \(B_n(x;u)=u(t_n,x)+\phi_u\Delta t+\nabla u\cdot\sigma\Delta W_n\)。定义单步误差 \(\text{err}^{\text{EM}}_n(x;u)=\frac{u(t_{n+1}, F_n(x))-B_n(x;u)}{\Delta t}\)。Un-EM-BSDE 在每步采 \(M_1+M_2\) 个独立 Brownian 增量 \(\Delta W_{n,i}\),把它们分成两组分别求平均后做乘积,得到无偏的单步损失,然后沿轨迹累加。
关键设计¶
-
Sample-splitting 无偏估计器:
- 功能:把 \(\ell_{\text{EM}}=\mathbb{E}[X^2]\) 的有偏二阶矩替换为 \(\mathbb{E}[X_1 X_2]\) 的无偏交叉矩
- 核心思路:定义 \(\text{Shot}_M[\xi]=\frac{1}{M}\sum_{m=1}^M \xi_m\);用 \(M_1+M_2\) 个 i.i.d. Brownian 增量计算 \(M_1+M_2\) 个独立单步误差,分成两个不重叠的组求 \(\text{Shot}_{M_1}, \text{Shot}_{M_2}\),最终损失 \(\ell^{M_1,M_2}_{\text{UEM}}=\mathbb{E}[\text{Shot}_{M_1}[\text{err}^{\text{EM}}_n]\cdot\text{Shot}_{M_2}[\text{err}^{\text{EM}}_n]]\)。Lemma 4.1 证明 \(\ell^{M_1,M_2}_{\text{UEM}}=([\mathcal{L}[u_\theta]-\phi_{u_\theta}](t_n,x))^2+O(\Delta t^{1/2})\),即恰好等于连续时间 PDE 残差的平方(除去消失项),完全去掉了 EM 偏置中的 \(\text{Tr}[(\sigma^T\nabla^2 u_\theta\sigma)^2]\) 项
- 设计动机:传统 BSDE 把同一个噪声 \(\Delta W_n\) 同时用于前向和反向,导致 \(\text{err}^{\text{EM}}_n\) 的方差被吸收进 \(\mathbb{E}[X^2]\) 形成偏置;用两组独立噪声把方差和均值平方分离,方差贡献被打散到 \(\mathbb{E}[X_1]\mathbb{E}[X_2]\) 之外
-
避免显式二阶导(second-order-free):
- 功能:保持 EM 单步更新结构不变,从而不需要 \(\nabla^2 u_\theta\)
- 核心思路:Heun-BSDE 之所以慢是因为 Itô-to-Stratonovich 转换会引入二阶空间导校正项,必须算 Hessian;Un-EM-BSDE 始终在 Itô 框架内,单步公式 \(B_n\) 只含 \(u, \nabla u\),整个 pipeline 只需要一阶反向梯度(PyTorch / JAX 一行 grad 搞定)
- 设计动机:在 \(d\) 维 PDE 中,Hessian 是 \(d\times d\) 矩阵,AD 计算成本 \(O(d)\) 倍于一阶梯度,在 \(d=100\) 量级的高维问题里直接决定了能否在 GPU 上跑完——这是 Heun-BSDE 42.91× 时间的根源
-
方差控制 + Shotgun 通用 wrapper:
- 功能:(a) 证明 \(M_1=1, M_2=2\) 的 Un-EM 估计器方差不比 EM-BSDE 大;(b) 把同样的 sample-splitting 思路套到任意单步损失上做通用消偏
- 核心思路:Theorem 4.3 证明在 \(\alpha=2/M-1/(2M_1)-1/(2M_2)\geq 4/(3M+\beta M^4)\)、\(\beta=1/(2M^2)-1/(4M_1 M_2)>0\) 的条件下,\(\mathbb{V}[\hat\ell^{M_1,M_2}_{\text{UEM}}]\leq\mathbb{V}[\hat\ell^M_{\text{SG}}]=\mathbb{V}[\hat\ell^M_{\text{SEM}}]\leq\mathbb{V}[\hat\ell_{\text{EM}}]\)。把同样的乘积构造套到 Shotgun loss 上得到 Un-SG,BSB 硬约束上 RL2 降低 2.67×、训练时间仅增 1.78×
- 设计动机:sample-splitting 容易引入额外方差(cross-moment 比 second moment 噪),方差分析是确保该方法实用的关键;同时通用 wrapper 把单点贡献放大成"任意有偏单步损失都能被无偏化"的一类技术
损失函数 / 训练策略¶
实验默认 \(M_1=M_2=5\)。基线对比:Shotgun 用 \(M=50\),Multi-Shot EM 用 \(M=10\),使内部采样 budget 与 \(M_1+M_2=10\) 对齐。损失既支持 soft constraint(终端条件作为额外损失项 \(L_T\))又支持 hard constraint(trial function 形式内置)。算法伪代码(Algorithm 1)展示 batched 实现:对 batch size \(B\)、时间步 \(N\)、shot 数 \(M_1+M_2\),张量 \(X\in\mathbb{R}^{B\times(N+1)\times(M_1+M_2)\times d}\) 一次性存所有候选状态,并行计算前向轨迹与每条 shot 的单步预测 \(\hat Y[b,n+1,i]\),最后按组聚合做乘积。
实验关键数据¶
主实验¶
5 个基准 PDE 上 RL2 误差(×\(10^{-2}\)),bold 标 best,underline 标 second-best:
| PDE / 约束 | EM-BSDE(有偏) | Shotgun(有偏) | Multi-Shot EM | Heun-BSDE(无偏) | FS-PINNs(无偏) | Un-EM-BSDE(本文) |
|---|---|---|---|---|---|---|
| HJB soft | 0.4055 | 1.1409 | 0.1617 | 0.1424 | 0.0867 | 0.1348 |
| BSB soft | 0.3483 | 39.99 | 0.1046 | 0.1030 | 0.0478 | 0.0814 |
| AC soft | 0.0462 | 0.0951 | 0.0206 | 0.0774 | 0.0325 | 0.0147 |
| BSB hard | 0.3456 | 0.1629 | 0.0739 | 0.0201 | 0.0048 | 0.0120 |
| PIDE hard | 0.0374 | 0.4057 | 0.0245 | 0.1874 | 0.0137 | 0.0226 |
训练时间倍数(Table 1):
| 方法 | Unbiased | 2nd-order-free | 训练时间 |
|---|---|---|---|
| EM-BSDE | ✗ | ✓ | 1× |
| Shotgun | ✗ | ✓ | 0.75× |
| Multi-Shot EM-BSDE | ✗ | ✓ | 1.74× |
| Heun-BSDE | ✓ | ✗ | 42.91× |
| FS-PINNs | ✓ | ✗ | 32.07× |
| Un-EM-BSDE(ours) | ✓ | ✓ | 1.79× |
消融实验¶
| 配置 | 效果 |
|---|---|
| Un-EM-BSDE 完整 | 几乎所有 setting 都是 second-best 或 best |
| 把 sample-splitting wrapper 套到 Shotgun(Un-SG) | BSB hard 上 RL2 降 2.67×,时间增 1.78× |
| Hard constraint vs Soft constraint | Hard 在复杂动力学(BZ、PIDE)下显著更稳,soft 受 loss balancing 影响 |
| BZ(fully-coupled FBSDE) soft | Un-EM 在 5.18 量级,Shotgun 飙到 86.53 |
关键发现¶
- 效率不降反而是杀手锏:在 \(d\) 高维场景下,Heun-BSDE 和 FS-PINNs 因为 Hessian 计算可能"根本跑不完",Un-EM-BSDE 的 1.79× 时间是 sweet spot。
- Wrapper 的通用性比方法本身更值钱:把同样的乘积构造套到 Shotgun 上立刻获得 2.67× 精度提升,说明这是个一类的消偏技术(适用于任何"同噪声前向 + 反向"的单步损失)。
- 复杂动力学(BZ、PIDE)下 hard constraint 更友好:soft constraint 的 loss balancing 问题在 fully-coupled / jump 场景中被放大,hard constraint 由于免去权重调优更稳定,这是个非常实用的工程提示。
- 方差不会爆:理论 Theorem 4.3 + 实验都验证 Un-EM 估计器方差不大于 EM-BSDE,sample-splitting 的"经典 concern"在这里不构成实际问题。
亮点与洞察¶
- 统计学经典招式的精准应用:sample-splitting 在统计推断里是老 trick,但把它精准 plug 进 BSDE 单步损失的位置,让消偏与效率同时实现,体现了对问题本质的深刻理解——偏置项就藏在 \(\text{Var}(X)\) 里,独立采样自动隔离它。
- 通用 wrapper 设计:Sec 5.3 把方法抽象成"任何带 \(\tau\)-参数的有偏单步损失都能套同样构造",这种 framework-level 贡献让论文价值远超单一算法。
- Itô vs Stratonovich 的避免:Heun-BSDE 强制走 Stratonovich 是为了得到无偏,但因此引入 Hessian;Un-EM 直接在 Itô 框架内通过随机化拿到同样的无偏性,跳出了 stochastic calculus 选择的两难。
- 理论与实验配合紧密:Lemma 4.1(无偏)+ Theorem 4.2(一致性)+ Theorem 4.3(方差)三件套都有对应实验验证,没有"理论好看但实验不灵"的常见 ML 论文病。
局限与展望¶
- 当前理论假设 \(\mu, \sigma\) 有界、\(u_\theta\in C^{1,2}\),对实际 fully-coupled FBSDE 和 PIDE 这类无界系数 / 跳跃过程,理论保证只是部分覆盖(论文显式承认)。
- 算法需要每步采 \(M_1+M_2\) 个独立 Brownian 增量(默认 10 个),相比 EM-BSDE 的 1 个,batched 实现要多分配 \(10\times\) 的张量内存,在 \(d\) 很大或 batch 很大时可能成为内存瓶颈。
- 实验只跑到 \(d\sim 100\) 量级,对真正大规模(\(d>1000\))的 PDE 求解器还没有完整 ablation。
- 与现代 SOTA 如 forward-backward 双网络方法(separate networks per step)的对比缺失。
- adaptive time-stepping 在复杂动力学下的扩展被列为 future work,目前固定 \(\Delta t\) 在 stiff / multi-scale PDE 上可能 sub-optimal。
相关工作与启发¶
- vs EM-BSDE (Raissi 2024):base method,Un-EM 用随机化 product 消除其偏置,时间仅多 79%。
- vs Heun-BSDE (Park & Tu 2025):同样无偏,但 Heun 需要 Hessian、时间 42.91×;Un-EM 完全免 Hessian。
- vs Shotgun (Xu & Zhang 2025):Shotgun 把偏置降 \(1/M\) 但不消除;Un-EM 用同样 wrapper 套到 Shotgun 上立刻无偏化。
- vs FS-PINNs (Park & Tu 2025):FS-PINNs 直接最小化沿 SDE 轨迹采样的 PDE 残差平方,无偏但要 Hessian;Un-EM 通过 BSDE-style 单步损失达到相似精度且免 Hessian。
- vs Hu et al. (2025) bias-variance trade-off PINNs:思路同源(独立 sample 形成乘积消偏),本文是这一思想在 BSDE 框架内的特化与扩展。
- 启发:(a) 把 sample-splitting 推广到其他随机损失(如对比学习、scoring rules)也许同样能消除二阶项偏置;(b) "把噪声拆成两组独立样本"的 trick 也可用于消除 RL 中 value estimation 的 bootstrap 偏置。
评分¶
- 新颖性: ⭐⭐⭐⭐ Sample-splitting 在 BSDE 损失内的应用是清晰且非平凡的贡献,但 sample-splitting 本身是经典思想
- 实验充分度: ⭐⭐⭐⭐ 5 个标准 PDE + 2 个复杂扩展(BZ、PIDE)+ wrapper 推广实验,覆盖很全
- 写作质量: ⭐⭐⭐⭐⭐ Table 1 的"unbiased + 2nd-order-free + time"三栏对照表立刻让贡献一目了然,Lemma/Theorem 编号清晰
- 价值: ⭐⭐⭐⭐⭐ Heun-BSDE 42.91× 慢导致它实用价值很有限,Un-EM 把无偏 BSDE 带回 EM 一级的训练成本,是直接可用的进步