跳转至

Differentiable Lifting for Topological Neural Networks

会议: ICLR2026
OpenReview: https://openreview.net/forum?id=eC89CbINIw
代码: https://github.com/JorgeLuizFranco/difflifting
领域: 图学习 / 拓扑深度学习
关键词: 拓扑神经网络, 可微提升, 高阶结构, 直通估计, cell complex

一句话总结

提出 ∂lift(DiffLift)——一个端到端可学习的图"提升"(lifting)框架,用 GNN 节点嵌入参数化"候选高阶 cell 的分布",再用伯努利采样 + 直通估计器决定哪些 cell 进入拓扑结构,从而把原本靠先验启发式确定的 hypergraph / simplicial / cell complex 结构变成由下游任务监督学出来,在 12 个数据集、4 种 TNN 上较静态提升最高提升约 45%。

研究背景与动机

领域现状:拓扑神经网络(Topological Neural Networks, TNN)被视为关系学习的新前沿。它把图上只能建模"成对边"的消息传递,推广到 hypergraph、simplicial complex、cell complex 等高阶结构上,让 cycle、clique 这类多节点高阶关系也能参与消息传递,从而增强表达能力。但 TNN 不能直接吃图——必须先把输入图 \(G\) 通过一个提升(lifting)操作 \(\mathrm{lift}: \mathcal{G} \to \mathcal{T}\) 转换成目标拓扑域 \(\mathcal{T}\),TNN 才能在上面跑。

现有痛点:绝大多数提升方法是无监督、与任务无关的启发式规则。比如 cycle lifting 把图里 cycle basis 的每个基本环组成一个 2-cell;clique lifting 把团组成 simplex;k-hop / k-NN / kernel lifting 各有各的固定造结构方式。问题是——到底该提升到哪个域、用哪种提升规则,对下游性能影响极大且高度依赖数据。论文 Figure 1 给出反例:同一个 hypergraph 域下,k-hop 和 k-NN 在 Cora 和 Wisconsin 上的优劣完全相反;提升到 simplicial 还是 hypergraph,在不同数据集上也会出现相反结论。

核心矛盾:提升这一步对 TNN 性能至关重要,但它却被排除在端到端优化之外——结构是"拍脑袋"先验定好的,没有任何任务信号回流去纠正它,于是很容易得到次优的拓扑结构。此前唯一探索过可微提升的工作(DCM,Battiloro et al. 2024)只局限于 cell complex 单一域。

本文目标:设计一个通用的可微提升框架,要同时满足三点——(1) 跨域(hypergraph / simplicial / cell complex 都能用);(2) 端到端可学(梯度能穿过"选不选某个 cell"这个离散决策);(3) 即插即用(能接到任意现成 TNN 后面)。

切入角度:作者把"提升"重新理解为一个概率采样问题——不再用固定规则一锤定音地造结构,而是用节点嵌入去参数化一族候选高阶 cell 上的分布,让模型自己学着"采纳"或"拒绝"每个候选 cell。

核心 idea:用 GNN 嵌入参数化候选 cell 的伯努利接受概率,配合直通估计器让离散采样可导,并按维度从低到高分层迭代地生成结构——把"造拓扑结构"这件事变成可被下游任务监督的可学习模块。

方法详解

整体框架

∂lift 解决的是"如何端到端学出图的高阶提升"。给定一张带属性的输入图 \(G=(V,E,x)\)、目标域 \(\mathcal{T}\) 和最大维度 \(D_{\max}\),整体管线是:先用一个任意 GNN 算出每个节点 \(v\) 的嵌入 \(z_v\);再用这些嵌入按目标域的约束生成候选高阶 cell(每个候选 cell \(C\) 的嵌入由其成员节点嵌入做置换不变聚合得到 \(z_C=\bigoplus_{v\in C} z_v\));然后用一个神经网络 \(\phi\) 给每个候选 cell 算一个接受概率 \(\phi(z_C)\),并从伯努利分布 \(\mathrm{Ber}(\phi(z_C))\) 采样决定是否纳入;最后从被接受的 cell 拼出关系对象,丢给一个现成 TNN 做图级/节点级预测。整套结构(GNN + 采样器 + TNN)端到端联合训练。

对于 cell complex 这类分层域,cell 不是一次性造完的,而是按维度由低到高迭代生成:先学 1 维 cell(边),再用当前复形里的环作为候选去学 2 维 cell,第 \(i\) 维的采样结果会反过来约束第 \(i+1\) 维的候选。对于 hypergraph,因为约定超边维度为 1,迭代一轮即停。整个流程里只有"生成候选 cell"(Step 2)和"接受/拒绝"(Step 3)这两步是依赖具体目标域的,其余骨架对所有域通用。

%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400}}}%%
flowchart TD
    A["输入带属性图 G"] --> B["GNN 节点嵌入<br/>每个节点得 z_v"]
    B --> C["候选 cell 生成<br/>kNN + 自适应尺寸 k_v"]
    C --> D["接受/拒绝采样<br/>Bernoulli + 直通估计"]
    D -->|"D < Dmax:用当前复形的环作下一维候选"| E["分层迭代提升<br/>维度 D 增 1"]
    E --> C
    D -->|"D = Dmax"| F["提升后的拓扑结构"]
    F --> G["现成 TNN<br/>图级/节点级预测"]

关键设计

1. GNN 节点嵌入:把"造结构"建立在可学的语义表示上

提升要不要监督,关键在于"判断两个/多个节点该不该结成高阶 cell"的依据是什么。静态提升用的是图的原始连通性或人工特征(n-hop、cycle),这些信号固定、与任务无关。∂lift 的第一步是用一个任意 GNN(可以是 GIN,也可以是带位置编码的 Graph Transformer GPS)算出每个节点的潜在表示 \(z_v\),后续所有"该不该纳入某个 cell"的决策都建立在这组嵌入之上。这一步的意义在于:嵌入是端到端学出来的,下游任务的梯度会一路回传调整它,从而让"结构生成的依据"本身也跟着任务走。实验里换 GNN backbone 的消融(Table 2)证实了这点——表达力更强的 GPS(受益于位置编码)普遍优于 GIN,在 cell complex + ZINC 上把 MAE 从 0.46 压到 0.17,说明 cell 的质量直接受嵌入质量牵引。

2. 候选 cell 生成:用 kNN + 自适应尺寸框定该考察哪些高阶组合

潜在的高阶 cell 组合是指数级的(\(2^V\)),不可能枚举。∂lift 在嵌入空间里用 k 近邻把候选范围收窄到"语义上相近的节点群"。以 hypergraph 为例,对每个节点 \(v\),取其在嵌入空间里的 \(k_v\) 个最近邻组成候选超边 \(C(v)=\{S\subset V: |S|=k_v,\ w\notin S\Rightarrow \mathrm{dist}(z_w,z_v)\ge \max_{u\in S}\mathrm{dist}(z_u,z_v)\}\)。这里的关键巧思是 cell 的尺寸 \(k_v\) 也是学出来的、且自适应:定义一个 \((k_{\max}-k_{\min}+1)\) 维概率向量 \(\pi_v \propto \exp\circ\mathrm{MLP}(z_v)\),再采样 \(k_v\sim\mathrm{Categorical}(\pi_v)\)。这样不同节点可以形成不同大小的高阶 cell,而不是被一个全局固定的 \(k\) 框死。对 cell complex 的 1 维情形,同样用 kNN 生成候选边;到 \(D\ge 2\) 维时,候选不再是 kNN,而是当前复形 \(K_{D-1}\)\((D-1)\) 维环空间 \(Z_{D-1}(K_{D-1})=\ker(\partial_{D-1})\) 的一组基(即 cycle basis,可用 NetworkX 的算法找基本环),把环作为候选 2-cell。

3. 接受/拒绝采样:用伯努利 + 直通估计器让"选不选 cell"可导

框定候选后,真正决定结构的是"每个候选要不要"。∂lift 用一个置换不变的神经网络 \(\Psi\)(接收 cell 内节点嵌入的 multiset,hypergraph 用 MLP、高维 cell 用 DeepSet)输出接受概率,然后采一个伯努利变量决定纳入与否,例如超边 \(b_v\sim\mathrm{Ber}(\Psi(\{\!\{z_u:u\in C(v)\}\!\}))\),最终结构为 \(V\cup E\cup\{C\in\mathcal{C}: y_C=1,\ y_C\sim\mathrm{Ber}(\phi(z_C))\}\)。难点在于伯努利采样是离散的、不可导,梯度传不回 \(\Psi\) 和 GNN。论文用直通估计器(straight-through estimator, Bengio et al. 2013)绕过:前向用采样的 0/1 硬决策,反向把梯度当作恒等直接穿过采样节点,于是整条链路端到端可训。置换不变是必须的——cell 是无序集合,接受概率不能依赖节点排列顺序。论文还给了一个确定性变体:直接用阈值 \(b_C=\mathbf{1}[\Psi(\cdot)>\gamma]\)(如 \(\gamma=0.5\))代替采样。被接受 cell 的特征则用缩放求和投影 \(x_{C(v)}=\frac{1}{k_v}\sum_{u\in C(v)} x_u\) 得到。

4. 分层迭代提升:按维度由低到高地生成,让高阶 cell 建立在已采纳的低阶结构之上

拓扑对象天然是分层的(simplicial / cell complex 里高维 cell 的边界必须是已存在的低维 cell),不能一股脑独立采样各维度,否则会破坏域的约束。∂lift 用一个迭代过程从 \(D=1\) 开始逐维往上走:第 \(D\) 维采样完得到复形 \(K_D=K_{D-1}\cup\{C\in\mathcal{C}:b_C=1\}\),再把它作为第 \(D+1\) 维候选生成的基础(高维候选直接来自当前复形的环空间),直到 \(D=D_{\max}\) 停止。这保证了生成的结构始终满足目标域的层级/边界约束,也让低维决策能信息性地引导高维决策。实验中受限于现有 cell complex TNN 实现只支持 2 维对象(如 TopoBench),作者只取 \(D_{\max}=2\);对 hypergraph 因约定维度为 1,一轮即停,候选超边的生成与接受是"尴尬并行"的,因此对 hypergraph 域计算特别高效。

损失函数 / 训练策略

没有额外设计的提升专用损失——整个 ∂lift(GNN 嵌入器 + cell 采样器 \(\Psi\) + 下游 TNN)端到端联合优化,直接用下游任务的标准损失(分类用交叉熵 / AUC 对应目标,ZINC 回归用 MAE)。离散采样处的梯度由直通估计器提供。论文指出 cell complex 域下计算候选需要算 cycle basis,对节点数可能有立方级开销,可通过对 \(k_v\) 加正则或把其分布往 0 推来减少候选数量。

实验关键数据

在 12 个数据集、4 种 TNN 上评测,覆盖图分类与节点分类两类任务,每个指标取 3 次独立运行的均值±标准差(ZINC 用 MAE↓,MOLHIV 用 AUC↑,其余用 accuracy↑)。

主实验(图分类,∂lift vs 静态提升)

TNN Lifting NCI1↑ NCI109↑ MOLHIV↑ MUTAG↑ ZINC↓
Cellular CWN Cycle 76.93 76.71 70.15 66.67 0.46
Cellular CWN ∂lift 79.81 80.55 75.37 85.96 0.17
Cellular CXN Cycle 72.02 75.01 69.17 61.40 0.79
Cellular CXN ∂lift 82.08 82.57 74.83 84.21 0.17
Hypergraph UniGCN2 k-hop 72.70 72.01 50.72 61.40 0.66
Hypergraph UniGCN2 ∂lift 77.45 75.30 69.32 89.47 0.56

∂lift 在超过 90% 的 TNN/数据集组合上是最佳提升方法;相对静态提升平均准确率最高提升约 45%。MOLHIV 上 UniGCN2 从 k-hop 的 50.72 跃升到 ∂lift 的 69.32,ZINC 上 CWN 的 MAE 从 0.46 降到 0.17,提升幅度都很显著。

节点分类(∂lift vs DCM 可微提升基线)

TNN Lifting Cora Citeseer Texas Wisconsin Avg
DCM - 80.73 77.90 56.76 73.86 72.31
CWN Cycle 74.80 75.83 63.06 80.39 73.52
CWN ∂lift 80.17 72.83 80.18 77.78 77.74
TopoTune ∂lift 86.82 78.23 72.97 65.36 75.84

∂lift(配 CWN 或 TopoTune)在所有数据集上都优于此前唯一的可微提升基线 DCM,尤其在 Texas、Cora 上优势很大;在异配图(Texas/Wisconsin)上对 k-hop 也有明显优势。

关键发现

  • GNN backbone 直接决定提升质量:表达力强的 GPS(带位置编码)普遍优于 GIN,cell complex + ZINC 上把 MAE 从 0.46 压到 0.17——印证了"结构生成依赖嵌入质量"的设计逻辑。
  • 同配/异配各有所长:相比 k-hop,∂lift 在异配数据集上更强、在同配上略弱(UniGCN2),说明可学提升尤其擅长那些"固定连通性规则不灵"的图。
  • 配 ∂lift 时 TNN 普遍超过纯 GNN,验证了"把提升变得可学、任务相关"确实能把高阶结构的潜力释放出来。

亮点与洞察

  • 把"提升"从启发式预处理升格为可学模块:这是最核心的"啊哈"点——长期以来 TNN 社区默认结构是先验给定的,本文指出这恰恰是性能瓶颈,并给出端到端解法。
  • 直通估计器 + 伯努利采样是让离散结构选择可导的关键工程手段,思路可直接迁移到任何"需要可微地选/删图结构"的场景(如图结构学习、稀疏化、子图选择)。
  • 自适应 cell 尺寸 \(k_v\sim\mathrm{Categorical}(\pi_v)\):让不同节点形成不同大小的高阶 cell,比全局固定 \(k\) 灵活得多,是一个轻量却有效的设计。
  • 通用配方 + 域特化两步的拆分很干净:只有"候选生成"和"接受/拒绝"依赖具体域,骨架通用——这种抽象让框架天然能扩展到 simplicial、combinatorial complex 乃至未来的点云。

局限与展望

  • cell complex 的候选生成开销大:需要算 cycle basis,相对节点数可能是立方级;作者建议靠正则 \(k_v\) 或把其分布往 0 推来缓解,但更高效的候选识别算法仍是开放问题。
  • 受现有实现限制只做到 2 维\(D_{\max}=2\) 是因为当前 cell complex TNN 实现只支持 2 维对象,更高维的潜力未被验证。
  • 理论分析缺位:作者推测 ∂lift 能通过学出跨"桥"的 2-cell 来缓解 oversquashing(在长窄路径连接的稠密子结构间建立捷径),但只给了直觉,未做形式化证明。
  • 展望:扩展到点云(3D 建图)、动态/时序图(拓扑结构随时间演化),以及结合预训练学出可跨任务迁移的拓扑先验。

相关工作与启发

  • vs 静态提升(cycle / k-hop / k-NN / kernel lifting):它们用固定的连通性或特征规则一次性造结构、与任务无关;∂lift 把结构生成变成端到端可学、任务监督,在 90%+ 组合上更优,最高提升 45%。
  • vs DCM(Battiloro et al. 2024,可微 cell complex 模块):DCM 是此前唯一的可微提升,但只支持 cell complex 单域;∂lift 是迄今最通用的可学提升(hypergraph / simplicial / cell complex 通吃),并在节点分类上全面超过 DCM。
  • vs 图结构学习(GSL):GSL 可看作"提升到图域"的特例(只学成对边);∂lift 把它推广到高阶 cell 的学习,实验上在多个 benchmark 超过 GSL 方法。

评分

  • 新颖性: ⭐⭐⭐⭐⭐ 首次把跨域、通用的可微提升做成即插即用模块,重新定义了 TNN 里"提升"的角色。
  • 实验充分度: ⭐⭐⭐⭐ 12 数据集 × 4 TNN,含 backbone 消融与确定性变体,覆盖图/节点分类两类任务;高维与理论分析尚缺。
  • 写作质量: ⭐⭐⭐⭐ 通用配方 + 域特化两步讲得清晰,Figure 1/2/4 直观;拓扑背景门槛偏高。
  • 价值: ⭐⭐⭐⭐⭐ 直击 TNN 的核心瓶颈,方法可即插即用、思路(STE 选结构)可迁移面广。