Object-Centric World Models for Causality-Aware Reinforcement Learning¶
会议: AAAI 2026
arXiv: 2511.14262
代码: 无
领域: 强化学习
关键词: 以物体为中心的世界模型, 因果注意力, 基于模型的强化学习, Slot Attention, Transformer
一句话总结¶
提出 STICA 框架,通过统一的以物体为中心的 Transformer 架构实现世界模型、策略网络和价值网络,其中世界模型将观测分解为独立物体的隐状态进行 token 级动力学预测,策略和价值网络通过因果注意力机制估计 token 级因果关系实现因果感知决策,在 Safety Gym 和 OCVRL 基准上显著超越 DreamerV3 等 SOTA 方法。
研究背景与动机¶
深度强化学习(RL)在多个领域取得了成功,但仍需要大量的环境交互,在考虑实时操作和物理设备故障的真实世界任务中成本高昂。基于模型的 RL(MBRL)通过学习世界模型来在"想象"环境中优化策略,从而提高样本效率。
然而,现有世界模型面临以下关键挑战:
整体表示的局限性:从 Dreamer 系列到 STORM、IRIS 等最新 Transformer 世界模型,都学习环境的整体(holistic)表示。当环境高维、非平稳且包含多个物体及其复杂交互时,整体表示难以捕捉个体物体之间的重要关系和交互。
人类认知的启示:人类通过将环境分解为离散概念(如物体和事件)来感知环境,实现更高效和因果感知的决策。将这种认知机制融入世界模型有望让 RL 智能体在复杂场景中更有效地运作。
现有物体中心 MBRL 的不足:已有的物体中心 MBRL 方法要么需要随机探索的 episode(OODP、COBRA)、要么依赖监督学习(FOCUS、OC-STORM)或外部数据集预训练(SOLD),且大多不适用于非平稳环境和部分可观测设置(如第一人称视角)。STICA 是首个从观测直接提取物体中心表示、无需随机探索/监督/预训练的 MBRL 智能体。
方法详解¶
整体框架¶
STICA 由三个组件构成: 1. 以物体为中心的世界模型:Slot-based 自编码器 + Transformer 动力学模型 2. 因果策略网络:Transformer + 因果注意力 3. 因果价值网络:Transformer + 因果注意力
训练流程遵循标准 MBRL 范式:从真实环境收集经验 → 训练世界模型 → 在想象轨迹上训练策略和价值网络。
关键设计¶
- Slot-based 自编码器:将观测 \(o_t\) 分解为物体中心表示和背景表示。
编码器:使用 Slot Attention(Locatello et al. 2020)从观测中获得 \(n\) 个 slot \((s_t^1, ..., s_t^n)\),每个 slot 作为 128 维 logits 定义 16 个 8 类别的 categorical 分布,从中采样得到隐状态 \(z_t^i\)。另外定义一个可学习的、时间无关的背景隐状态 \(z_{BG}\),用于表示与时间无关的静态背景信息(如环境布局、智能体身体),使编码器能将动态物体与静态背景分开提取。
解码器:Spatial Broadcast 解码器从隐状态重建个体 RGB 图像 \(\hat{o}_t^k\) 和未归一化掩码 \(m_t^k\),背景 \(\hat{o}_{BG}\) 的掩码填充为 0。通过 softmax 归一化掩码后混合:\(\hat{o}_t = \sum_{k=1}^{n} M_t^k \odot \hat{o}_t^k + M_{BG} \odot \hat{o}_{BG}\)。
损失函数:\(\mathcal{L}_\phi^{ae} = \mathbb{E}_B[\frac{1}{T}\sum_{t=1}^{T}(\mathcal{J}_{rec.}^t + \alpha_1 \mathcal{J}_{ent.}^t + \alpha_2 \mathcal{J}_{cross}^t)]\),包含重建误差、熵正则化(防止分布退化为确定性)和交叉熵项(将提取的隐状态与动力学模型预测对齐)。由于 Slot Attention 提取顺序随机,交叉熵项在计算前需要重排索引以最小化 L1 距离。
- Transformer 动力学模型:以 Transformer-XL 为核心的聚合模型,接收历史奖励 \(r_{1:t-1}\)、隐状态 \((z_{1:t}^1, ..., z_{1:t}^n)\) 和动作 \(a_{1:t}\) 作为 token,通过因果掩码确保不访问未来时间步。创新点在于位置编码仅依赖时间 \(t\),不依赖物体索引 \(1,...,n\),保证输出对隐状态顺序的等变性。
预测器包括: - 隐状态预测器:\(p_\psi^{\hat{z}}(\hat{z}_{t+1}^k | h_t^k)\),categorical 分布 - 奖励预测器:\(p_\psi^{\hat{r}}(\hat{r}_t | h_t')\),正态分布 - 折扣因子预测器:\(p_\psi^{\hat{\gamma}}(\hat{\gamma}_t | h_t')\),Bernoulli 变量
损失函数:\(\mathcal{L}_\psi^{dyn} = \mathbb{E}_B[\frac{1}{T}\sum_{t=1}^{T}(\mathcal{J}_{cross}^{t+1} + \beta_1 \mathcal{J}_{rew.}^t + \beta_2 \mathcal{J}_{dis.}^t)]\)
- 因果注意力(Causal Attention)机制:这是 STICA 的核心创新。
动机:环境包含因果物体(与动作选择和价值估计相关,如目标、障碍物)和非因果物体(无关,如地板)。策略需关注奖励相关物体,价值估计可完全忽略非因果物体。
因果图矩阵 \(G\):定义物体间的因果关系,\(G_{i,j}\) 表示从物体 \(j\) 到物体 \(i\) 的因果关系存在性(1为策略/价值,2为因果物体,3为非因果物体)。
因果分数 \(p_t^k \in [0,1]\):通过 MLP 估计每个隐状态 \(z_t^k\) 代表因果物体的概率。基于此构建权重矩阵 \(W_t\),然后通过 \(W_t G W_t^\top\) 表示 token 间的因果关系。
因果注意力计算:\(\text{CA}_t = \text{Norm}\left(\text{softmax}\left(\frac{Q_t K_t^\top}{\sqrt{d}}\right) \odot W_t G W_t^\top\right) V_t\)
这用因果关系矩阵来缩放标准注意力权重,使策略和价值网络的每一层注意力都考虑因果结构。所有隐状态共享相同位置编码,确保输出对顺序等变。
损失函数 / 训练策略¶
- 策略学习采用 A2C + GAE
- 策略和价值网络各自独立的 Transformer
- 每个 Transformer 处理 \(n\) 个隐状态 token + 1个可学习的目标 token
- 端到端学习,无需预训练或监督
实验关键数据¶
主实验¶
Safety Gym 基准(8个3D任务,第一人称视角,非平稳目标,多物体交互):
| 方法 | PointGoal1 | PointGoal2 | PointBtn1 | PointBtn2 | CarGoal1 | CarGoal2 | CarBtn1 | CarBtn2 | 平均 | 归一化平均 |
|---|---|---|---|---|---|---|---|---|---|---|
| PPO | 5.00 | 4.43 | 4.94 | 3.78 | 3.15 | 1.46 | 0.95 | 1.51 | 3.15 | 1.00 |
| TWM | 9.89 | 8.87 | 2.79 | 0.23 | 16.89 | 17.25 | 3.79 | 7.40 | 8.39 | 3.83 |
| DreamerV3 | 19.13 | 13.64 | 4.01 | 4.16 | 16.32 | 15.30 | 5.20 | 3.87 | 10.20 | 4.06 |
| TD-MPC2 | 8.00 | 6.50 | 4.31 | 4.84 | 2.57 | 1.24 | 1.90 | 0.48 | 3.73 | 1.15 |
| STICA | 13.63 | 13.64 | 11.52 | 5.97 | 17.09 | 18.27 | 9.65 | 9.25 | 11.90 | 5.49 |
OCVRL 基准(成功率):
| 方法 | Obj. Goal | Obj. Interaction | Obj. Reaching |
|---|---|---|---|
| TWM | 0.727 | 0.080 | 0.772 |
| DreamerV3 | 0.677 | 0.156 | 0.697 |
| STICA | 0.737 | 0.333 | 0.867 |
消融实验¶
| 配置 | PointButton1 | CarButton1 | 说明 |
|---|---|---|---|
| STICA(完整) | 最优 | 最优 | 全部组件 |
| STICA w/o CA | 显著降低 | 显著降低 | 移除因果注意力 |
| STICA w/o BR | 轻微降低 | 轻微降低 | 移除背景分离 |
| STICA w/o CA+BR | 中等降低 | 中等降低 | 移除两者 |
| STICA w/o CA+BR+TP+TV | 接近 TWM | 接近 TWM | 进一步移除 Transformer 策略/价值网络 |
| TWM | 基线 | 基线 | 无物体中心表示 |
消融结论的优先级排序: 1. 因果注意力:最大提升(从 "w/o CA" 到完整 STICA) 2. Transformer 策略/价值网络:显著提升 3. 背景分离:适度提升 4. 物体中心世界模型单独使用:仅边际提升
关键发现¶
- Safety Gym 全面领先:STICA 在8个任务中7个取得最佳,归一化平均分 5.49 vs DreamerV3 的 4.06(+35.2%)
- Button 任务优势巨大:PointButton1 上 STICA 得分 11.52,DreamerV3 仅 4.01——因为 Button 任务需要在多个动态障碍物中识别目标按钮,正是物体中心+因果注意力的用武之地
- Obj. Interaction 任务:STICA 成功率 0.333 vs DreamerV3 的 0.156(+113.5%),因为需要智能体直接与物体交互,物体中心世界模型和因果价值网络协同作用
- PointGoal1 例外:该任务物体少且简单,整体表示足够,STICA 无显著优势
- 因果注意力可视化:价值网络几乎只关注奖励相关目标物体;策略网络主要关注目标但也适当关注其他物体——符合直觉
亮点与洞察¶
- 统一框架:世界模型、策略、价值网络全部使用物体中心 Transformer,架构高度统一优雅
- 因果注意力的显式建模:不只是用物体中心表示,还通过因果分数显式刻画"哪些物体对决策有因果影响",且通过注意力权重可视化实现可解释性
- 背景分离的端到端学习:通过可学习的时间无关背景隐状态 \(z_{BG}\),无需预训练即可实现背景与前景分离
- 位置编码设计:动力学模型中仅用时间位置编码(不用物体索引),保证对 slot 顺序的等变性——这对 Slot Attention 的随机提取顺序至关重要
- 首个无需额外监督的物体中心 MBRL:不需要随机探索、标注数据或预训练
局限与展望¶
- 因果注意力中的因果图 \(G\) 结构是预定义的(3类:目标、因果物体、非因果物体),灵活性有限
- 固定的 slot 数量 \(n\) 可能不适应物体数量变化的场景
- 仅在 Safety Gym(3D 第一人称)和 OCVRL(2D/3D 固定视角)上验证,未在更复杂的真实世界环境测试
- Slot Attention 的随机提取顺序需要重排索引来计算损失,增加了计算开销
- 因果分数 \(p_t^k\) 的估计依赖 MLP,在更复杂场景中是否准确有待验证
相关工作与启发¶
- Dreamer 系列(Hafner et al.):RSSM 世界模型的代表,STICA 直接对标 DreamerV3
- STORM(Zhang et al. 2023):Transformer 世界模型,学习整体表示
- TWISTER(Burchi & Timofte 2025):Transformer + 对比预测编码
- OCRL(Yoon et al. 2023):模型无关的物体中心 RL,将 Transformer 用于策略和价值网络但缺乏世界模型
- EIT(Haramati et al. 2024):Entity Interaction Transformer,类似物体中心策略但不含世界模型
- SlotFormer(Wu et al. 2023):物体中心视频预测模型,是 STICA 动力学模型设计的参照
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ (因果注意力 + 物体中心世界模型 + 背景分离的统一框架,设计非常优雅)
- 实验充分度: ⭐⭐⭐⭐ (Safety Gym + OCVRL 双基准测试,消融全面,可视化有说服力,但缺乏真实环境)
- 写作质量: ⭐⭐⭐⭐⭐ (结构清晰,公式严谨,图1的框架图非常直观)
- 价值: ⭐⭐⭐⭐⭐ (在物体中心 MBRL 方向上的重要推进,因果注意力有广泛适用性)