跳转至

DoFlow: Flow-based Generative Models for Interventional and Counterfactual Forecasting

会议: ICLR 2026
arXiv: 2511.02137
代码:
领域: 图像生成
关键词: 因果推断, 连续正则化流, 时间序列预测, 反事实推理, 异常检测

一句话总结

提出DoFlow,一种基于连续正则化流(CNF)的因果生成模型,在因果DAG上统一实现观测、干预和反事实时间序列预测,并可通过显式似然进行异常检测,在合成和真实医疗数据上验证了有效性。

研究背景与动机

时间序列预测是统计学和机器学习的核心问题。传统预测模型(ARIMA、LSTM、Transformer等)是纯观测性的——学习历史相关性并外推。但实际应用中常需回答因果性的"what if"问题:

干预查询:"如果改变控制变量,系统如何演化?"例如水电站中改变涡轮控制信号,功率输出如何变化。观测预测器对固定历史只能给出固定预测,无法模拟不同控制方案。

反事实查询:"如果当时采取不同干预,已观测到的轨迹会如何改变?"例如在医疗中,观察到患者的治疗和结局轨迹后,问在不同用药方案下这个特定患者的结局是否会更好。

核心挑战:现有因果生成模型主要面向静态数据,时间序列的因果反事实预测尚无通用框架。需要一个既具因果结构又具生成能力的模型。

方法详解

整体框架

DoFlow把 \(K\) 维多变量时间序列的每个变量看作因果DAG上的一个节点,并按拓扑序排列。它用结构因果模型(SCM)刻画每个节点的生成机制——节点 \(X_{i,t}\) 由自身历史 \(X_{i,t-}\)、父节点历史 \(X_{\text{pa}(i),t-}\) 和独立外生噪声 \(U_{i,t}\) 共同决定,即 \(X_{i,t} := f_i(X_{i,t-}, X_{\text{pa}(i),t-}, U_{i,t})\)。序列被切成上下文窗口 \(\{1,\dots,\tau\}\) 和预测窗口 \(\{\tau+1,\dots,T\}\),前者作条件、后者作预测目标。核心思路是给每个节点配一个连续正则化流(CNF),把因果推断的「溯因—行动—预测」三步映射成流的「编码—改条件—解码」,从而用一套模型统一吐出观测、干预、反事实三类预测,并顺带用流的显式密度做异常检测。

%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400}}}%%
flowchart TD
    A["多变量时间序列<br/>按因果DAG拓扑序排列"] --> B["时间条件连续正则化流<br/>RNN聚合自身+父节点历史 H"]
    B --> C["编码:事实轨迹沿 H^F<br/>正向积分→潜变量 z"]
    C -->|观测| D["三种预测模式·观测<br/>采样 z 逆向解码"]
    C -->|干预| E["三种预测模式·干预<br/>干预节点强赋值并写回 H"]
    C -->|反事实| F["三种预测模式·反事实<br/>保留 z,换 H^CF 逆向解码"]
    D --> G["预测轨迹"]
    E --> G
    F --> G
    G --> H["似然异常检测<br/>对数密度低→异常分数"]

关键设计

1. 时间条件连续正则化流:让因果结构和流绑定在一起

如果对所有节点共用一个无条件流,就丢掉了因果DAG里"谁依赖谁"的信息。DoFlow为每个节点 \(i\) 学一个跨时间步共享的CNF,并把因果历史作为条件注入。流由一个Neural ODE定义,在 \(s\in[0,1]\) 上把数据分布连续地搬运到基底分布 \(\mathcal{N}(0,1)\)\(\frac{dx_{i,t}(s)}{ds} = v_i(x_{i,t}(s), s; H_{i,t-1})\)。这里的条件 \(H_{i,t-1}=\text{concat}(h_{i,t-1}, h_{\text{pa}(i),t-1})\) 由RNN聚合自身和父节点的历史隐状态——速度场 \(v_i\) 因此显式依赖因果父节点,DAG结构被编码进了流的动力学里。

2. 编码—解码的可逆映射:一条流同时承担推断噪声和生成预测

CNF天然可逆,正反两个方向各对应因果推断里的一半。正向(编码)把观测值 \(x_{i,t}^F\) 沿事实隐状态 \(H_{i,t-1}^F\) 积分到潜变量 \(z_{i,t}^F = x_{i,t}^F + \int_0^1 v_i(x_{i,t}(s), s; H_{i,t-1}^F)\,ds\),相当于反推出该样本对应的外生噪声;逆向(解码)则从潜变量出发、沿新的隐状态 \(\hat{H}_{i,t-1}\) 反向积分 \(\hat{x}_{i,t} = z_{i,t} - \int_0^1 v_i(x_{i,t}(s), s; \hat{H}_{i,t-1})\,ds\),生成预测值。正是这种"编码出噪声、换掉条件、再解码"的结构,让反事实可以在保留个体噪声的前提下改变干预,而不需要额外的反事实专用模块。

3. 三种预测模式:同一套编解码切换出观测、干预、反事实

观测预测最简单:直接采样 \(z\sim\mathcal{N}(0,1)\),按拓扑序逐节点逆向解码。干预预测在解码时对干预集合 \((i,t)\in\mathcal{I}\) 强制赋值 \(\hat{x}_{i,t}\leftarrow\gamma_{i,t}\),非干预节点照常解码——但由于干预后的值会写回隐状态向下游传播,下游节点自然"感知"到上游被动了手脚,这就是不具因果结构的基线做不到的地方。反事实预测则走完整三步:先用事实隐状态 \(H^F\) 把事实轨迹编码成 \(z_{i,t}^F\)(溯因,锁住个体噪声),再施加干预(行动),最后用反事实隐状态 \(\hat{H}^{CF}\) 把同一组 \(z_{i,t}^F\) 解码出反事实轨迹(预测)。

4. 似然异常检测:免费拿到的显式密度

CNF不仅能采样,还能算出预测轨迹的精确对数密度,因为变换的雅可比可以通过速度场的散度积分得到:\(\log p_{\theta}(\hat{x}_{\tau+1:T}\mid\hat{H}_\tau) = \sum_{t=\tau+1}^{T}\big[\log q(z_t) + \int_0^1 \nabla\cdot v_\theta(x_t(s), s; \hat{H}_{t-1})\,ds\big]\)。当上下文异常时,模型给出的预测轨迹会落在低密度区,于是这个对数似然天然成了异常分数,无需另训判别器。

一个完整示例

以医疗反事实为例:观测到某患者在事实用药方案下的治疗与结局轨迹后,要问"换一种剂量结局会不会更好"。DoFlow先用事实隐状态 \(H^F\) 沿正向积分,把这段事实轨迹逐节点编码成潜变量 \(z_{i,t}^F\),这一步把"这个特定患者"的外生噪声固定下来;接着在反事实剂量上施加干预、并据此重算下游节点的隐状态 \(\hat{H}^{CF}\);最后把同一组 \(z_{i,t}^F\) 沿 \(\hat{H}^{CF}\) 逆向解码,得到这名患者在新剂量下的反事实结局。因为编码与解码共用一条可逆流、且潜变量被保留,结果反映的是同一个体在不同干预下的差异,而非群体平均效应。

损失函数 / 训练策略

训练用条件流匹配(CFM)损失,参考路径取数据点与基底样本之间的直线插值 \(\phi(x_{i,t},z;s)\),回归对应的恒定速度 \(z-x_{i,t}\)\(\mathcal{L}_{\text{CFM}}(\theta) = \mathbb{E}\big[\frac{1}{K(T-\tau)}\sum_{i,t}\|v_i(\phi(x_{i,t},z;s), s; H_{i,t-1}) - (z - x_{i,t})\|_2^2\big]\)。隐状态用真实观测自回归更新(teacher forcing),保证训练时条件与数据对齐。理论上(Corollary 4.5),在SCM单调且训练精确的假设下,上述编码—解码流程可精确恢复真实的反事实轨迹,为方法提供了可识别性保证。

实验关键数据

主实验

合成数据观测/干预/反事实预测RMSE(多种DAG结构)

方法 Tree-Obs Tree-Int Tree-CF Diamond-Obs Diamond-Int Diamond-CF
DoFlow 0.57 0.54 0.11 0.55 0.57 0.12
GRU 0.65 1.01 NA 0.58 0.94 NA
TFT 0.58 0.97 NA 0.63 1.18 NA
TiDE 0.60 1.15 NA 0.50 1.05 NA

关键观察: - DoFlow在干预预测上大幅领先(RMSE差距~0.5),因为基线不具备因果结构 - 反事实预测是DoFlow独有能力,基线无法实现 - 非线性非加性(NLNA)场景也表现稳健

真实数据:水电站干预预测

DoFlow在水电站数据上成功预测不同涡轮控制方案下游信号的变化,因果结构与物理一致。

真实数据:癌症治疗效果估计

方法 均方根政策误差(RMSE of PEHE)↓
DoFlow 最优
CRN 次优

消融实验

  • 加性 vs 非线性非加性噪声模型:DoFlow在两种设置下均表现良好
  • 不同DAG结构(Chain/Tree/Diamond/FC-Layer):一致性优异
  • 异常检测AUROC:DoFlow在合成和真实水电站数据上有效检测异常

关键发现

  1. DoFlow统一了观测、干预、反事实三种查询,是首个通用时间序列反事实框架
  2. CNF的可逆性是核心:编码→修改条件→解码,天然支持反事实
  3. 隐状态传播干预效应,下游节点自然感知上游干预
  4. 显式似然密度提供异常检测的额外能力

亮点与洞察

  • 统一框架:用一个模型同时支持三种因果查询,架构设计自然优雅
  • CNF的因果对齐:流的可逆性与因果推断的溯因-行动-预测三步法完美契合
  • 理论支持:证明了在单调SCM下的反事实精确恢复性质
  • 显式似然:除预测外免费获得异常检测能力,增加实用价值
  • RNN+CNF的组合:RNN编码时序上下文,CNF处理不确定性和可逆映射

局限与展望

  • 假设因果DAG已知,实际中可能需要因果发现
  • 假设无同时步内因果效应(所有因果影响至少有一步时滞)
  • 反事实恢复理论需要SCM单调性和精确训练假设
  • 每个节点一个独立CNF,节点数多时扩展性待验证
  • 真实场景中反事实真值不可观测,仅能在合成数据上定量评估
  • 未与更复杂的时间序列因果效应估计方法深入对比

相关工作与启发

  • vs 传统因果效应方法:后者关注离散动作的短期期望差异,DoFlow支持连续变量在任意时间的干预
  • vs 静态因果生成模型(Javaloy等):DoFlow扩展到时间序列,捕获跨时间的因果依赖
  • vs 现代预测器(TFT/TiDE/TSMixer):它们是观测性的,无法回答因果问题
  • 医疗应用潜力:个体化治疗方案比较、药物剂量优化、临床决策支持

评分

维度 分数
创新性 ★★★★★
理论深度 ★★★★☆
实验充分性 ★★★★☆
实用价值 ★★★★☆
写作质量 ★★★★☆