HOTA: Hamiltonian Framework for Optimal Transport Advection¶
会议: ICLR 2026
OpenReview: https://openreview.net/forum?id=Og1klGbvlM
代码: https://github.com/nazarblch/HOTA
领域: 优化 / 最优传输 / 动力学生成建模
关键词: Generalized Schrödinger Bridge, Hamilton-Jacobi-Bellman, 最优传输, Kantorovich 势, 非光滑势函数
一句话总结¶
HOTA 把广义 Schrödinger Bridge 的对偶问题重写成「Kantorovich 势 + HJB 值函数」的联合优化,用 RL 式的 replay buffer + 目标网络 + 自适应梯度平衡稳定训练,做到无需建模中间密度、能处理非光滑势函数,同时严格保证终端分布匹配。
研究背景与动机¶
领域现状:最优传输(OT)已成为引导概率流的天然框架。静态 Monge-Kantorovich OT 只关心边界耦合,产生的是直线路径,完全忽略数据流形的几何结构。Benamou-Brenier 的动力学 OT 把它改写成概率路径上的时间变分问题,能在带曲率、障碍或势场的非平凡几何上操作轨迹,这与随机最优控制(SOC)里的广义 Schrödinger Bridge(GSB)问题紧密相连。
现有痛点:GSB 通常引入一个状态势函数 \(U(x)\) 来编码空间几何和物理约束(如分子模拟里的硬核排斥、steric clash),这些约束天然是不连续/非光滑的。主流解法走 HJB 对偶路线(DeepGSB、NLSB),但有两个硬伤:(1) 优化动态不稳定,高维下梯度方差大、样本效率差;(2) 缺乏严格的终端分布匹配准则,导致耦合不精确。另一条线 GSBM 走交替优化(先固定边际学漂移、再更新边际),但要求势函数处处可微,且交替迭代极其昂贵(收敛慢)。
核心矛盾:既要保留 HJB 的理论严谨性和处理非光滑势的能力,又要解决它训练不稳、终端不精确的老毛病——两者此前难以兼得。
本文目标:在两测度之间、几何由势函数定义的 GSB 问题上,提出一个新的 HJB 框架,既显式求解 GSB,又修复学习稳定性问题,并通过有理论支撑的目标函数确保终端分布精确匹配。
核心 idea:【对偶绑定】 把 GSB 的对偶形式重新组织成 Kantorovich 势与 HJB 值函数的绑定——让 HJB 约束变成一个内建正则项,从而得到一个对梯度优化友好、且天然不依赖密度估计的稳定目标。
方法详解¶
整体框架¶
HOTA 从 GSB 原始问题(最小化轨迹积分代价 \(L=\|v_t\|^2+U(x_t)\) 并匹配终端分布)出发,用 Lagrange 乘子(即 Kantorovich 势 \(g\))松弛终端约束,得到鞍点问题;再借动态规划定义值函数 \(s(t,x)\),证明其满足 HJB 偏微分方程,并把内层对控制的最小化解析地解出 \(v_t^*=-\nabla_x s(t,x_t)\)。最终对偶问题(定理 1)变成:势匹配项当判别器保证终端分布匹配,HJB 残差项保证轨迹最优。训练上把它当成「最优控制 + RL」的混合:用 Euler-Maruyama 模拟轨迹、replay buffer 存历史、EMA 目标网络稳定回归、自适应梯度平衡两个损失。
flowchart TD
A[GSB 原始问题<br/>min 轨迹代价 + 终端匹配 β] --> B[Lagrange 松弛终端约束<br/>引入 Kantorovich 势 g]
B --> C[动态规划: 值函数 s_t,x<br/>满足 HJB PDE]
C --> D[对偶问题 定理1<br/>v* = -∇s]
D --> E1[势匹配损失 L_pot<br/>判别器→终端分布匹配]
D --> E2[HJB 残差损失 L_hjb<br/>→轨迹最优性]
E1 --> F[自适应梯度平衡 + EMA 目标网络 + replay buffer]
E2 --> F
F --> G[Euler-Maruyama 生成 OT 轨迹<br/>无需建模中间密度]
关键设计¶
1. 对偶绑定:把 HJB 约束变成内建正则项。 这是全文的理论支点。通过松弛终端约束引入 Kantorovich 势 \(g\),再用值函数 \(s(t,x)\) 的边界条件 \(s(1,x)=-g(x)\) 把势函数和值函数「焊死」成同一个网络。定理 1 给出的对偶目标 $\(\max_{s(1,\cdot)}\ \mathbb{E}_{x_0\sim\alpha}\Big[\int_0^1 L(t,x_t,-\nabla_x s)\,dt + s(1,x_1)\Big] - \mathbb{E}_{y\sim\beta}\big[s(1,y)\big]\)$ 里,第一项扮演判别器、强制 \(T_\#\alpha=\beta\),第二项(即 HJB 约束)负责轨迹最优。关键的实践技巧是:在 HJB 约束被满足的前提下,积分代价项已经被「记账」并最小化,因此势匹配损失里可以省掉显式的积分项,只需简单的终端势差 \(L_{\text{pot}}=\frac{1}{n}\sum_k s_\theta(1,x_T^k)-\frac{1}{n}\sum_k s_\theta(1,y^k)\)。由此得到的目标完全无需建模 \(t\in(0,1)\) 的中间密度,这正是它绕过密度估计、能吃非光滑势的根本原因。
2. HJB 残差损失的双向对称写法 + 角加速度正则。 值函数要逼近 HJB PDE 的黏性解,损失直接惩罚 PDE 残差: $\(\frac{\partial s_\theta}{\partial t}-\frac{1}{2}\|\nabla_x s\|^2+U(x)+\frac{\sigma^2}{2}\mathrm{tr}\{\nabla^2 s\}+\lambda_a\|a\|\)$ 这里特意写成主网络 \(s_\theta\) 与目标网络 \(s\) 交叉配对的两项之和(一项 \(\nabla_x s\) 用目标、一项用主网络),让优化更像监督回归而非自举追逐自己。其中 \(U(x)\) 直接进残差、不需要对 \(U\) 求导,所以势函数即使几乎处处不可微也能用;\(\mathrm{tr}\{\nabla^2 s\}\) 这一二阶项靠 JAX 自动微分算,开销 <5%。附加的角加速度项 \(a=\frac{d}{dt}\frac{\nabla s_\theta}{\|\nabla s_\theta\|}\) 以系数 \(\lambda_a\) 鼓励轨迹拉直(可选)。
3. RL 式三件套稳定高维训练:replay buffer + EMA 目标网络 + 数据 replay。 因为不建模中间密度,训练点需要落在「流(轨迹)集中的区域」才有效。HOTA 早期用 \(\alpha,\beta\) 之间的线性插值粗略估计流区域,之后改从 replay buffer \(\mathcal{B}\) 采样历史轨迹点;每轮把当前策略 \(v_t=-\nabla s\) 生成的一条轨迹存进 buffer。目标网络 \(s\) 以 EMA 方式更新参数 \(\theta\leftarrow\gamma\theta+(1-\gamma)\theta\),复刻 DQN 把自举目标固定住的思路,把不稳定的 PDE 拟合变成稳定的回归。这套设计正是消融里证明对 feasibility 最关键的部分。
4. 自适应梯度平衡:让两个损失的梯度尺度自动对齐。 \(L_{\text{pot}}\)(势匹配)和 \(L_{\text{hjb}}\)(HJB 残差)量纲、尺度差异巨大,直接相加会被某一项主导。HOTA 用两者梯度范数之比的 EMA 作为缩放因子 $\(\nabla_\theta L_{\text{pot}}+\lambda_{\text{hjb}}\,\mathrm{EMA}\!\Big(\frac{\|\nabla_\theta L_{\text{pot}}\|}{\|\nabla_\theta L_{\text{hjb}}\|}\Big)\nabla_\theta L_{\text{hjb}}\)$ 把 HJB 梯度动态缩放到与势匹配梯度同量级再求和,保证既满足最优性条件、又满足边界约束的稳定收敛。
实验关键数据¶
主实验:低维带几何约束基准(feasibility / optimality,越低越好)¶
6 个数据集,前三个(Stunnel/Vneck/GMM)光滑、后三个(BabyMaze/Slit/Box)几乎不可微:
| 指标 | 方法 | Stunnel | Vneck | GMM | BabyMaze | Slit | Box |
|---|---|---|---|---|---|---|---|
| Feasibility \(W_2\) | NLSB | 30.54 | 0.02 | 67.76 | >1 | 0.013 | 0.024 |
| GSBM | 0.03 | 0.01 | 4.13 | 0.01 | 0.01 | 0.02 | |
| HOTA | 0.006 | 0.002 | 0.19 | 0.004 | 0.0004 | 0.002 | |
| Optimality(积分代价) | GSBM | 460.88 | 155.53 | 229.12 | 6.5 | 4.9 | 3.8 |
| HOTA | 383.25 | 115.09 | 80.44 | 4.87 | 3.06 | 2.84 |
HOTA 在 feasibility 与 optimality 上全面领先;GMM 上 optimality 从 229 降到 80,体现对邻近点轨迹分离的优势。运行效率上比 GSBM 快 50–100×。
高维可扩展性(Sphere 数据集,Normalized \(W_2\) / Optimality)¶
| 维度 | HOTA \(W_2\) | GSBM \(W_2\) | HOTA Opt | GSBM Opt |
|---|---|---|---|---|
| 10 | 0.001 | 0.99 | 3.37 | 20.9 |
| 1000 | 0.051 | 0.21 | 3.17 | 22.1 |
维度升到 1000 仍稳定;另在 1000 维 opinion depolarization 任务上有效。
图像翻译(FID,越低越好)¶
| 任务 | HOTA | 最强 baseline |
|---|---|---|
| CelebA Male→Female (64×64) | 6.28 | EUOT 8.44 |
| CelebA Female→Anime (64×64) | 11.67 | ENOT 13.12 |
消融实验(Stunnel/Vneck/GMM,feasibility)¶
| 变体 | Stunnel | Vneck | GMM |
|---|---|---|---|
| HOTA 完整 | 0.006 | 0.002 | 0.19 |
| w/o buffer | 0.076 | 16.47 | 1.248 |
| w/o 梯度平衡 | 3.60 | 0.026 | 2.64 |
| w/o EMA 目标网络 | 0.018 | 0.004 | 0.65 |
关键发现¶
- replay buffer 对 feasibility 最关键,去掉后 Vneck 直接发散(16.47)。
- 梯度平衡对 Stunnel feasibility 影响最大(去掉后从 0.006 飙到 3.60)。
- 三件套缺一不可,但作用在不同数据集上各有侧重。
亮点与洞察¶
- 理论闭环漂亮:把 Kantorovich 势的边界条件直接当成 HJB 值函数的终端条件,一个网络同时身兼「判别器」和「值函数」,对偶目标自然分解为 feasibility(判别器)+ optimality(HJB)两块,结构清晰。
- 省掉积分项的洞察:在 HJB 约束满足下积分代价已被隐式最小化,因此势匹配损失只需终端势差——这是把无需建模中间密度落地为简洁损失的关键一步。
- 把 RL 工程经验迁移到 OT 求解:replay buffer + 目标网络 EMA 这套 DQN 稳定技巧,被巧妙用来稳定 HJB 残差拟合,是跨领域借力的好范例。
- 非光滑势是真实需求:\(U(x)\) 不求导直接进残差,使分子模拟、黑盒 LLM 势、LiDAR 点云扫描面等不可微约束都能纳入。
局限与展望¶
- 网络设计敏感:对时间的 Fourier 特征编码等设计选择较敏感,值函数要同时支撑最优控制估计和当 Kantorovich 势,需要能聚合丰富时空信息的架构。
- 理论假设较强:当前依赖 \(s\in C^{1,2}\) 等正则性条件,作者计划放宽到更弱的正则性假设。
- 架构有提升空间:未来引入更结构化的神经架构以进一步改善可扩展性与稳定性。
- 仅在单卡 A100 80GB 上验证,更大规模工业场景的表现待考。
相关工作与启发¶
- 动力学 OT 谱系:Benamou-Brenier 动力学 OT → Schrödinger Bridge → GSB(加入状态势 \(U\))。HOTA 属于 HJB 对偶解法这一支,对手是 DeepGSB、NLSB、GSBM。
- 与 GSBM 的关键区别:GSBM 要求势处处可微且交替优化昂贵;HOTA 单目标联合优化、吃非光滑势、快 50–100×。
- 与 SOC 的联系:Adjoint Matching、SOCM 等 SOC 方法存在方差估计不稳定问题,HOTA 用 Kantorovich 势和保证 feasibility。
- 启发:对于「约束以势函数形式给出且可能不可微」的生成/控制任务,把约束塞进 PDE 残差而非要求可微分,是个值得复用的思路;RL 的稳定化技巧(target network、replay)在连续控制/PDE 求解里同样有效。
评分¶
- 新颖性: ⭐⭐⭐⭐ — 对偶绑定(Kantorovich 势=HJB 终端条件)让 HJB 约束变内建正则项,加上把 RL 稳定技巧迁移到 GSB 求解,组合新颖且自洽。
- 实验充分度: ⭐⭐⭐⭐ — 覆盖 6 个低维带几何约束基准、1000 维 Sphere/opinion、两个图像翻译任务,消融拆清三件套贡献;但 baseline 部分结果直接引自原论文、个别任务缺对照。
- 写作质量: ⭐⭐⭐⭐ — 从原始问题到对偶定理推导清晰,feasibility/optimality 两个指标定义明确,算法伪代码完整。
- 价值: ⭐⭐⭐⭐ — 解决非光滑势 + 终端精确匹配 + 高维稳定三个实际痛点,且有 50–100× 加速,对计算生物、机器人、生成建模等 OT 应用有直接价值。