跳转至

Branched Schrödinger Bridge Matching

会议: ICLR 2026
arXiv: 2506.09007
代码: HuggingFace
领域: 图像生成 / 生成模型理论
关键词: Schrödinger Bridge, 分支轨迹, 流匹配, 细胞命运分化, 最优传输

一句话总结

提出 BranchSBM 框架,通过参数化多个时间依赖的速度场和增长过程,将 Schrödinger Bridge Matching 扩展到分支场景,能够建模从单一初始分布到多个目标分布的分叉动态轨迹,在 LiDAR 表面导航和单细胞扰动建模等任务上显著优于单分支方法。

研究背景与动机

预测初始分布和目标分布之间的中间轨迹是生成模型中的核心问题。现有方法如 Flow Matching 和 Schrödinger Bridge Matching(SBM)能够有效学习两个分布之间的映射,但它们模型化的是单一随机路径,本质上只能处理单模态的转换。

核心矛盾:许多真实系统中存在分支动态——即从一个共同起源状态分叉演化到多个不同的终态分布。例如: - 细胞命运分化:均质的祖细胞群体在发育过程中分化为不同的细胞类型 - 药物扰动响应:同一细胞系在药物处理后可能产生多种不同的表型结果 - 路径规划:从一个起点到不同目的地的多路径导航

现有的单路径 SBM 方法无法建模这种分支行为。当目标分布是多模态的时,单分支方法要么发生模式坍缩(只到达最低能量的模态),要么生成的轨迹不能准确地到达各个终态。

本文切入角度:将 SBM 推广到分支场景——学习一组分支的 Schrödinger 桥,每个分支有独立的漂移场和增长率,共同描述群体级别从单一起点到多个终点的分叉动态。

方法详解

整体框架

BranchSBM 要学的是「一个起点分叉到多个终点」的轨迹:给定初始分布 \(\pi_0\)\(K+1\) 个目标分布 \(\{\pi_{1,k}\}_{k=0}^{K}\),单分支 SBM 只能拟合一条路径、面对多模态终态会坍缩。它的核心做法是为每个分支 \(k\) 配一套独立的速度网络 \(u_{t,k}^\theta\)(决定这条分支往哪走)和增长网络 \(g_{t,k}^\phi\)(决定有多少质量流到这条分支),再用分阶段训练把二者解耦学出,最终得到一族从同一起点出发、质量随时间在分支间重分配、分别流向各目标分布的分支轨迹。整个 pipeline 把「分支该往哪走」和「质量怎么分」拆成两套网络、四个阶段逐级学出,如下图所示。

%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400}}}%%
flowchart TD
    IN["输入:单一起点 π0<br/>+ K+1 个目标分布"]
    S1["Stage 1 · 神经插值<br/>学能量最优条件路径 φ"]
    S2["Stage 2 · 条件流匹配<br/>各分支速度场 u_t,k(决定往哪走)"]
    S3["Stage 3 · 训练增长网络<br/>g_t,k 把质量在分支间重分配(Σ权重=1)"]
    S4["Stage 4 · 联合微调<br/>速度场 + 增长网络 + 重建损失"]
    OUT["输出:K+1 条分支轨迹<br/>质量分别流向各目标分布"]
    IN --> S1 --> S2 --> S3 --> S4 --> OUT

关键设计

1. 非平衡条件随机最优控制(Unbalanced CondSOC):让质量能在分支间流动

标准 SBM 假设质量守恒——初始的每一份质量都一一对应地走到终态,这正是它处理不了分支的根本原因:一个统一的群体没法在守恒约束下「分裂」成多个目标。BranchSBM 在标准 GSB 问题上引入一个时间依赖的权重 \(w_t(X_t)\),它由增长率 \(g_t(X_t)\) 驱动,使得主分支的质量可以随时间转移到次要分支,从而把「分叉」这件事变成可建模的连续过程。Proposition 1 进一步证明,这个 Unbalanced GSB 问题可以通过条件化端点对来高效求解,把一个群体级的优化问题拆成可采样的条件子问题。

2. 分支广义 Schrödinger 桥问题:把多分支写成多个非平衡子问题之和

有了 Unbalanced CondSOC 这块积木,BranchSBM 把整个分支 GSB 问题形式化为多个 Unbalanced GSB 问题之和:主分支(\(k=0\))初始权重为 1,\(K\) 个次要分支初始权重为 0,整个过程满足质量守恒约束 \(\sum_{k=0}^{K} w_{t,k} = 1\) 对所有时间 \(t\) 成立——即任意时刻所有分支的质量加起来恒为 1,质量只在分支间挪动而不凭空增减。Proposition 2 证明了这个分支 CondSOC 问题可以分解为独立的分支子问题,于是每个分支可以分别训练,避免了直接求解高维耦合问题。

3. 四阶段训练策略:把漂移学习和增长学习解耦,逐步逼近最优

由于直接联合优化漂移场和增长率很困难,BranchSBM 把训练拆成四步逐级推进。Stage 1 是神经插值优化:训练插值网络 \(\varphi_{t,\eta}(\mathbf{x}_0, \mathbf{x}_{1,k})\),在状态代价 \(V_t(X_t)\) 下学能量最优的条件路径,用轨迹损失 \(\mathcal{L}_{\text{traj}}\) 最小化路径的动能和势能。Stage 2 是条件流匹配:训练每个分支的漂移网络 \(u_{t,k}^\theta\) 去匹配 Stage 1 学到的条件速度场,用标准 CFM 损失 \(\mathcal{L}_{\text{flow}}\)。Stage 3 固定漂移网络、单独训练增长网络 \(g_{t,k}^\phi\),优化一个综合损失:分支能量损失 \(\mathcal{L}_{\text{energy}}\) 优化分支间的能量分配,权重匹配损失 \(\mathcal{L}_{\text{match}}\) 确保终态权重匹配目标分布的比例,质量守恒损失 \(\mathcal{L}_{\text{mass}}\) 强制所有分支权重之和守恒。Stage 4 再解冻所有参数联合微调漂移和增长网络,并加入重建损失 \(\mathcal{L}_{\text{recons}}\) 收尾。先单独把漂移学准、再单独把质量分配学准、最后联合打磨,避开了一上来就联合优化的不稳定。

4. 理论保证:从最优性到质量单向流动的完整证明链

BranchSBM 给上述设计配了一组定理来兜底。Proposition 3 证明 Stage 1+2 训练得到的就是 GSB 问题的最优漂移,说明前两阶段的解耦不损失最优性;Proposition 4 通过变分法的直接方法证明了最优增长函数的存在性;Lemma 2 则证明次要分支的最优增长率是非减的——也就是质量只会从主分支流出、不会回流,这与「群体从单一起点逐步分化、不可逆」的直觉一致。

损失函数 / 训练策略

  • Stage 1 使用 Adam 优化器,lr=1e-4
  • Stage 2-4 使用 AdamW 优化器,lr=1e-3,weight decay=1e-5
  • 每个阶段最多训练 100 epochs
  • 模型架构:3层MLP + SELU激活函数
  • 增长网络的次要分支输出额外通过 softplus 确保非负性
  • 状态代价使用数据依赖的 LAND 度量(低维)或 RBF 度量(高维)

实验关键数据

主实验

数据集 指标 BranchSBM 单分支SBM 说明
LiDAR 表面导航 \(\mathcal{W}_1\) / \(\mathcal{W}_2\) 显著更低 分支路径绕山两侧
小鼠造血分化(t₁) \(\mathcal{W}_1\) / \(\mathcal{W}_2\) 显著更低 中间时间点准确预测
小鼠造血分化(t₂) \(\mathcal{W}_1\) / \(\mathcal{W}_2\) 显著更低 终态两个细胞命运准确还原
Clonidine扰动(50PC) MMD / \(\mathcal{W}_1\) / \(\mathcal{W}_2\) 最优 只达cluster0 单分支无法到达所有终态
Clonidine扰动(100PC) MMD 优于50PC单分支 - 高维扩展性验证
Clonidine扰动(150PC) MMD 优于50PC单分支 - 维度可扩展
Trametinib扰动(3分支) MMD / \(\mathcal{W}_1\) / \(\mathcal{W}_2\) 最优 仅达cluster0 验证3分支能力

消融实验

配置 关键指标 说明
Stage 3 only (固定漂移) \(\mathcal{L}_{\text{energy}}\) 较高 增长网络独立训练
Stage 3+4 (联合训练) \(\mathcal{L}_{\text{energy}}\) 更低 联合优化进一步降低能量
\(\mathcal{L}_{\text{match}}\) → 0 终态权重准确匹配
\(\mathcal{L}_{\text{mass}}\) → 0 质量守恒得到满足

关键发现

  • 单分支SBM发生模式坍缩:面对多模态目标时,单分支方法只能到达能量最低的模态,完全忽略其他终态
  • 分支时间可自动学习:在 LiDAR 实验中,模型自动在山脊边缘发起分支——表明框架能从数据中学习最优分支时刻
  • 高维可扩展:在单细胞扰动实验中,从50到150个主成分维度,BranchSBM 都能有效工作
  • 质量转移动态合理:权重曲线显示质量从主分支平滑地转移到次要分支,符合生物学直觉
  • 三分支同样有效:Trametinib 实验验证了框架可扩展到两个以上的分支

亮点与洞察

  • 问题定义新颖且重要:首次形式化定义并求解了分支 Schrödinger 桥问题,填补了生成模型理论的空白
  • 理论扎实:提出了完整的理论框架(Propositions 1-4),包括存在性、唯一性和构造性证明
  • 四阶段训练设计精巧:通过分阶段解耦漂移学习和增长学习,避免了联合优化的困难
  • 应用场景明确:细胞命运分化和扰动响应是计算生物学的核心问题,本文提供了原理性的解决方案
  • 与 Flow Matching 和 OT 的深刻联系:当增长率为零时退化为标准 GSBM,理论上统一了多种方法

局限与展望

  • 分支数需要预先指定:需要先通过聚类确定 \(K\),无法自动发现分支结构
  • 配对需要 OT:端点配对依赖最优传输计划,对大规模数据的计算开销可能较大
  • 仅在低到中维(2-150D)验证:全基因组级别(数万维)的可扩展性未验证
  • MLP 架构简单:更复杂的架构(如 Transformer、GNN)可能带来性能提升
  • 中间时间点数据利用有限:目前主要使用端点数据训练,如果有中间快照数据应该能进一步提升
  • 生物学验证不够深入:没有与其他计算生物学方法(如 CellOT、PRESCIENT)进行全面比较

相关工作与启发

  • Schrödinger Bridge:Schrödinger (1931) 的经典问题,近年来在生成模型中重新焕发生机(De Bortoli et al., 2021; Shi et al., 2023)
  • Flow Matching:Lipman et al. (2023) 的条件流匹配为 BranchSBM 的 Stage 2 提供了理论基础
  • Generalized SBM:Liu et al. (2023) 引入了状态代价,BranchSBM 在此基础上增加了分支机制
  • 非平衡最优传输:Chizat et al. (2018)、Pariset et al. (2023) 研究了质量不守恒的传输问题
  • 单细胞轨迹推断:Schiebinger et al. (2019)、Bunne et al. (2023) 使用 OT 方法建模细胞状态转换
  • 启发:是否可以引入注意力机制让模型自动学习分支结构?是否可以处理分支合并(而不仅仅是分叉)的场景?

评分

  • 新颖性: ⭐⭐⭐⭐⭐
  • 实验充分度: ⭐⭐⭐⭐
  • 写作质量: ⭐⭐⭐⭐⭐
  • 价值: ⭐⭐⭐⭐