跳转至

Improving Sparse Autoencoder with Dynamic Attention

会议: CVPR 2026
论文: CVF Open Access
代码: https://github.com/qyj-bkjx/Sparsemax-SAE
领域: 可解释性 / 机制可解释性 / 稀疏自编码器
关键词: 稀疏自编码器, sparsemax, 交叉注意力, 动态稀疏, 概念解耦

一句话总结

这篇论文把稀疏自编码器(SAE)重写成一个共享概念向量的交叉注意力结构,并用 sparsemax 取代 softmax,让每个样本按自身复杂度自动决定激活几个概念,从而摆脱 TopK 里"K 该设多少"的老问题,在图像和文本上都拿到更低重构误差和更清晰的概念。

研究背景与动机

领域现状:大模型里的神经元是"多义"的(一个神经元同时响应多个无关概念,即 superposition)。稀疏自编码器把激活解耦成一组稀疏、单义、可解释的概念,是当前机制可解释性的主流工具。

现有痛点:SAE 的核心难题是"每个特征该用几个概念"——给太多概念损害可解释性,给太少又损害重构,两边都会让概念学得不好。现有激活函数各有硬伤:ReLU 系(含 GatedReLU/JumpReLU)要配 L1/L0 正则,且 L1 会导致 feature shrinkage(激活整体被拉向 0),平衡系数还得手调;TopK / BatchTopK 直接保留最大的 K 个概念、其余清零,省了正则但把 K 当超参,K 设错就会在复杂样本上漏概念、在简单样本上塞死概念(dead concepts)

核心矛盾:稀疏度本该是数据依赖的(复杂图像需要更多概念、简单图像更少),但 ReLU 的正则和 TopK 的固定 K 都是"一刀切"的全局设定,无法逐样本自适应。

本文目标:设计一种 SAE,让稀疏度按每个样本的内容复杂度自动确定,且只用重构损失训练、不依赖额外正则或 K 调参。

切入角度:作者注意到 sparsemax 这个激活——它把输入投影到概率单纯形上、能给低分项精确赋 0,且处处可微、有闭式阈值解。这正好和 SAE"稀疏激活"的诉求同源:阈值 \(\tau\) 可以由样本自身算出来,相当于逐样本的动态 K。

核心 idea:用交叉注意力框架重写 SAE(特征当 query、字典概念当 key/value,编码解码共享同一组概念向量),并把注意力里的 softmax 换成 sparsemax,让稀疏度动态自适应。

方法详解

整体框架

传统 SAE 是单层 MLP 编码-解码:\(z=\sigma(W_{enc}(x-b_{enc})),\ \hat{x}=W_{dec}z+b_{dec}\),把多义特征 \(x\in\mathbb{R}^d\) 解成 \(M\gg d\) 个概念 \(C=\{c_1,\dots,c_M\}\) 的稀疏组合,\(W_{dec}\) 的列就是概念,\(\sigma\) 决定稀疏模式。本文在两点上改造它:(1) 用交叉注意力把编码器和解码器用同一组概念向量连起来,而不是两个独立 MLP;(2) 把注意力里的 softmax 换成 sparsemax,让每个样本动态决定激活几个概念。整个模型只用重构损失训练,无需稀疏正则、无需调 K。这是一个单次前向的机制改造(非多阶段 pipeline),下面直接用公式讲清两个设计。

关键设计

1. Transformer 化 SAE:用共享概念向量连接编解码

针对传统 SAE 把 \(W_{enc}\)\(W_{dec}\) 当两个独立投影、导致编码权重和解码概念脱节的问题,作者把 SAE 重写成交叉注意力:把待学字典当一组概念向量,经投影同时充当 key 和 value;把每个潜在特征当 query,做交叉注意力得到重构特征:

\[Q=x^\top W_Q,\quad K=C^\top W_K,\quad V=C^\top W_V,\quad \hat{x}=\sigma\!\left(\frac{QK^\top}{\sqrt{d}}\right)V\]

注意力权重的计算天然就是 SAE 的"编码"阶段——它度量 query 与各概念的相关性分数(\(z\) 越高表示特征与概念在嵌入空间越近);用这些权重加权 value(同一组概念)就是"解码"。关键在于 key 和 value 来自同一个概念集 \(C\):编码时算相关性用的概念、解码时加权重构用的概念是同一批,于是权重和概念在加权(解码)阶段强协同,比 MLP 式 SAE 把 \(W_{enc}/W_{dec}\) 当两套独立参数更连贯,重构能力更强、概念质量更高。

2. Sparsemax 注意力:逐样本动态决定激活概念数

针对 TopK 把 K 写死、softmax 又输出稠密分布的问题,作者用 sparsemax 替换注意力里的 softmax。设 \(z=QK^\top\in\mathbb{R}^M\) 是 query 与 \(M\) 个概念的相似度,sparsemax 把 \(z\) 投影到概率单纯形上、取欧氏距离最近的点:

\[\text{sparsemax}(z)=\arg\min_{p\in\Delta^{M-1}}\|p-z\|^2\]

它的闭式解是软阈值 \(\text{sparsemax}(z)_m=\max(z_m-\tau,0)\),阈值 \(\tau\) 由"被选中项之和为 1"这一约束解出:把 \(z\) 降序排成 \(z_{(1)}\ge\cdots\ge z_{(M)}\),取 \(k=\max\{r: z_{(r)}+\frac{1-\sum_{i=1}^r z_{(i)}}{r}>0\}\),则 \(\tau=\frac{\sum_{i=1}^k z_{(i)}-1}{k}\)。和 TopK 设死阈值不同,这里的 \(\tau\) 是按输入内容复杂度动态算出来的:query 特征若包含多个概念,\(z\) 里会有很多接近的值、支撑集 \(S\)(即被激活的概念集)就大;若是纯概念,\(S\) 很小。论文示例里 sparsemax 会给复杂图像分配 6 个概念、给简单图像只分 2 个。sparsemax 处处可微、有良定义雅可比,能直接梯度优化(它其实是 \(\alpha\)-entmax 在 \(\alpha=2\) 时的特例),因此可以看作"样本级的 BatchTopK"——把 K 从 batch 级细化到样本级,更灵活也更准。

一个例子:复杂图 vs 简单图

对一张内容丰富的复杂图像,query 与字典里多个概念都高相关,\(z\) 中出现一批量级相近的大值,sparsemax 解出的阈值 \(\tau\) 较低、支撑集 \(S\) 较大(如激活 6 个概念);对一张内容单一的简单图像,只有少数概念高相关,\(z\) 里大值很少,\(\tau\) 相对更"切"、\(S\) 很小(如只激活 2 个概念)。同一个模型、同一组参数,激活数随图自适应——这是 TopK 固定 K 做不到的。

损失函数 / 训练策略

只用重构损失训练,不加任何稀疏正则、不调 K。视觉侧用 CLIP ViT-B/16,取倒数第二层注意力残差流输出,按 PatchSAE 设概念数 \(M=49152\)(ViT 隐维的 64 倍),ImageNet 上训练、batch 32、共喂 2,621,440 个 patch。文本侧用 GPT-2 Small,取第 8 层残差流,OpenWebText 上训练、序列长 128、batch 128、共喂 \(10^9\) token,字典 \(M\in\{3072,6144,12288,24576\}\)。统一 Adam,lr=\(3\times10^{-4}\)\(\beta_1=0.9,\beta_2=0.99\);对比基线按各自论文取 K=32(TopK 系)、稀疏权重 1e-3(ReLU 系)。

实验关键数据

主实验

文本重构用 NMSE(归一化均方误差,越低越好)和 CE degradation(把 GPT-2 中间特征换成 SAE 重构后输出的交叉熵退化,越接近 0 越好)衡量。下表为 OpenWeb 上不同字典规模 \(M\) 的 NMSE:

Method M=3072 M=6144 M=12288 M=24576
ReLU 0.064 0.064 0.064 0.059
JumpReLU 0.051 0.050 0.050 0.051
Gated 0.078 0.092 0.129 0.489
TopK 0.014 0.059 0.010 0.055
BatchTopK 0.014 0.061 0.060 0.060
Sparsemax SAE (Ours) 0.005 0.038 0.004 0.039

跨所有字典规模,Sparsemax SAE 的 NMSE 显著低于所有基线(在 WikiText-103 上同样如此),CE degradation 也更小——说明动态稀疏注意力既能把多义特征解成可解释概念,又能用更低信息损失重构输入。零样本图像分类(用 top-n 概念替换 ViT 中间嵌入做 11 数据集分类)上,Sparsemax SAE 在所有 top-n 设置(n=1/5/10/50)的平均表现最佳,尤其在极小 n(1/5/10)时明显领先次优。

消融实验

ImageNet 上拆开"transformer 架构"和"sparsemax 激活"两个贡献(top-n 概念分类准确率):

配置 on 1 on 5 on 10 on 50
ReLU SAE 3.12 15.83 22.17 34.87
Transformer + ReLU 3.86 16.85 24.08 36.33
MLP + Sparsemax 7.91 29.87 39.73 55.32
Sparsemax SAE (Ours) 10.93 33.47 42.13 59.95

关键发现

  • 两个设计都正向、且 sparsemax 是主力:从 ReLU SAE 单独加 transformer(→Transformer+ReLU)只小涨,单独换 sparsemax(→MLP+Sparsemax)大涨(on 1 从 3.12 升到 7.91),二者叠加(完整模型)最好(on 1 达 10.93),说明动态稀疏激活贡献最大、共享概念的交叉注意力架构进一步加成。
  • 能反哺现有 SAE 选 K:Sparsemax SAE 算出的逐样本稀疏度可作为 TopK 系 SAE 的调参向导(在 Food101 上,Sparsemax 的 on-1 准确率 26.11 远超固定 K=24/32 的 TopK/BatchTopK 的 0.99~8.64),即"动态 K"可以指导"固定 K"该设多少。
  • 概念更干净:可视化显示,相比 BatchTopK,Sparsemax SAE 学到的概念掩码图和 top-5 参考图更清晰、更可解释;在 EuroSAT、DTD 这类与预训练自然图差异大的数据上,SAE 概念甚至超过原始 CLIP,说明学到的概念有泛化性。

亮点与洞察

  • 把"稀疏度选择"从超参变成模型的内生计算:sparsemax 的阈值 \(\tau\) 有闭式解、随样本复杂度自动浮动,等于把 BatchTopK 的"batch 级 K"细化到"样本级 K",这是从"调参"到"自适应"的范式转变。
  • 共享概念向量连接编解码很巧:传统 SAE 的编码权重和解码概念是两套参数、容易脱节;用交叉注意力让 key/value 同源于一组概念,编码相关性和解码重构天然协同——这个"权重即相似度、字典即 key/value"的视角值得借鉴。
  • 可迁移点:sparsemax 替 softmax 这一招,凡是"需要稀疏、可解释、且稀疏度该随输入变"的注意力/路由场景(如 MoE 专家选择、检索 top-k)都可借用,且无需额外正则、处处可微好优化。

局限与展望

  • sparsemax 是 \(\alpha\)-entmax 在 \(\alpha=2\) 的特例,\(\alpha\) 写死,未探索可学习 \(\alpha\) 或其它 entmax 变体是否更优。⚠️
  • 阈值 \(\tau\) 需对相似度分数排序求解,字典 \(M\) 很大(如 49152)时排序开销和效率论文未深入分析。
  • 评估集中在 CLIP ViT 和 GPT-2 Small 两类中等规模模型,能否扩到更大 LLM、扩散模型、多模态 LLM 上仍待验证。
  • "动态决定概念数"虽灵活,但缺少对激活数稳定性 / 训练收敛性的系统分析,极端样本下是否会激活过多概念未讨论。

相关工作与启发

  • vs TopK / BatchTopK SAE:它们靠固定 K(样本级/batch 级)选概念,K 设错就漏概念或塞死概念;Sparsemax SAE 把 K 细化到逐样本动态阈值,无需调 K,重构和概念质量都更好,还能反过来指导它们选 K。
  • vs ReLU / GatedReLU / JumpReLU SAE:这些靠 L1/L0 正则促稀疏、L1 还引 feature shrinkage、平衡系数难调;本文只用重构损失、稀疏由 sparsemax 内生,无正则无平衡系数。
  • vs PatchSAE 等视觉 SAE:PatchSAE 等只是把 ReLU/TopK SAE 搬到视觉域、没改架构;本文提出全新的交叉注意力 SAE 架构 + sparsemax 激活,且图像文本通用。

评分

  • 新颖性: ⭐⭐⭐⭐⭐ 交叉注意力 SAE + sparsemax 动态稀疏,从根上换掉了 SAE 的稀疏选择机制
  • 实验充分度: ⭐⭐⭐⭐ 图像/文本双域、多字典规模、含架构与激活的拆分消融,但缺大模型与效率分析
  • 写作质量: ⭐⭐⭐⭐⭐ 动机清晰、sparsemax 推导完整、可视化有说服力
  • 价值: ⭐⭐⭐⭐ 解掉 SAE 选 K 痛点且能反哺现有方法,对机制可解释性社区实用