跳转至

Overlap-Adaptive Regularization for Conditional Average Treatment Effect Estimation

会议: ICLR 2026
OpenReview: https://openreview.net/forum?id=HMMSnGgYOy
代码: https://github.com/Valentyn1997/OAR
领域: 因果推断 / CATE 估计 / 元学习器
关键词: 条件平均处理效应, 重叠权重, 自适应正则化, Neyman 正交性, 元学习器

一句话总结

针对条件平均处理效应(CATE)估计中"低重叠区域"难学的老大难问题,本文提出 Overlap-Adaptive Regularization(OAR):让两阶段元学习器第二阶段模型的正则化强度随重叠权重 \(\nu(x)\) 反比变化(重叠越低、正则越强),并给出可保持 Neyman 正交性的去偏版本 dOAR,在多组(半)合成数据上稳定优于"常数正则化"。

研究背景与动机

领域现状:从观测数据估计 CATE \(\tau(x)=\mathbb{E}[Y[1]-Y[0]\mid X=x]\) 是因果机器学习的核心任务,在个性化医疗里直接用来预测不同病人对某种治疗的反应。当前 SOTA 是两阶段 Neyman 正交元学习器(DR-learner、R-learner、IVW-learner):第一阶段估计 nuisance 函数 \(\eta=(\mu_0,\mu_1,\pi)\),第二阶段把伪结果(pseudo-outcome)\(\phi(Z,\eta)\) 投影到目标模型类 \(\mathcal{G}\) 上。它们的好处是模型无关、且 Neyman 正交性让第二阶段对 nuisance 估计误差一阶不敏感。

现有痛点:元学习器的性能受重叠(overlap)程度制约。重叠用倾向得分定义的 \(\nu(x)=\pi(x)(1-\pi(x))\) 表示——协变量相近的病人是否会接受不同治疗。医学里重叠常被违反:某类病人因诊疗指南几乎只接受一种治疗,于是这些低重叠区域反事实样本稀疏,CATE 极难学。现有两条对策都不理想:(1) Retargeting(重定向)把重叠权重塞进误差项,在低重叠区把误差项截断/降权,但这样目标模型在低重叠区要么行为不可控、要么干脆改去估计了另一个因果量(R-/IVW-learner 在低重叠区估出的是加权平均处理效应 WATE 而非 CATE);(2) 常数正则化(CR)对全空间"一刀切"地压低 CATE 异质性,根本不看重叠程度。

核心矛盾:低重叠区伪结果方差巨大(来自第一阶段外推差或逆倾向得分爆炸),需要更强的正则;而高重叠区样本充足,需要保留模型灵活性。CR 用同一个 \(\lambda\) 无法兼顾——配 DR-learner 时会"在低重叠区欠拟合、高重叠区过拟合同时发生",配 R-/IVW-learner 时正则一大就滑向 WATE。

核心 idea:让正则化强度随重叠自适应——把正则函数 \(\lambda(\nu)\) 设成与逆重叠 \(1/\nu\) 成正比,于是低重叠区正则强(强制模型更简单、更平滑)、高重叠区正则弱(保留灵活性)。这是首个把重叠权重直接写进元学习器正则项(而非误差项)的工作。

方法详解

整体框架

OAR 不改两阶段元学习器的骨架,而是替换第二阶段目标风险里的正则项。回顾 Neyman 正交风险的一般形式:

\[\mathcal{L}(g,\eta)=\underbrace{\mathbb{E}\big[\rho(A,\pi(X))(\phi(Z,\eta)-g(X))^2\big]}_{\text{误差项 }E}+\underbrace{\Lambda(g;P(X))}_{\text{正则项 }\Lambda}\]

其中 \(\rho(A,\pi(X))\ge 0\) 是去偏权重,\(\phi(Z,\eta)\) 是满足 \(\mathbb{E}[\phi(Z,\eta)\mid X=x]=\tau(x)\) 的伪结果。DR-/R-/IVW-learner 只是 \(w,\rho,\phi\) 的不同选择。传统 CR 把正则取成与重叠无关的常数(如 \(\Lambda=\lambda\|\beta\|_2^2\))。OAR 的全部动作就是:把这个常数 \(\lambda\) 换成依赖重叠的函数 \(\lambda(\nu(x))\),并针对不同的第二阶段模型类(参数 / 非参数)给出具体实现,再补一个去偏修正使其重新满足 Neyman 正交性。整篇文章是"一个正则化原则 + 多种落地 + 理论保证"的纯方法论文,没有新的流水线结构。

关键设计

1. 重叠自适应正则函数:让正则强度 \(\propto 1/\nu\)

这是全文地基,针对 CR"一刀切"的痛点。OAR 把正则项写成 \(\Lambda_{\text{OAR}}=\Lambda(g;P(X,A);\lambda(\nu(X)))\),要求正则函数 \(\lambda(\nu)>0\)\(\lambda(\nu)\propto 1/\nu\)。直觉是:\(\nu(x)\to 0\)(低重叠)时 \(\lambda(\nu)\to\infty\),把模型压成简单/平滑的;\(\nu(x)\to 1/4\)(完美重叠,\(\pi=0.5\))时 \(\lambda(\nu)\to 0\),几乎不正则、留足灵活性。作者给出三类候选正则函数:

\[\lambda_m(\nu)=\tfrac{1}{4}\nu^{-1}-1,\quad \lambda_{\log}(\nu)=-\log(4\nu),\quad \lambda_{m2}(\nu)=\tfrac{1}{16}\nu^{-2}-1\]

分别叫乘性(multiplicative)、对数(log)、平方乘性(squared multiplicative),惩罚强度依次递增。这个设计跟 retargeting 的本质区别在于:retargeting 把重叠降权进误差项\(\mathbb{E}[\rho\mid X]=\nu\)),OAR 把重叠升权进正则项;两者一般产生不同的风险极小点,仅当倾向得分恒定时才重合。它也比"表示平衡"(balancing)简单:Proposition 1 证明 OAR 的平均正则量 \(\mathbb{E}[\lambda(\nu(X))]\) 等于或被 \(P(X)\)\(P(X\mid A=a)\) 之间的 \(f\)-散度上界,因此只需估计倾向得分、无需在高维 \(X\) 上估计分布距离。

2. 参数模型的两种落地:OAR 噪声正则与 OAR dropout

针对线性/神经网络这类参数模型 \(\mathcal{G}=\{g(\cdot;\beta,c)\}\),作者把两种经典"注入噪声"的正则技巧改造成重叠自适应版本。OAR 噪声正则给输入加方差正比于逆重叠的高斯噪声 \(\xi\sim\mathcal{N}(0,\sqrt{\lambda(\nu(X))}^2)\),即 \(\sigma^2\propto 1/\nu(x)\),低重叠区噪声更大、正则更强。对线性模型可证其显式形式(Prop 2)为 \(E+\|\beta\|_2^2\,\mathbb{E}[\rho(A,\pi(X))\lambda(\nu(X))]\),恰好等价于一个常数为 \(\mathbb{E}[\rho\cdot\lambda(\nu)]\) 的岭回归。OAR dropout 则按 \(p(\nu)=\lambda(\nu)/(\lambda(\nu)+1)\in(0,1)\) 的概率丢弃,高重叠区 \(p=0\)、低重叠区 \(p\to 1\);其线性显式形式(Prop 3)是 \(E+\beta^\top\mathrm{diag}(\Sigma_{\rho(\cdot,\pi)}\cdot\lambda(\nu))\beta\),是一个随重叠变化的二次型而非纯 \(l_2\)——等价于把每个特征按 \(\tilde X_j=X_j/\sqrt{\mathbb{E}[\rho\cdot\lambda(\nu)\cdot X_j^2]}\) 缩放后再做岭回归。这里有个重要细节:当 OAR 噪声正则跟已经 retarget 的 R-/IVW-learner 组合、用乘性 \(\lambda_m\) 时,低重叠区 \(\mathbb{E}[\rho\cdot\lambda_m(\nu)]\to 1/4\) 反而退化成常数正则;因此要让 retargeted learner 也能自适应,必须改用平方乘性 \(\lambda_{m2}\)(使 \(\mathbb{E}[\rho\cdot\lambda_{m2}(\nu)]\to\infty\))。

3. 去偏 OAR(dOAR):一步偏差修正以保持 Neyman 正交性

原始 OAR 的正则函数依赖估计出的重叠权重 \(\hat\nu(x)=\hat\pi(x)(1-\hat\pi(x))\),当倾向得分 \(\hat\pi\) 估得差时会被一阶误差严重带偏——在观测研究里真值重叠未知,这很要命。作者用一步偏差修正(one-step bias correction)构造 dOAR:在原风险上加一项基于高效影响函数(efficient influence function, IF)的修正 \(C^\diamond\)(噪声版 \(C^{+\xi}\)、dropout 版 \(C^{\circ\xi}\) 见原文 Eq.10–11,推导用了链式法则 + 重参数化 + REINFORCE 技巧)。结论是 dOAR 对 \(\hat\pi\) 的估计误差一阶不敏感,从而与标准 Neyman 正交学习器组合时重新恢复整体的 Neyman 正交性。实现上还对过大的修正项 \(C^\diamond\) 做了截断以稳定训练。

4. 非参数扩展与理论保证:RKHS 范数版 + 相比 CR 的超额风险下界

对非参数模型 \(\mathcal{G}=\mathcal{H}_{K+c}\)(核岭回归 KRR),OAR 把正则项设成加权 RKHS 范数 \(\Lambda_{\text{OAR}}=\|\sqrt{\lambda(\nu)}\,g\|_{\mathcal{H}_K}^2\)\(\sqrt{\lambda(\nu(x))}\) 作为 RKHS 的乘子起到逐点自适应正则的作用,Prop 6 给出该加权 KRR 有良定义的闭式解。理论上 Prop 5 用偏差-方差分解证明:DR-learner + 线性二阶段模型下,超额预测风险 \(\|\hat g-g^*\|_{L_2}^2\) 的方差项与偏差项里,CR 对应 \(\Gamma_{\text{CR}}=\lambda I\)、OAR/dOAR 对应 \(\Gamma_{\text{OAR}}=\mathrm{diag}(\Sigma_{\lambda(\nu)})\);在(i)条件方差假设和(ii)"低重叠-低异质性归纳偏置(LOLH-IB)"下,OAR/dOAR 的方差项 \(\le\) CR、且偏差项不会增加太多,从而整体优于 CR。这给"为什么自适应正则更好"提供了形式化支撑,而非仅靠经验。

损失函数 / 训练策略

两阶段实现:阶段 1 用交叉验证的全连接网络估计 nuisance \(\hat\eta\);阶段 2 用经验版目标风险 \(\hat{\mathcal{L}}\)(参数模型)或 KRR 闭式解(非参数模型)拟合目标网络得到 CATE 估计器。为了和 CR 公平比较,作者把正则函数 \(\tilde\lambda(\nu)\)(或 dropout 概率 \(\tilde p(\nu)\))整体缩放,使其平均正则量与常数 \(\lambda\) 对齐;dOAR 额外截断过大的去偏项以稳住训练。论文实证推荐乘性正则函数 \(\lambda_m\) 作为默认:理论上 Prop 5 表明方差最优的正则形状约为 \(\lambda(\nu)\propto\nu^{-1/3}\)(介于 log 与乘性之间),乘性是更稳健的实用近似;且乘性 OAR + DR-learner 在 KRR 下等价于 CR + R-learner(借用 R-learner 久经检验的有效性)。

实验关键数据

主实验

在四组(半)合成数据上评估,指标为样本外 rPEHE(\(\text{rPEHE}_{\text{out}}\),越低越好),baseline 是同等强度的常数正则化 CR;横向比较 DR-/R-/IVW-learner 三种 Neyman 正交学习器。

数据集 规模 关键现象
IHDP \(n=672+75,\,d_x=25\) 重叠违反严重;每种学习器+正则类型的最优都由某个 OAR/dOAR 版本取得,DR-learner + 大正则时尤其有效
ACIC 2016 \(n=4802,\,d_x=82\),77 个子集 DR-learner 下 dOAR 在过半数据集上显著优于 CR
HC-MNIST \(d_x=784+1\) 高维天然低重叠;DR-/R-/IVW 三种学习器多数情况下 OAR/dOAR 显著优于 CR,验证可扩展性

HC-MNIST(乘性 \(\lambda_m/p_m\),rPEHE_out,越低越好,括号为相对 CR 的变化)节选:

学习器 方法 Noise reg. \(\lambda=0.25\) Dropout \(p=0.3\)
DR CR 0.711 0.727
DR OAR 0.696 (−0.015) 0.713 (−0.014)
DR dOAR 0.684 (−0.027) 0.705 (−0.021)
IVW CR 1.028 1.117
IVW OAR 0.984 (−0.044) 1.061 (−0.056)
IVW dOAR 0.978 (−0.049) 1.110 (−0.006)

(Oracle rPEHE_out = 0.513。)

消融实验

ACIC 2016(DR-learner,报告"OAR/dOAR 显著优于 CR 的数据集占比",\(\alpha=0.1\),越高越好):

正则函数 方法 Noise reg. \(\lambda=0.05\) Dropout \(p=0.3\)
\(\lambda_m\) OAR 31.17% 41.56%
\(\lambda_m\) dOAR 57.14% 70.13%
\(\lambda_{m2}\) OAR 27.27% 16.88%
\(\lambda_{m2}\) dOAR 76.62% 64.94%
\(\lambda_{\log}\) dOAR 7.79% 64.94%

关键发现

  • 去偏(dOAR)几乎总比原始 OAR 更好:ACIC 上 dOAR 的"显著胜出占比"普遍远高于 OAR(如 \(\lambda_m\) dropout 从 41.56% 升到 70.13%),印证一步偏差修正对真值重叠未知的观测数据很关键。
  • DR-learner 是最佳搭档:OAR/dOAR + DR-learner 在所有 benchmark 上稳定好用,因为它在伪结果高方差与正则强度间取到恰当平衡;R-/IVW-learner 因误差项已被重叠降权,再叠 OAR 容易在低重叠区过度正则
  • 乘性正则函数最稳:理论(方差最优 \(\propto\nu^{-1/3}\))+ 与 R-learner 的等价性 + 经验三方面共同支持把 \(\lambda_m\) 设为默认。
  • 正则越大、增益越明显:低重叠+大正则区间正是 CR 最吃亏、OAR 自适应优势最突出的地方。

亮点与洞察

  • 换了个正则化的"位置":现有工作都把重叠权重塞进误差项(retargeting),本文第一个把它放进正则项,并证明二者一般不等价、仅倾向得分恒定时才重合——一个简单但被忽视的设计空间。
  • 把经典 dropout/噪声正则"因果化":复用 Wager et al. (2013)"dropout/噪声正则 ≈ 自适应正则"的老结论,但首次将其与 CATE 估计接通,给每个区域按重叠分配不同的"有效 \(l_2\)"。
  • 去偏让 trick 变成可靠估计器:用高效影响函数做一步修正,使一个看似工程化的正则技巧重新满足 Neyman 正交性——这是把启发式上升为有理论保证方法的范式可迁移到其他"依赖估计量的正则"。
  • 明确的适用边界:OAR 在"低重叠恰好伴随低 CATE 异质性(LOLH-IB)"时最优,作者把这个归纳偏置摆到台面上而非藏起来,诚实且实用(缺反事实时本就该偏好简单模型)。

局限与展望

  • 依赖 LOLH-IB 归纳偏置:当低重叠区其实有高 CATE 异质性时,"低重叠就更强正则"会把真实异质性抹平,OAR 的优势不再成立。
  • 只在(半)合成数据上验证:因为需要反事实真值,全部实验是 IHDP/ACIC/HC-MNIST 等(半)合成集,真实观测数据上的表现未直接检验。
  • 缩放需要对齐 CR:为公平比较把 \(\tilde\lambda(\nu)\) 整体缩放到与常数对齐,实际部署时如何无监督地选 OAR 的整体强度仍是开放问题。
  • 改进方向:把重叠权重同时用于误差项与正则项的联合设计、对 LOLH-IB 不成立区域做检测并退化为局部 CR、以及非参数 OAR 在高维下的可扩展实现(论文已因高维排除 RKHS 版于 ACIC/HC-MNIST)。

相关工作与启发

  • vs Retargeting(R-/IVW-learner、trimming/truncation):它们把重叠降权进误差项以聚焦子人群,但不约束目标模型在子人群之外如何泛化,且低重叠区会改估 WATE;OAR 升权进正则项,能在低重叠区仍给出 ATE(比 WATE 更有意义的因果量),并控制泛化平滑度。两者可叠加但需配 \(\lambda_{m2}\)
  • vs 常数正则化 CR:CR 对全空间一刀切,配 DR-learner 会同时过/欠拟合;OAR 按 \(1/\nu\) 自适应,Prop 5 证明在合理假设下方差项不劣于 CR、偏差项不显著增大。
  • vs 表示平衡(Balancing,如 CFR/TARNet 系):平衡用 \(P(X\mid A=0)\)\(P(X\mid A=1)\) 的分布距离做平均正则;OAR 的平均正则量也能写成分布距离(\(f\)-散度上界),但只需估计倾向得分、不必在高维 \(X\) 上估计分布距离,实现更轻。

评分

  • 新颖性: ⭐⭐⭐⭐⭐ 首个把重叠权重放进元学习器正则项的工作,设计空间被忽视已久。
  • 实验充分度: ⭐⭐⭐⭐ 覆盖四组数据 + 三种学习器 + 三种正则函数,但全为(半)合成、缺真实观测数据验证。
  • 写作质量: ⭐⭐⭐⭐ 理论与实现交织清晰,命题/显式形式给得扎实,但符号密集、阅读门槛高。
  • 价值: ⭐⭐⭐⭐ 即插即用于任意两阶段元学习器,对个性化医疗等低重叠场景有直接实用价值。