ASMIL: Attention-Stabilized Multiple Instance Learning for Whole-Slide Imaging¶
会议: ICLR 2026
OpenReview: https://openreview.net/forum?id=CYmjrbQRyM
代码: https://github.com/Linfeng-Ye/ASMIL
领域: 医学图像 / 计算病理 / 多示例学习
关键词: 全切片图像、多示例学习、注意力稳定、锚点模型、EMA、归一化 Sigmoid
一句话总结¶
本文首次识别出注意力 MIL 在全切片图像(WSI)训练中的"注意力动态不稳定"失败模式,提出 ASMIL:用 EMA 锚点模型蒸馏稳定注意力、用归一化 sigmoid 抑制注意力过度集中、用 token 随机丢弃缓解过拟合,三招合一在多个病理数据集上把 F1 提升最高 6.49%。
研究背景与动机¶
领域现状:全切片图像是吉像素级的病理大图,肿瘤区往往只占切片极小一部分,且像素级标注不现实,因此只能用切片级弱标签做多示例学习(MIL)。基于注意力的 MIL(ABMIL、TransMIL 等)把每个 tile 当成一个示例、用注意力加权聚合成 bag-level 表示,既是 WSI 分型的事实标准,其注意力热图还被当作临床可解释性证据。
现有痛点:注意力 MIL 已知有两个老毛病——(PII)注意力过度集中,模型把权重几乎全压在少数几个 tile 上,损害泛化与可解释性;(PIII)过拟合,因为 WSI 数据集每类往往只有几百张切片、tile 高度冗余。本文在四种代表性注意力 MIL、两个公开数据集上做实验,发现了第三个、且此前从未被报道的失败模式。
核心矛盾:(PI)注意力动态不稳定——对同一张 WSI,注意力分布在不同 epoch 间剧烈震荡而非收敛到稳定模式。作者用相邻 epoch 注意力分布间的 Jensen-Shannon 散度量化,发现 TransMIL 等的 JSD 持续大幅波动,对应更高的交叉熵和更差的性能。注意力不收敛意味着模型每个 epoch 关注的组织区域都在变,既掉点又破坏可解释性。
本文目标:用一个统一框架同时解决 PI、PII、PIII 三个问题。
核心 idea:引入一个与在线模型同构、但用 EMA 更新(不走反传)的锚点模型作为稳定参考——在线注意力去对齐锚点注意力即可获得稳定性;锚点分支用归一化 sigmoid 取代 softmax 来天然抑制过度集中;再叠加 token 随机丢弃做正则。
方法详解¶
整体框架¶
每张 WSI 切成 tile、经冻结的预训练编码器变成 vision token,与可训练的 FEAT token 一起送入在线编码器和锚点编码器两个分支。在线分支用 softmax 得注意力 α,锚点分支用归一化 sigmoid 得 α^nsf,两者间的 KL 散度作为稳定化损失 \(L_{AS}\) 把在线注意力拉向锚点。锚点用 stop-gradient + EMA 从在线模型更新,不走反传。训练时对 FEAT token 随机丢弃一部分,剩余 token 加 [CLS] token 送入第二个 transformer 得到 bag 表示并分类。总损失 \(L = L_{CE} + \beta L_{AS}\),推理时只用在线模型、丢弃锚点,因此不增加推理开销。
flowchart LR
A[WSI tiles] --> B[冻结编码器]
B --> C[Vision tokens + FEAT tokens]
C --> D[在线编码器 softmax → α]
C --> E[锚点编码器 NSF → α_nsf]
E -. KL散度 L_AS .-> D
D -- EMA + sg --> E
D --> F[随机丢弃 FEAT token]
F --> G[第二个 Transformer + CLS]
G --> H[分类器 → ŷ → L_CE]
关键设计¶
1. EMA 锚点模型稳定注意力:用"数据相关的功能正则"替代标量惩罚。 针对 PI,作者复制在线模型的注意力模块作为锚点,参数按 \(\theta'_t \leftarrow m\theta'_{t-1} + (1-m)\theta_t\) 做指数滑动平均更新,输入与在线模型相同但只读不反传。由于 EMA 天然平滑了参数的高频抖动,锚点给出的注意力分布比在线分支更稳更一致,于是把在线注意力 α 往锚点分布拉就能传递这种稳定性。作者特别论证为什么要用锚点而非熵/ℓ2/温度这类标量惩罚:标量惩罚是内容无关的、只作用于当前 batch,编码不了示例间的关系结构;而 EMA 锚点给出的是条件于整个 bag 的、数据相关的目标分布,让在线注意力贴近它相当于做了一种能捕捉 inter-instance 关系的功能正则,这是标量正则做不到的。
2. 锚点分支用归一化 sigmoid(NSF)抑制过度集中:可证明的"选择性展平"。 针对 PII,作者把过度集中归因于 softmax 的指数特性——少数 token 的分数经指数放大后会吞掉其余 token 的权重。NSF 定义为 \(\alpha^{nsf}_i(z) = \sigma(z_i) / \sum_j \sigma(z_j)\),其中 \(\sigma\) 是 sigmoid。由于 sigmoid 对大正值饱和到 1、对负值压向 0,NSF 能把真正有信息的"高分" token 拉平到接近相等、同时把"低分" token 压到很小。论文用 Theorem 1 给出严格界:对高分集合内任意两 token,其 NSF 权重比 \(\le 1 + e^{-\tau}\)(随阈值 τ 增大趋于 1),低分 token 权重 \(\le e^{-\tau}/h\);并证明 softmax 用单一温度无法同时满足"压制低分"和"展平高分"两个目标(两个温度约束在 \(\frac{\gamma}{\log\kappa} > \frac{2\tau}{\log(h/\varepsilon)}\) 时不可兼得)。关键细节:NSF 不能直接用在在线模型上——会导致梯度消失而掉点(在线模型靠 softmax 学习),所以 NSF 只放在锚点分支当稳定先验,通过 KL 间接引导在线模型。该 KL 对在线分数 \(z_i\) 的梯度极简洁:\(\frac{\partial KL(\alpha^{nsf}\|\alpha)}{\partial z_i} = \alpha_i - \alpha^{nsf}_i\),即梯度下降直接把在线注意力推向锚点分布。
3. FEAT token 随机丢弃缓解过拟合:专为 ASMIL 设计的 token 级正则。 针对 PIII,作者把 N 个可训练 FEAT token(\(N \ll M\),M 为 tile 数,相当于通过 token reduction 做信息聚合)在训练时按 Bernoulli 掩码以比例 \(B \in [0,1)\) 随机丢弃,保留集大小 \(\tilde{N} \sim \text{Binomial}(N, 1-B)\);推理时不丢(B=0)保留全部内容。这种随机移除阻止 FEAT token 之间的协同适应、防止模型过度依赖某几个 token,消融显示 B≈0.5 附近性能峰值最稳。值得注意的是,由于 ASMIL 的注意力对齐假设在线/锚点 token 间一一对应,像 MIL-Dropout 那类通用实例丢弃方法无法直接套用,所以作者专门设计了这套作用在可训练 token 上的丢弃方案。
实验关键数据¶
数据集:CAMELYON-16 / CAMELYON-17 / BRACS 三个公开 WSI 分型数据集;两种特征 backbone(ImageNet 预训练 ResNet-18、域内 SSL 预训练 ViT-S);对比 11 个注意力 MIL 基线。
主实验表格(ViT-S SSL backbone,F1 / AUC)¶
| 方法 | CAM-16 F1 | CAM-17 F1 | BRACS F1 | BRACS AUC |
|---|---|---|---|---|
| ABMIL (ICML18) | 0.914 | 0.522 | 0.680 | 0.866 |
| TransMIL (NeurIPS21) | 0.922 | 0.554 | 0.631 | 0.841 |
| DTFD-MIL (CVPR22) | 0.948 | 0.627 | 0.612 | 0.870 |
| ACMIL (ECCV24) | 0.954 | 0.562 | 0.722 | 0.888 |
| AEM (MICCAI25) | 0.947 | 0.647 | 0.742 | 0.905 |
| HDMIL (CVPR25) | 0.958 | 0.571 | 0.717 | 0.874 |
| ASMIL (Ours) | 0.965 | 0.689 | 0.781 | 0.914 |
ASMIL 在 ViT-SSL backbone 下全数据集 SOTA:BRACS F1 比次优高 3.9 个点、CAMELYON-17 F1 提升 6.49%(肿瘤稀疏、弱监督最难的场景增益最明显);ResNet-18 特征下也与最强基线持平或更优。
消融实验表格¶
组件消融(BRACS):
| Anchor | NSF | rd | F1 | AUC |
|---|---|---|---|---|
| ✓ | ✓ | ✓ | 0.781 | 0.914 |
| ✓ | ✓ | ✗ | 0.765 | 0.903 |
| ✓ | ✗ | ✓ | 0.759 | 0.895 |
| ✓ | ✗ | ✗ | 0.747 | 0.887 |
| ✗ | ✗ | ✓ | 0.728 | 0.868 |
| ✗ | ✗ | ✗ | 0.712 | 0.860 |
三组件缺一不可,其中锚点模型贡献最大(从 0.712 加到 0.747)。
即插即用消融(Table 2):把 Anchor + NSF 加到现有方法上一致涨点,ABMIL 在 BRACS 上 F1 增益最高达 10.73%;仅 DSMIL 在 CAMELYON-16 上加 anchor 时 F1 微降 0.001。
关键发现¶
- 注意力不稳定是真实且普遍的失败模式,JSD 持续震荡与高交叉熵相关;ASMIL 让注意力稳定收敛并一致高亮癌变区。
- 锚点 + NSF 是可迁移的通用插件,给老方法直接涨点。
- 定位任务(FROC / Dice / 切片级特异度)上 ASMIL 热图比基线更完整覆盖所有癌变区,归功于 NSF 减轻了过度集中。
亮点与洞察¶
- 发现新问题本身就是贡献:首次系统识别并量化"注意力动态不稳定",用 JSD 把一个模糊直觉变成可测指标,这类"指出大家都没注意到的失败模式"的工作往往启发力强。
- EMA 锚点 + 注意力 KL 对齐这一组合把自监督里的 teacher-student/动量编码器思想迁移到 MIL 注意力稳定上,且论证清楚它与 MHIM-MIL 的 EMA teacher 的本质差异(对齐注意力分布 vs 挖掘难样本)。
- NSF 的理论刻画:用一条定理证明"选择性展平"是 softmax 单温度做不到的,给"换激活函数"这种工程改动提供了扎实的数学理由。
- 锚点推理时丢弃,零额外推理开销,工程上很友好。
局限与展望¶
- 锚点 + NSF 假设在线/锚点 token 一一对应,导致通用实例丢弃(如 MIL-Dropout)无法直接整合,方法的正则手段被绑定到自家 token 设计上,灵活性受限。
- 引入了 EMA 因子 m、损失权重 β、FEAT token 数、丢弃率 B、锚点更新频率等多个超参,调参负担不轻(虽有消融但跨数据集稳健性需更多验证)。
- DSMIL 上加 anchor 出现微降,说明该插件并非对所有架构都正收益,何时失效缺乏先验判据。
- 实验集中在三个相对小的病理数据集,向更大规模、多癌种 WSI 的泛化性仍待验证。
相关工作与启发¶
- 注意力 MIL 谱系:ABMIL(可学习实例权重 + 热图)→ TransMIL(transformer 建模示例间关系)→ DSMIL(双流 max-pooling + 非局部注意力)→ CLAM(类特定注意力池化 + 实例聚类监督)。
- 对抗过度集中:ACMIL 随机掩 top-K 实例、AEM 加熵正则展平注意力、CAMIL 用邻域约束抑噪——ASMIL 走了"换激活函数 + 锚点蒸馏"的不同路线。
- 对抗过拟合:DTFD-MIL 拆 pseudo-bag、MHIM-MIL 用 EMA teacher 做难负挖掘、MIL-Dropout 随机移除 top-K 实例——ASMIL 的 token 丢弃与之互补。
- 启发:把"训练动态稳定性"作为单独的优化目标,而非只盯最终精度,是一个值得在其他弱监督/小样本任务里复用的视角;EMA 参考分支 + 分布对齐是稳定任何"会震荡的中间表示"的通用模板。
评分¶
- 新颖性: ⭐⭐⭐⭐ 首次识别并量化注意力不稳定这一新失败模式,锚点 + NSF + token 丢弃的组合有清晰动机和理论支撑,虽单个组件都有渊源但整合与问题定义新颖。
- 实验充分度: ⭐⭐⭐⭐ 三数据集、两 backbone、11 基线、组件消融 + 即插即用消融 + 定位任务齐全,附录还覆盖生存预测和非 WSI 数据;不足是数据集规模偏小。
- 写作质量: ⭐⭐⭐⭐ 问题(PI/PII/PIII)分类清晰,Theorem 1 论证严谨,图示直观,方法与动机一一对应。
- 价值: ⭐⭐⭐⭐ 锚点 + NSF 作为可迁移插件给现有方法直接涨点(最高 +10.73%),且零推理开销,对计算病理社区实用价值高。