Bayesian Test-Time Adaptation via Dirichlet feature projection and GMM-Driven Inference for Motor Imagery EEG Decoding¶
会议: ICLR2026
OpenReview: VDg6Pv4S3v
代码: 待确认
领域: 测试时适应 / 自监督 / EEG 脑机接口
关键词: 测试时适应, EEG 运动想象, Dirichlet 分布, 高斯混合, 贝叶斯校准
一句话总结¶
BTTA-DG 把每条 EEG 试次的逐时刻预测序列压成一个 Dirichlet 参数向量,用历史试次拟合的 GMM 当似然、深度模型输出当先验,做一次无梯度的贝叶斯后验校准,在运动想象脑机接口的跨被试/跨 session 迁移上达到 SOTA 且实时(15.7 ms/试次)。
研究背景与动机¶
领域现状:基于 EEG 的运动想象(Motor Imagery, MI)脑机接口靠解码感觉运动节律来控制外设。近年大规模 EEG 预训练模型把表征学习推得很高,但真要部署时会撞上 EEG 信号的非平稳性——同一套模型换个被试、换个 session,分布就漂了(cross-subject / cross-session shift),不微调几乎没法用。
现有痛点:测试时适应(TTA)本来是个对症的方向——推理阶段用无标签的在线数据现场调模型。但已有的 EEG-TTA 分两派,各有硬伤:① 梯度派(Tent 的熵最小化、伪标签、一致性正则、T-TIME、OTTA 这类)要反向传播更新参数,计算贵,而且在 EEG 在线场景里 batch size 常常是 1,单条噪声试次就会产生误导性梯度,把预训练结构覆盖掉,引发灾难性遗忘;② 数据对齐派(BN 统计重算、Euclidean Alignment 这类)不更新参数所以快,但只是浅层统计对齐,抓不住深层"时序预测嵌入"在新域里到底怎么变形了。
核心矛盾:既要计算高效、不毁预训练权重(避免灾难性遗忘),又要能建模深层的分布变化、还得有理论依据——梯度派和对齐派各占一头,没人同时占全。
切入角度:作者的关键观察是——域漂移更可靠地体现在"模型逐时刻预测序列的分布"里,而不是任何单次预测里。一条 EEG 试次经过编码器会输出一串随时间变化的类别概率向量 \(X=[x_1,\dots,x_T]\),与其盯着平均后的那个点估计,不如把整条概率轨迹的"集中程度"建模出来。
核心 idea:用 Dirichlet 分布("分布的分布")把这串时序概率轨迹投影成一个低维参数 \(\alpha\),再用 GMM 对历史 \(\alpha\) 建密度、用贝叶斯推断把它和模型先验融合——整个适应过程无梯度,只校准输出不动权重。
方法详解¶
整体框架¶
BTTA-DG 处理的是在线、单试次到达的跨被试 TTA:源域是若干被试的有标签 EEG,目标域是某一个被试逐条到来的无标签试次,目标是不碰源数据、不要目标标签就把预训练模型 \(f_\theta\) 适配过去。整条流水线是"轻量主干提特征 → 把时序预测轨迹投影成 Dirichlet 参数 → 用历史参数的 GMM 做贝叶斯后验校准"。
具体地:先用一个轻量的 SincAdaptNet 当主干,它的可学习 Sinc 带通滤波器把 MI 相关的 mu/beta/gamma 频段抽出来,编码器输出每个试次的逐时刻类别概率轨迹 \(X\in\mathbb{R}^{|L|\times T}\),时间平均后得到模型先验 \(f_{cls}(X)=\frac1T\sum_j x_j\)。然后 Dirichlet 特征投影 用极大似然把整条轨迹 \(X\) 压成一个 \(\alpha\in\mathbb{R}^{|L|}_+\),每个分量 \(\alpha_i\) 是对类 \(i\) 的"证据浓度"、总尺度 \(\alpha_0\) 反映整体不确定性。最后 GMM-驱动的贝叶斯推断 把历史高置信试次的 \(\alpha\) 按预测类别存进 memory bank、每类拟合一个 GMM 当似然 \(p_{GMM}(\alpha\mid y)\),和模型先验 \(p_\theta(y)\) 相乘归一化得到校准后验,取 argmax 作为最终预测,再把当前 \(\alpha\) 按置信度/熵阈值回写进 memory bank。
%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400}}}%%
flowchart TD
A["在线 EEG 试次<br/>(batch=1)"] --> B["SincAdaptNet 主干<br/>可学习 Sinc 带通滤波"]
B --> C["逐时刻预测轨迹 X<br/>+ 时间平均得先验 pθ(y)"]
C --> D["Dirichlet 特征投影<br/>MLE 把 X 压成 α"]
D --> E["GMM-驱动贝叶斯推断<br/>历史 α 拟合 GMM 当似然"]
C -->|先验 pθ(y)| E
E -->|校准后验 argmax| F["最终预测 ŷ_cal"]
E -->|高置信/低熵才回写| G["memory bank 更新<br/>+ EM 重拟合 GMM"]
G -.历史 α.-> E
关键设计¶
1. SincAdaptNet 主干:用可解释带通滤波替自由卷积核,给后续投影喂干净特征
EEG 单试次在线适配里 batch size 经常是 1,普通卷积主干既依赖 batch 统计、学到的核又是黑箱,频段乱混。SincAdaptNet 用 Spat-Conv → Sinc-Conv → IncCh-Conv → Cls-Conv 的结构,并在时序滤波和通道扩展后插 LayerNorm(而非 BatchNorm),从根上摆脱单样本下 batch 统计不稳的问题。核心是 Sinc-Conv:受 SincNet 启发,它不学自由形态的卷积核,而是只学每个滤波器的低截止频率 \(f_{low}\) 和带宽 \(f_{band}\)(于是 \(f_{high}=f_{low}+f_{band}\)),据此生成一个加窗 sinc 带通核。这样用极少参数就把 MI 相关的 mu(8–13 Hz)、beta(13–30 Hz)、gamma(>30 Hz) 节律分离出来,频谱含义清楚(实验里 Table 6 显示约 25% 滤波器落在 μ 带、31% 落 β 带、36% 落 γ 带,<8% 落在 MI 节律之外)。受 SwAV 启发,编码器输出经 softmax 映射到归一化概率空间,得到每个时刻的瞬时类别概率向量 \(x_j\in\Delta^{|L|-1}\),整条轨迹既是后续 Dirichlet 投影的输入,时间平均又给出贝叶斯推断要用的先验。
2. Dirichlet 特征投影:把一整条时序预测轨迹压成一个可解释的浓度向量
浅层对齐抓不住"深层预测分布在新域里怎么变形",而单次预测又太吵。作者的解法是把 Dirichlet 分布——"类别分布上的分布"——引进 EEG-TTA:假设一条试次的 \(T\) 个时刻概率向量 \(x_j\) 是 \(\mathrm{Dir}(\alpha)\) 的 i.i.d. 采样,于是用一个投影 \(\mathcal{P}:\mathbb{R}^{|L|\times T}\to\mathbb{R}^{|L|}_+\),通过极大似然把整条轨迹估成一个 \(\hat\alpha_{MLE}=\arg\max_\alpha\sum_{j=1}^T\log D(x_j;\alpha)\),其中
\(\hat\alpha_{MLE}\) 由 Minka 的定点迭代高效求解:\(\alpha_i^{new}=\psi^{-1}\!\big(\psi(\alpha_0^{old})+\frac1T\sum_{j=1}^T\log x_{ij}\big)\)(\(\psi\) 为 digamma 函数),且 \(\alpha\) 是估出来的、不是手调的。为什么有效:这个低维向量比一个类别标签信息量大得多——\(\alpha_i\) 编码了对类 \(i\) 的证据浓度、\(\alpha_0\) 编码了整条试次的预测不确定性/方差,等于把"模型预测了什么 + 有多自信、多一致"一起压进一个语义丰富的向量;新域带来的信号差异会让这些预测分布产生位移,而 Dirichlet 投影正好把这种分布层面的漂移显式地、紧凑地刻画出来。可视化(Figure 3)佐证:Dirichlet 特征空间类间 KL 散度 >31.85、类内协方差 <0.27,类簇分得很开。
3. GMM-驱动的贝叶斯推断:用历史参数的密度当似然,无梯度地校准后验
有了每条试次的 \(\hat\alpha\),怎么"适应"而不动权重?作者维护一个按预测标签组织的 memory bank \(\mathcal{M}_y\),只存高置信试次的 \(\alpha\),对每一类用 GMM 做非参数密度估计,得到类条件似然
GMM 同时编码了历史试次的全局分布(各成分)和当前待校准试次的邻域信息(当前 \(\alpha\) 越靠近某簇,似然越大)。对当前试次,把 GMM 似然和深度模型先验 \(p_\theta(y)=f_\theta(s_i)\) 按贝叶斯公式融合:
最终预测 \(\hat y_{cal}=\arg\max_y p_{cal}(y\mid\hat\alpha_{MLE})\)。之后再把当前 \(\hat\alpha\) 按置信度阈值 \(\tau_{conf}\) 和熵阈值 \(\tau_{ent}\) 决定是否回写 memory bank:满了就丢最老的,只让高置信、低不确定性的试次进来,每条最多贡献 \(1/M\) 给 GMM,防止少数近期试次主导后验;每次插入后用标准 EM 在对应 bank 上重拟合 GMM(因为 Dirichlet 特征维度极低,全量 EM 开销很小,比在线增量聚类更简单且无精度损失)。为什么有效:整个校准只是密度估计 + 贝叶斯融合,完全绕过梯度优化,既躲开了灾难性遗忘,又避免了反传的计算开销,还自带理论依据。
损失函数 / 训练策略¶
源域上正常预训练 SincAdaptNet 分类器;测试阶段不做任何梯度更新——适应完全靠 Dirichlet 投影(MLE 定点迭代)+ GMM(EM 拟合)+ 贝叶斯后验融合完成。预处理只用 1–48 Hz 带通滤波 + Euclidean Alignment(EA)。关键超参为 GMM 成分数 \(K\)、置信度阈值 \(\tau_{conf}\)、熵阈值 \(\tau_{ent}\),三者在合理区间内都不敏感。
实验关键数据¶
主实验¶
在 MOABB 的三个 MI 数据集(BNCI2014001/002, BNCI2015001)和 SHU MI 上做跨被试 LOSO,所有方法独立跑 10 次取均值。
| 数据集(跨被试 LOSO) | 指标 | BTTA-DG | 之前最好基线 | 提升 |
|---|---|---|---|---|
| BNCI2014001 | Acc(%) | 78.70 | OTTA 77.58 | +1.12 |
| BNCI2014002 | Acc(%) | 80.29 | OTTA 78.29 | +2.00 |
| BNCI2015001 | Acc(%) | 77.92 | OTTA 76.20 | +1.72 |
| BNCI2014001 跨 session | Acc(%) | 86.50 | OTTA 83.91 | +2.59 |
值得注意的是,多个梯度派基线(Tent、PL)在 TTA 后反而掉点——因为在线单试次(batch=1)下噪声试次诱导误导梯度更新 BN、覆盖预训练结构(灾难性遗忘);BTTA-DG 冻结网络、在概率参数空间适应,规避了这一失败模式。
计算效率(Table 7,BNCI2014001):BTTA-DG 平均推理 15.7 ms/试次、141.6 MFLOPs,比 T-TIME 快 17.8%、比 OTTA 快 24.2%;只有纯重算 BN 统计的 BN-adapt(5.1 ms)更快但性能差一截。
消融实验¶
Table 8 逐组件拆解(多数据集均值):
| 配置 | 2014001 跨session | 2014002 | 说明 |
|---|---|---|---|
| SincAdaptNet (Source Only) | 80.62 | 76.40 | 主干基线 |
| BTTA-DG w/o EA | 81.88 | 77.55 | 去 EA,保留 Dirichlet+GMM,仍涨 → 模块自带适应力 |
| SincAdaptNet + EA | 82.33 | 78.05 | 标准域对齐 |
| + EA + GMM | 82.47 | 78.25 | 直接在均值概率上做 GMM,仅微涨 |
| + EA + Dirichlet | 84.04 | 78.88 | Dirichlet 投影是涨点主力 |
| BTTA-DG (Full) | 86.50 | 80.29 | 三件套齐活,比源模型绝对涨约 2–6% |
关键发现¶
- Dirichlet 投影是核心增益来源:在均值概率上直接套 GMM(+EA+GMM)几乎没用(82.47 vs 82.33),但把投影换成 Dirichlet(+EA+Dirichlet)一下涨到 84.04——说明价值在于"建模整条时序轨迹的分布",而非 GMM 本身。
- 超参不敏感:\(K\in[2,12]\) 精度稳定在 78.2–78.7%;提高 \(\tau_{conf}\)(0.53–0.65) 过滤低置信试次反而把精度从 78.0% 推到 78.7%;低熵阈值(0.65–0.70) 同样稳。
- 类不平衡下不崩反而专精:测试集类比从 1:1 推到 1:0.25 时,整体精度平滑下降(78.70%→68.77%),但少数类精度反而升(80.40%→85.19%)——模型会向稀有事件特化,对真实 BCI 场景是好性质。
- 生理可解释性:学到的 Sinc 滤波器累计响应在 mu(11.2 Hz)/beta(30.5 Hz)/gamma(55.3 Hz) 处能量集中,<8% 落在 MI 节律外,空间核也呈现 CSP 式的额/中央/顶/枕区拓扑。
亮点与洞察¶
- "适应点估计 → 适应分布表征"的范式转换最值得记:把每条试次从一个类别预测升级成一个 Dirichlet 浓度向量,等于同时编码了"预测什么 + 多自信 + 多一致",这比单次输出对域漂移鲁棒得多。这个"用低维分布参数刻画域漂移"的思路可迁到任何有时序/多样本预测的 TTA 任务。
- 把"适应"从优化问题改写成密度估计+贝叶斯融合,是绕开灾难性遗忘和反传开销的巧妙一招——冻结网络、只在一个极低维参数空间里做 GMM+EM,既快又稳,还自带概率解释。
- Sinc 带通主干用极少参数换来频谱可解释性,对需要生理可信度的医学/神经信号场景是可复用的好 trick。
局限与展望¶
- 依赖大致平衡的 EEG 数据:memory bank 上的密度估计在类别极端不平衡或强非平稳时会吃力,作者承认更极端的不平衡处理是未来工作。
- GMM 假设的局限:用高斯混合刻画 Dirichlet 参数密度,对高度多模态或重尾的真实分布是否够,论文未深究;\(K\) 虽不敏感但仍是预设。
- 未与大预训练 EEG 模型联用:作者把"BTTA-DG 叠加到大 EEG 预训练模型上看增益是否可加"列为未来方向,目前主干仍是轻量自研网络。
- 跨模态泛化仅是设想:把 Dirichlet+GMM 校准迁到 fNIRS/ECoG 等其他神经模态尚未验证。
相关工作与启发¶
- vs 梯度派 TTA(Tent / T-TIME / OTTA):他们靠熵最小化/伪标签反传更新参数,在线 batch=1 时易被噪声试次带偏、灾难性遗忘且慢;本文冻结网络、无梯度校准,既避免遗忘又快 17–24%,且性能更高。
- vs 数据对齐派(BN-adapt / Euclidean Alignment):他们只做浅层统计对齐,抓不住深层预测分布的形变;本文用 Dirichlet 投影显式建模整条时序轨迹的分布漂移,消融显示这正是涨点主力。
- vs SincNet / SwAV 的借鉴:主干借 SincNet 的可学习带通核拿到频谱可解释性,借 SwAV 的归一化概率空间思想把预测映射到 simplex 上再做 Dirichlet 建模——把两条已有思路嫁接到 EEG-TTA 这个新场景。
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ 首次把 Dirichlet 分布引入 EEG-TTA,用分布参数替点估计做无梯度贝叶斯校准,思路确实新。
- 实验充分度: ⭐⭐⭐⭐ 四数据集 + 跨被试/跨 session + 效率/可解释/类不平衡/敏感性多角度,扎实;但主干较轻、未与大预训练 EEG 模型对比。
- 写作质量: ⭐⭐⭐⭐ 动机—方法—理论分析链条清晰,公式与可视化到位。
- 价值: ⭐⭐⭐⭐ 实时、抗遗忘、可解释,对实际 BCI 部署有直接意义,方法范式也可外推到其他神经模态。