跳转至

Breaking the Simplification Bottleneck in Amortized Neural Symbolic Regression

会议: ICML 2026
arXiv: 2602.08885
代码: https://github.com/psaegert/flash-ansr
领域: 可解释性
关键词: 符号回归, 表达式化简, Transformer, 摊销推理, 科学发现

一句话总结

提出 SimpliPy(基于规则的化简引擎,比 SymPy 快 100 倍)和 Flash-ANSR(基于 Transformer 的摊销符号回归框架),在 FastSRB 基准上以 ~58% 的恢复率匹敌甚至超越遗传编程方法 PySR,同时随推理预算增加生成更简洁的表达式。

研究背景与动机

领域现状:符号回归(Symbolic Regression, SR)旨在从观测数据中发现可解释的解析表达式。传统方法以遗传编程(GP)为主(如 PySR),但每个数据集都从头搜索,无法在任务间迁移结构知识。摊销 SR 通过在海量合成数据上预训练 Transformer 来学习后验 \(p(\bm{\tau}|\mathcal{D})\),将计算负担转移到一次性预训练阶段。

现有痛点:摊销 SR 面临三重困境。第一,静态语料方案(如 NeSymReS)使用 SymPy 离线化简生成固定数据集(~100M 表达式),但高昂的化简成本限制了覆盖度和维度(\(D \leq 3\))。第二,部分方法(如 E2E)放弃化简直接在未规范化表达式上训练,导致模型浪费容量学习语法冗余(\(x+0\), \(1 \cdot x\) 等)。第三,将 SymPy 嵌入训练循环的方法(如 NSRwH)引入严重的计算瓶颈,SymPy 的中位化简时间约 100ms/表达式。

核心矛盾:表达式化简的质量与速度之间存在根本矛盾——通用 CAS 系统的面向对象解析和树遍历机制对 SR 训练场景来说过于重量级,但不化简又导致训练目标冗余和推理效率低下。

本文目标:设计一个快速且高质量的化简引擎,打破 CAS 瓶颈,使摊销 SR 能扩展到更大规模、更高维度的训练。

切入角度:作者观察到 SR 训练中遇到的表达式具有有限的结构复杂度,因此可以将化简本身也"摊销化"——离线穷举发现所有短模式的等价规则,运行时仅做快速查表匹配。

核心 idea:用预计算的哈希索引规则集替代通用 CAS,将符号化简从 \(O(100\text{ms})\) 降到 \(O(1\text{ms})\),从而实现在训练循环中同步化简在线生成的表达式。

方法详解

整体框架

Flash-ANSR 的训练 pipeline 分为四阶段:(1) 骨架采样——按长度指数先验采样算子数量,用 Lample & Charton 算法构建前缀骨架;(2) SimpliPy 化简——将冗余表达式归约为标准形式;(3) 去污染——通过符号和数值双重检测剔除与测试集等价的表达式;(4) 数据集渲染——采样常数和数据点,生成 \((X, y)\) 对。推理阶段采用 softmax 采样生成 \(K\) 个候选骨架,经 SimpliPy 去重后用 Levenberg-Marquardt 优化常数,最终按拟合质量和简洁性正则化排序选出最优表达式。

关键设计

  1. SimpliPy 化简引擎:

    • 功能:将代数表达式快速归约为最短标准形式,实现 100 倍于 SymPy 的加速
    • 核心思路:分两阶段工作。离线阶段——按长度分层穷举所有至多 \(L_{\max}=7\) 符号的表达式模式,通过数值等价测试发现化简规则 \(\bm{\tau} \to \bm{\tau}'\),每条规则必须满足严格长度缩减 \(|\bm{\tau}'| < |\bm{\tau}|\) 和变量不增条件。在线阶段——将无变量的 ground 规则用哈希表实现 \(O(1)\) 查找,含变量的 pattern 规则按算子和长度分桶存储为树结构做子树匹配。运行时交替执行模式匹配(ApplyRules)和项消去(CancelTerms)至多 \(K=5\) 轮,最终排序交换律操作数并替换合并后的常量
    • 设计动机:通用 CAS 从第一性原理求解化简,对 SR 训练场景来说大材小用。通过将化简本身摊销化(离线投入 ~100h 计算,换取运行时毫秒级化简),彻底消除训练循环中的 CAS 瓶颈
  2. 可扩展编码器-解码器架构:

    • 功能:将数据集编码为条件信息,自回归生成前缀表示的表达式骨架
    • 核心思路:编码器采用 Set Transformer 处理变长数据集,引入 masked RMSSetNorm 替代 LayerNorm/SetNorm(统计轴数与 SetNorm 相同但参数量减半,且正确处理 padding);输入用 32-bit IEEE-754 多热编码(覆盖 \(10^{-38}\)\(10^{38}\),远超 16-bit 的 \(10^{-4}\)\(10^{4}\))。解码器使用 Pre-RMSNorm + FlashAttention + RoPE 位置编码,推理时用 softmax 采样替代 beam search 以提高候选多样性
    • 设计动机:Pre-Norm 比 Post-Norm 训练更稳定(消融实验中 Post-Norm 训练直接发散);32-bit 编码覆盖物理域数据的真实量级;softmax 采样在 \(c=4096\) 时产生的语法重写仅为 beam search 的 \(1/70\),恢复率高 9.4pp
  3. 严格去污染与评估协议:

    • 功能:防止训练数据泄漏,建立可靠的评估标准
    • 核心思路:去污染时先剪枝所有常数节点得到骨架,然后同时做符号比较(token 相等)和数值比较(在固定网格 \(X_{\text{check}} \in \mathbb{R}^{512 \times D}\) 上求值后四位小数取整再哈希,碰撞即拒绝)。评估采用机器精度恢复标准 \(\text{FVU} \leq 1.19 \times 10^{-7}\),分析推理时间-恢复率的 Pareto 前沿
    • 设计动机:先前几乎所有工作都未做严格去污染,可能导致性能高估;宽松的成功阈值(如 \(R^2 > 0.9\))掩盖了真实失败案例

训练策略

训练目标为交叉熵损失:\(\hat{\theta} = \arg\min_{\theta} \mathbb{E}[-\sum_{t=1}^{L} \log p_{\theta}(\bar{\tau}_t^* | \bar{\tau}_{<t}^*, \mathcal{D})]\),编码器和解码器端到端联合训练。共训练四个规模的模型(3M / 20M / 120M / 1B 参数),最大模型在 512M 在线生成的数据-表达式对上训练。推理时按简洁性正则化排序:\(\hat{\bm{\tau}}^{\star} = \arg\min \log_{10}\text{FVU}(\hat{\bm{\tau}}) + \gamma \cdot |\hat{\bm{\tau}}|\),默认 \(\gamma = 0.05\)

实验关键数据

主实验(FastSRB 基准,115 个表达式)

方法 类型 vNRR↑ (~10s) vNRR↑ (峰值) 表达式长度比↓ 说明
NeSymReS 摊销 SR ~10% ~10% 饱和,无法泛化
E2E 摊销 SR <2.5% <2.5% 几乎完全失败
PySR 遗传编程 ~45% 50.0% 0.94→1.85 复杂度随时间增长
Flash-ANSR 3M 摊销 SR ~25% ~35% 落后于 PySR
Flash-ANSR 120M 摊销 SR ~45% ~58% 1.40→1.27 超越 PySR,简洁性反转

SimpliPy 化简效率对比

化简引擎 中位时间 化简比 超时率(>1s) 长度增加比例
SymPy ~100ms 较好 9% 38%-52%
SimpliPy (\(L_{\max}=4\)) ~1ms 接近 SymPy 0% 0%(严格不增长)
SimpliPy (\(L_{\max} \geq 5\)) 数ms 超越 SymPy 0% 0%

消融实验

配置 vNRR↑ 长度比 说明
Full (SimpliPy, 100M) 最高 最低 完整模型
A-U (无化简) 接近 +40-50% 表达式冗余严重
B1 (Post-Norm) 训练失败 梯度不稳定
B2 (16-bit 编码) 显著下降 显著上升 数值精度不足
Beam Search vs Softmax -9.4pp 重写多 70× beam search 模式坍缩

关键发现

  • 简洁性反转(Parsimony Inversion):PySR 随推理时间增长表达式越来越复杂(长度比 0.94→1.85),而 Flash-ANSR 反向收敛到更简洁的形式(1.40→1.27),这是因为更多采样能找到稀有但简洁的"大海捞针"式正确表达式
  • 数据稀疏性的三阶段相变:在 \(M \approx 8\) 个数据点处出现"复杂度峰值",类似于 Deep Double Descent——太少的点导致简洁的高偏差近似,临界点处模型用过多常数插值,充足数据后才收敛到真实表达式
  • 噪声鲁棒性不足:在噪声水平 \(\eta \geq 10^{-2}\) 时 PySR 明显优于 Flash-ANSR,因为模型仅在无噪声数据上训练,将噪声误解为高频信号

亮点与洞察

  • 化简的摊销化思想:将化简本身视为可预计算的查表问题而非在线求解问题,这种将"一次性重计算"换取"运行时零成本"的思路可迁移到任何需要在训练循环中执行昂贵符号操作的场景
  • softmax 采样优于 beam search:在多模态后验下,beam search 的模式寻求行为导致 70 倍的冗余重写,softmax 采样以更低成本探索更多功能不同的假设——这一发现对所有序列生成任务都有启示
  • 自我发现 scaling law:作者用 Flash-ANSR 本身对自己的 scaling curve 做符号回归,发现其性能渐近遵循 \(\text{vNRR} \propto \log\log T\),而 PySR 有约 53% 的上界——用自己的工具分析自己的行为,方法论上很优雅

局限与展望

  • 噪声鲁棒性差:仅在无噪声数据上训练,噪声构成分布外偏移,未来需在训练中引入噪声增强
  • 化简规则的离线发现成本高\(L_{\max}=7\) 需要 ~100h(32 线程),扩展到更长模式的成本呈指数增长
  • 评估仍限于 FastSRB:115 个表达式的规模有限,在更复杂的真实科学场景中的表现有待验证
  • 改进方向:训练时加入噪声数据、探索更宽的生成分布、尝试替代的编码/解码范式(如扩散模型)

相关工作与启发

  • NeSymReS / E2E:先前摊销 SR 代表工作,分别受限于静态数据集和未化简训练,本文统一解决了两者的瓶颈
  • PySR:当前遗传编程 SOTA,在中等计算预算下被 Flash-ANSR 追平甚至超越
  • 启发:将"化简"视为独立的可摊销组件,而非必须在线求解的子问题,这种解耦思路值得在其他涉及符号操作的 ML 系统中借鉴