Beyond Structure: Invariant Crystal Property Prediction with Pseudo-Particle Ray Diffraction¶
会议: ICLR2026
OpenReview: OfmurJrzlT
代码: https://github.com/Bin-Cao/PRDNet
领域: 应用于物理科学 / 晶体性质预测 / 材料信息学
关键词: 晶体性质预测, 倒易空间, 衍射表征, 伪粒子, 多模态融合
一句话总结¶
PRDNet 在传统图神经网络之外,引入一个可学习的"伪粒子"去模拟晶体衍射,用神经网络生成的形状因子(form factor)合成倒易空间的衍射图样,把图表示(短程)与衍射表示(长程)做模态级融合,同时严格满足晶体学对称不变性,在 Materials Project、JARVIS-DFT、MatBench 三大基准上刷新 SOTA。
研究背景与动机¶
领域现状:晶体性质(形成能、带隙、模量等)本质由量子力学方程决定,精确解(DFT)对大体系算不动,于是机器学习成了 DFT 的"廉价替身"。这类模型的精度高度依赖原子系统怎么表示,目前主流是把晶体当作图、做消息传递的图神经网络(CGCNN、MEGNet、Matformer、eComFormer、Crystalformer 等),近年又叠加键角、周期向量、等变/不变约束,越做越精细。
现有痛点:晶体在原理上是无限的三维周期体系,而真实空间编码器的感受野有限、又是局部编码,难以刻画长程原子相互作用。这导致一个致命问题——不同的晶体被映射到同一个表示(论文 Figure 2 给了多边图、键角嵌入、周期向量三类方法各自"撞表示"的反例)。表示一旦退化,下游性质预测必然出错。
核心矛盾:长程相互作用要么靠 DFT 那样的超胞/边界条件去硬算(贵),要么在真实空间里强行扩大感受野(仍然有限、且破坏对称不变性)。真实空间这条路天然受感受野的约束。
切入角度:换到倒易空间(衍射空间)。因为晶体有周期性,完整的衍射图样可以从单个真实空间原胞解析推导出来,无需构造大超胞——这天然紧凑又能编码长程信息。每个原子都沿特定晶面对衍射有贡献,理想衍射表示能无损嵌入完整的真实空间信息(Figure 1、Figure 3 用光衍射做类比说明倒易空间如何编码长程相互作用)。
为什么现有衍射类方法不够:EwaldMP、PotNet、ReGNet(ReciNet) 这些工作虽然引入了 Ewald 求和/傅里叶变换,但把它当成"傅里叶式信息融合"塞进逐层消息传递,忽略了形状因子(form factor)本应是由结构和探针唯一决定、不随层间聚合传播的物理量。而且传统 X 射线衍射用的是查表得到的固定形状因子,只依赖散射矢量 \(|Q|\) 和原子种类,无法区分同一元素处在不同局部化学环境的原子——而这恰恰决定材料性质。
核心 idea:用一个数据驱动的"伪粒子"当探针替代真实的 X 射线/电子/中子。它的形状因子由神经网络学习,显式依赖局部化学环境,对元素和环境变化更敏感;据此合成衍射图样,并把它作为模态(而非原子级特征)与图表示融合,同时保证对晶体学对称(旋转/反射/平移,E(3) 群)的完全不变性。
方法详解¶
整体框架¶
PRDNet 要解决的是"图表示抓不住长程、且会把不同晶体撞成同一表示"的问题。它的做法是双模态:一条路用图注意力建模短程原子环境,另一条路用伪粒子衍射建模长程周期信息,最后在模态层融合做性质预测。
具体流转是:晶体 \(M=(A,P,L)\) 先转成图并做嵌入,经过若干层晶体注意力(crystal attention)消息传递得到节点表示 \(h_i^{(L)}\);这些节点表示一方面被全局池化成短程特征 \(g\),另一方面送进一个形状因子层 \(\mathrm{MLP_{form}}\),为每个原子生成依赖局部环境的可学习形状因子 \(f_i^*\);再按一组系统选取的米勒指数 \(\mathcal{H}\) 计算结构因子,得到衍射特征张量 \(F_{\mathrm{concat}}\),映射为长程特征 \(d\);最后把 \(g\) 和 \(d\) 拼接融合 \(z_{\mathrm{fused}}\),输出性质 \(y\)。整条链路被证明对 E(3) 群不变。
%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400}}}%%
flowchart TD
A["晶体结构<br/>M=(A,P,L)"] --> B["图嵌入<br/>节点+边特征"]
B --> C["晶体注意力结构建模<br/>多头注意力·门控聚合"]
C --> D["伪粒子可学习形状因子<br/>MLP 生成 f*"]
D --> E["米勒指数选择与结构因子<br/>合成衍射图样"]
C -->|全局池化 g| F["多模态融合与 E(3) 不变性"]
E -->|衍射特征 d| F
F --> G["性质预测 y"]
关键设计¶
1. 晶体注意力结构建模:在图侧用多头注意力+门控抓短程环境
这是双模态里的"真实空间"分支,负责把每个原子的局部化学环境编码进节点表示,为后面生成依赖环境的形状因子打底。晶体转成图 \(G=(V,E)\),节点是原子的 one-hot 嵌入 \(h_i^{(0)}\),边特征拼接了径向基、球面基(含键角 \(\theta_{ijk}\))和距离:\(e_{ij}=\mathrm{RBF}(d_{ij})\oplus\mathrm{SBF}(\theta_{ijk})\oplus d_{ij}\)。消息传递层在常规图注意力上做了两点增强:一是把边特征 \(E_{ij}^{(h)}\) 一起投影进 query/key/value,注意力分数 \(\alpha_{ij}^{(h)}=\frac{q_{ij}^{(h)}\odot k_{ij}^{(h)}}{\sqrt{3d_h}}\) 显式吃进边信息;二是门控聚合,用 \(g_{ij}^{(h)}=\sigma(\mathrm{LayerNorm}(\alpha_{ij}^{(h)}))\) 当自适应滤波器调制每条消息 \(m_{ij}^{(h)}=W_{\mathrm{msg}}v_{ij}^{(h)}\odot g_{ij}^{(h)}\),再带残差更新 \(h_i^{(l+1)}=\beta_i\odot h_i^{(l)}+(1-\beta_i)\odot\mathrm{SiLU}(\mathrm{BN}(W_{\mathrm{concat}}m_i^{\mathrm{agg}}))\)。消融里"单头""去残差""去边特征"都会明显掉点,说明这三件套是结构建模的有效组件。
2. 伪粒子可学习形状因子:把"查表固定"的形状因子换成"依赖环境可学习"
这是全文最核心的创新,直接针对"同元素不同环境的原子被混淆"这个痛点。传统 X 射线的形状因子是查《国际晶体学表》得到的固定值,只依赖散射矢量大小和原子种类:\(f_i^{\text{X-ray}}=f_i^{\text{X-ray}}(|Q|,\,f_i^{\text{type}})\),所以对给定 \(Q\) 完全分不开处于不同局部环境的同种原子。PRDNet 设计一个不受真实粒子物理定律约束的伪粒子,把形状因子扩展为 $\(f_i^{\text{Pseudo}}=f_i^{\text{Pseudo}}\big(|Q|,\,G_\theta(\mathcal{G}),\,f_i^{\text{type}}\big),\)$ 其中 \(G_\theta(\mathcal{G})\) 正是上一步图网络学到的局部化学环境编码。实现上由专门的形状因子层把最终节点表示映射成各个 \((h,k,l)\) 方向的散射强度:\(f_i^*(H)=\mathrm{MLP_{form}}(h_i^{(L)})\in\mathbb{R}^{N_{hkl}}\)。关键在于它把形状因子物理上该有的三重依赖(原子种类、局部环境、衍射矢量)一个都不少地保留,而 EwaldMP 忽略了 \(G_\theta(\mathcal{G})\)、PotNet 连 \(G_\theta(\mathcal{G})\) 和 \(|Q|\) 都忽略、ReGNet 也没考虑这两者的依赖——这正是它们倒易表示不自洽的根源。
3. 米勒指数选择与结构因子计算:系统覆盖倒易空间并合成衍射图样
有了形状因子,还要选"在哪些方向衍射"并把它们累加成衍射图样。每个米勒指数三元组 \((h,k,l)\) 对应一条晶面方向、一个唯一散射矢量 \(Q\)。论文先取 $\(\mathcal{H}_0=\{(h,k,l)\in\mathbb{Z}^3:|h|,|k|,|l|\le C_{\max},\ \gcd(|h|,|k|,|l|)=1\},\)$ 用 \(\gcd=1\) 滤掉冗余的高阶反射、只保留基本反射,\(C_{\max}=8\) 控制倒易空间分辨率;再对全排列加正负号取闭包 \(\mathcal{H}=\{\pm\mathrm{perm}(h,k,l)\}\),保证指数集对所有晶体学操作封闭(这是后面不变性证明的前提)。然后对每个指数累加结构因子的实部、虚部: $\(\mathrm{Re}(F_{hkl})=\sum_{i=1}^N f_i^*\cos(2\pi\,\mathbf{h}\cdot r_i^T),\quad \mathrm{Im}(F_{hkl})=\sum_{i=1}^N f_i^*\sin(2\pi\,\mathbf{h}\cdot r_i^T),\)$ 其中 \(r_i\) 是原子 \(i\) 的分数坐标。把实虚部展平拼成衍射特征张量 \(F_{\mathrm{concat}}=\mathrm{flatten}(\mathrm{Re}\oplus\mathrm{Im})\in\mathbb{R}^{2N_{hkl}}\),这就是单个原胞解析得到的、紧凑且无需超胞的长程描述。
4. 多模态融合与 E(3) 不变性:模态级融合,且全程对称不变
论文反复强调衍射是整个结构的全局属性,必须在模态层融合,而不是当成原子级特征去拼。融合用 \(g=\mathrm{GlobalPool}(\{h_i^{(L)}\})\)、\(d=\mathrm{MLP_{diff}}(F_{\mathrm{concat}})\)、\(z_{\mathrm{fused}}=\mathrm{MLP_{fusion}}([g\oplus d])\) 三步。不变性方面:晶体学操作 \(g:r\mapsto R_g r+t_g\)(\(R_g\) 是整数幺模矩阵,\(\det R_g=\pm1\))作用下,坐标和米勒指数同步变换,结构因子只多出一个相位 \(\phi(g,h)=(R_g h)\cdot t_g\), $\(F_{g\cdot h}(\{g\cdot r_i\})=e^{2\pi i\,\phi(g,h)}F_h(\{r_i\}),\)$ 而 \(\phi(g,h)\) 恒为整数,故 \(e^{2\pi i\phi}=1\),衍射表示天然不变。再加上 \(g\) 依赖的几何量 \(d_{ij},\theta_{ijk}\) 在 E(3) 下不变,最终 \(z_{\mathrm{fused}}\) 对旋转、反射、平移完全不变。这相比前作只满足"基本"不变性、却仍可能撞表示,是一个更彻底的保证。
损失函数 / 训练策略¶
回归任务用 MAE、分类任务用准确率评估;基于 PyTorch 实现,在 RTX 3090 上训练,所有对比基线均沿用原开源设置不额外调参。PRDNet 参数量约 20.9M(明显大于多数基线)。
实验关键数据¶
主实验¶
在 Materials Project(122,959 条)上,PRDNet 在多数任务取得最低误差与最高分类准确率:
| 数据集/任务 | 指标 | PRDNet | 之前最好基线 |
|---|---|---|---|
| MP 形成能 | MAE eV/atom ↓ | 0.028 | 0.030 (Crystalframer) |
| MP 带隙 | MAE eV ↓ | 0.151 | 0.153 (eComFormer) |
| MP 体模量 | MAE log(GPa) ↓ | 0.035 | 0.047 (Crystalframer) |
| MP 剪切模量 | MAE log(GPa) ↓ | 0.108 | 0.111 (GATGNN) |
| MP 杨氏模量 | MAE log(GPa) ↓ | 0.104 | 0.106 (Crystalframer) |
| MP 金属/非金属 | Acc% ↑ | 93.3 | 92.7 (Matformer) |
在 JARVIS-DFT 与 MatBench 上同样全面领先或持平,例如 JARVIS 形成能 0.032、带隙(OPT) 0.140、总能量 0.032、带隙(MBJ) 0.267;MatBench 形成能 0.019、剪切模量 0.058、折射率 0.242,均为最优。
消融实验(Materials Project)¶
| 配置 | 形成能 | 带隙 | 体模量 | 金属分类Acc | 说明 |
|---|---|---|---|---|---|
| Default | 0.028 | 0.151 | 0.035 | 93.3 | 完整模型 |
| NoDiff | 0.041 | 0.361 | 0.081 | 81.9 | 去掉衍射模块,全面大幅掉点 |
| SingleHead | 0.040 | 0.318 | 0.067 | 89.3 | 单头注意力 |
| NoRes | 0.038 | 0.297 | 0.077 | 82.7 | 去残差连接 |
| NoEdge | 0.043 | 0.355 | 0.071 | 80.1 | 去边特征 |
关键发现¶
- 衍射模块贡献最大:去掉它(NoDiff)带隙从 0.151 飙到 0.361、分类从 93.3% 跌到 81.9%,印证"长程衍射表示"是性能主来源。
- 多头注意力、残差、边特征三者各自去掉都会掉点,说明短程结构分支也是必要的组件,而非可有可无的陪衬。
- 论文还做了一个有意思的扩展(Table 4):把衍射模块接到其它基线(CGCNN、SchNet、Matformer 等)上,多数任务都能带来提升(标 ↓ 表示误差下降),说明伪粒子衍射是一个可移植的增强模块,而不只对自家结构有效。
亮点与洞察¶
- 把"固定查表的形状因子"变成"可学习、依赖局部环境的伪粒子形状因子",这是最让人"啊哈"的一步:它用一个虚构探针绕开了真实粒子(X 射线/电子/中子)的物理局限,让同元素不同环境的原子也能被区分开。
- 从真实空间转战倒易空间抓长程:利用晶体周期性"单原胞即可解析出完整衍射"的特性,避免构造大超胞,紧凑又无损,是一个很物理的角度。
- 不变性证明很干净:靠米勒指数集对称封闭 + 结构因子相位恒为整数倍 \(2\pi\),直接得到 E(3) 不变,而不是靠数据增强硬学。
- 衍射模块可即插即用地增强其它 GNN,复用价值高——这套"先学环境敏感的形状因子、再合成结构因子做模态融合"的范式,可迁移到任何需要长程周期信息的材料/晶体任务。
局限与展望¶
- 参数量偏大:PRDNet 约 20.9M,远高于不少 <1M 的轻量基线,性价比/部署成本在论文里未充分讨论。
- \(C_{\max}=8\) 的米勒指数截断是个硬超参,决定倒易空间分辨率与衍射特征维度 \(N_{hkl}\),对不同体系是否需要重调、对精度-成本如何权衡,正文交代有限(细节在附录)。
- 对比中的 ReGNet(ReciNet) 因官方代码不可得而由作者自行复现(隐藏维度从 256 调到 304 以对齐参数量),该基线数字应带一点 caveat 看待。
- 伪粒子缺乏直接的物理可解释性——它强于判别,但"学到的形状因子对应什么物理量"还需更多分析(论文给了 CaF₂ 的案例研究在附录)。
相关工作与启发¶
- vs 多边图 / 键角 / 周期向量类(CGCNN、ALIGNN、Matformer、Crystalformer 等):它们都在真实空间里加结构信息,但受有限感受野与局部编码所限,仍会把不同晶体撞成同一表示;PRDNet 改到倒易空间,用衍射的全局性从原理上保证表示唯一。
- vs 倒易空间/长程类(EwaldMP、PotNet、ReGNet):前者把 Ewald 求和/傅里叶当成逐层"信息融合",且简化或丢掉了形状因子的物理依赖(环境 \(G_\theta\)、\(|Q|\));PRDNet 把形状因子当成由结构和探针唯一决定的不变量、在模态层一次性融合,三重依赖一个不少,因而表示更自洽、精度更高。
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ 用可学习伪粒子+倒易空间衍射做晶体表示,角度独到且物理动机扎实
- 实验充分度: ⭐⭐⭐⭐⭐ 三大基准多任务、系统消融、还验证衍射模块可移植增强其它基线
- 写作质量: ⭐⭐⭐⭐ 物理铺垫与公式完整,但部分关键超参与可解释性放在附录、正文略紧
- 价值: ⭐⭐⭐⭐⭐ 刷新 SOTA 且提供一个可复用的"形状因子→结构因子→模态融合"范式