CAT: Circular-Convolutional Attention for Sub-Quadratic Transformers¶
会议: NeurIPS 2025
arXiv: 2504.06704
代码: 无(作者未公开)
领域: LLM/NLP / Transformer 效率
关键词: 循环卷积, FFT注意力, 次二次复杂度, softmax保持, EIT框架
一句话总结¶
CAT 将标准自注意力中的 \(N \times N\) 注意力矩阵替换为一个由 \(N\) 维向量生成的循环矩阵(circulant matrix),利用 FFT 实现 \(O(N \log N)\) 复杂度的注意力计算,在严格保持 softmax 行归一化结构的前提下,在 ImageNet-1k(avg pool 下 CLIP-L 准确率 0.694 vs 标准注意力 0.646)和 WikiText-103 masked LM(PPL 8.32 vs 9.82)上匹配或超越标准注意力。
研究背景与动机¶
领域现状:Transformer 的标准自注意力机制复杂度为 \(O(N^2)\),制约了长序列任务的可扩展性。现有高效注意力方案主要有两条路线:(1) Linear Transformer(如 Performer)用核函数近似 softmax,复杂度降至 \(O(N)\) 但牺牲了 softmax 结构,训练不稳定;(2) 稀疏注意力(如 BigBird、Longformer)只在局部 token 子集上应用 softmax,丢失全局上下文覆盖。
现有痛点:Linear Transformer 的核近似经常导致数值不稳定和精度下降;稀疏方法需要精心调节块大小、稀疏模式等超参数,且不同序列长度需要重新调优;Mamba 等 SSM 方法完全放弃注意力机制,与现有 Transformer 训练栈和推理优化(FlashAttention、KV Cache)不兼容。
核心矛盾:所有现有方法都在"降低复杂度"和"保持 softmax 注意力的表示能力/兼容性"之间做取舍——要么放弃 softmax 换取效率,要么保留 softmax 但引入依赖序列长度的超参数。
本文目标 找到一种方法,既严格保持全局 softmax 行归一化(与标准注意力数学形式一致),又实现低于 \(O(N^2)\) 的计算复杂度,且不引入额外超参数。
切入角度:作者观察到注意力矩阵的每一行本质上是对所有 token 的加权,而循环矩阵(circulant matrix)天然满足"每行是相同权重集合的循环移位"这一结构——且循环矩阵的乘法可以通过 FFT 在 \(O(N \log N)\) 内完成。
核心 idea:将 \(Q K^\top\) 生成的 \(N \times N\) 注意力矩阵替换为由单向量 softmax 后生成的循环矩阵,利用 FFT 实现次二次复杂度的全局注意力。
方法详解¶
整体框架¶
CAT 作为标准多头自注意力的 drop-in 替换。输入 \(\mathbf{X} \in \mathbb{R}^{N \times D}\),标准注意力需要 \(\mathbf{W_Q}, \mathbf{W_K}, \mathbf{W_V}\) 三个投影矩阵(共 \(3D^2\) 参数),CAT 仅需 \(\mathbf{W_A} \in \mathbb{R}^{D \times 1}\)(注意力投影)和 \(\mathbf{W_V} \in \mathbb{R}^{D \times D}\)(值投影),共 \((D+H)D\) 参数(\(H\) 为头数)。
关键设计¶
-
循环注意力矩阵构造:
- 功能:将 \(N \times N\) 注意力权重矩阵替换为由 \(N\) 维向量生成的循环矩阵
- 核心思路:用单一投影 \(\mathbf{Z} = \mathbf{X} \mathbf{W_A} \in \mathbb{R}^{N \times 1}\) 得到注意力得分向量,对其做行 softmax 得到 \(\mathbf{Z}^\star = \text{softmax}(\mathbf{Z})\)。然后构建循环矩阵 \(\text{circ}(\mathbf{Z}^\star)\),其第 \(i\) 行是 \(\mathbf{Z}^\star\) 的第 \(i\) 次循环移位。关键性质:\(\text{softmax}(\text{circ}(\mathbf{Z})) \equiv \text{circ}(\text{softmax}(\mathbf{Z}))\),循环矩阵与 softmax 可交换,所以每行仍然是合法的 softmax 归一化权重
- 设计动机:循环矩阵是"保持 softmax 行归一化的最简结构"——每行共享相同权重集合,只是位置不同。这同时赋予了模型显式的相对位置编码能力,因为注意力权重只依赖 token 间的相对偏移
-
FFT 加速计算:
- 功能:利用循环卷积定理将 \(O(N^2)\) 的矩阵乘法加速为 \(O(N \log N)\)
- 核心思路:循环矩阵乘向量等价于循环卷积。根据卷积定理,\(\text{circ}(\mathbf{Z}^\star) \mathbf{V} = \text{IFFT}[\text{FFT}(\mathbf{Z}^\star) \odot \text{FFT}(\mathbf{V})]\),其中 \(\odot\) 为逐元素 Hadamard 乘积。FFT/IFFT 的复杂度为 \(O(N \log N)\),Hadamard 乘积为 \(O(N)\),总复杂度 \(O(N \log N)\)
- 设计动机:这是唯一不近似、不截断、完全精确的次二次注意力计算方式——循环矩阵乘法与 FFT 在浮点精度内完全等价
-
CAT-Alter 混合架构:
- 功能:在网络中交替使用 CAT 层和标准注意力层
- 核心思路:将一半注意力层替换为 CAT,另一半保留标准注意力,交替排列。CAT 层提供 \(O(N \log N)\) 的高效计算和循环移位正则化,标准注意力层保留完全灵活的全局交互能力
- 设计动机:纯 CAT 在需要高度灵活全局交互的任务(如 causal LM with token pooling)中偶尔不如标准注意力,混合方案兼得两者优势。类似于 Jamba 等 Transformer/SSM 混合架构的思路
损失函数 / 训练策略¶
CAT 不改变训练损失——在视觉任务中使用标准交叉熵分类损失,在语言建模中使用标准交叉熵 LM 损失。训练时直接替换注意力层,使用相同超参数。ImageNet-1k 训练 50 epochs(4 × V100),WikiText-103 训练 50 epochs,均使用 AdamW。Causal LM 场景下需要对 \(\mathbf{Z}\) 做移位以保持因果性,但这会将复杂度退化回 \(O(N^2)\)。
实验关键数据¶
主实验¶
| 任务 / 模型 | Pool/LM类型 | 方法 | 可学习参数 | 复杂度 | 指标 |
|---|---|---|---|---|---|
| ImageNet CLIP-L | avg pool | Attention | \(3D^2\) | \(O(N^2)\) | Acc 0.646 |
| ImageNet CLIP-L | avg pool | CAT | \((D+H)D\) | \(O(N\log N)\) | Acc 0.694 (+4.8%) |
| ImageNet CLIP-L | avg pool | CAT-Alter | \((2D+H/2)D\) | \(O(N^2)\) | Acc 0.681 |
| ImageNet CLIP-L | token | Attention | \(3D^2\) | \(O(N^2)\) | Acc 0.574 |
| ImageNet CLIP-L | token | CAT-Alter | \((2D+H/2)D\) | \(O(N^2)\) | Acc 0.593 |
| WikiText-103 GPT-2 | masked | Attention | \(3D^2\) | \(O(N^2)\) | PPL 9.82 |
| WikiText-103 GPT-2 | masked | CAT | \((D+H)D\) | \(O(N\log N)\) | PPL 8.32 (-15.3%) |
| WikiText-103 GPT-2 | causal | Attention | \(3D^2\) | \(O(N^2)\) | PPL 27.84 |
| WikiText-103 GPT-2 | causal | CAT-Alter | \((2D+H/2)D\) | \(O(N^2)\) | PPL 27.68 |
消融实验¶
Q/K/V 参数化方式对比(CLIP-L, avg pool):
| 循环卷积变体 | 可学习参数 | 复杂度 | Acc ↑ |
|---|---|---|---|
| qkv (Averaged-Key) | \(3D^2\) | \(O(N\log N)\) | 0.696 |
| qv (CAT, 默认) | \((D+H)D\) | \(O(N\log N)\) | 0.694 |
| q only | \((N+H)D\) | \(O(N\log N)\) | 0.637 |
| v only | \((N+D)D\) | \(O(N\log N)\) | 0.625 |
| 标准 Attention | \(3D^2\) | \(O(N^2)\) | 0.646 |
关键发现¶
- avg pooling 下 CAT 大幅领先标准注意力:CLIP-L 上提升 4.8%(0.694 vs 0.646),说明循环卷积的移位结构在全局 token 混合场景下特别有效
- masked LM 场景增益惊人:GPT-2 上 PPL 从 9.82 降至 8.32(-15.3%),Transformer-XL 上从 13.94 降至 10.28(-26.3%),暗示循环结构对 mask 预测有天然优势
- qv 合并是最佳 trade-off:与保留完整 qkv(仅多 0.2%)精度相当,但参数量从 \(3D^2\) 降至 \((D+H)D\),减少约 2/3
- causal 场景是弱点:因果约束迫使 CAT 退回 \(O(N^2)\),且精度略低于标准注意力,需要 CAT-Alter 混合方案补救
- Linear Attention 训练崩溃:在 CLIP-L 上多次出现 NaN loss,佐证了 softmax 保持的重要性
亮点与洞察¶
- 循环矩阵与 softmax 的交换律是全文最精妙的理论洞察:\(\text{softmax}(\text{circ}(\mathbf{Z})) = \text{circ}(\text{softmax}(\mathbf{Z}))\),这意味着只需对 \(N\) 维向量做 softmax 就能得到 \(N \times N\) 的合法注意力矩阵,一步到位解决了"保持 softmax"和"降低复杂度"的矛盾
- "注意力系数从 \(N^2\) 降到 \(N\)"的视角:标准注意力需要运行时生成 \(N^2\) 个系数,CAT 只需生成 \(N\) 个然后循环移位。这不仅降低计算量,还可能起到正则化作用,避免过拟合
- 对 ViT avg pooling 趋势的呼应:近期研究发现 avg pooling 在 ViT 中经常优于 class token pooling,CAT 恰好在 avg pooling 场景下表现最强,暗示循环结构与全局 token 混合有深层亲和性
局限与展望¶
- causal LM 无法保持次二次复杂度:因果 mask 破坏循环结构,使 CAT 退化为 \(O(N^2)\)。这是最大限制——GPT 式生成模型无法从中受益
- 实际加速有限:FFT 实现在 \(N=256\) 时开销与理论增益抵消,gather-based 实现仍为 \(O(N^2)\)(虽然有 ~10% 加速)。需要定制 GPU kernel 或更长序列才能体现 \(O(N \log N)\) 优势
- 所有 token 共享相同注意力权重集合:循环矩阵的每行只是权重的移位,不能学习真正 token-specific 的注意力模式。这在需要高度异质化注意力的任务中可能是瓶颈
- 未在超长序列任务上验证:ImageNet ViT 的 \(N=256\),WikiText 的 \(N=256\),都很短。CAT 的理论优势需要在 \(N > 10000\) 的长文本/长视频任务上验证
- 无代码开源:可复现性受限
相关工作与启发¶
- vs Performer / Linear Attention: Performer 用随机特征近似 softmax 实现 \(O(N)\),但破坏了 softmax 结构导致训练不稳定。CAT 严格保持 softmax 且更稳定,代价是复杂度为 \(O(N \log N)\) 而非 \(O(N)\)
- vs BigBird / Longformer: 稀疏方法保持 softmax 但仅在 token 子集上,丢失全局覆盖且需要调参。CAT 覆盖所有 token 对且零超参数
- vs Mamba (SSM): Mamba 完全放弃注意力,与 Transformer 推理栈不兼容。CAT 是 drop-in 替换,与 FlashAttention、KV Cache 等完全兼容。有趣的是 CAT 的 qv 合并设计与 Mamba 的输入驱动控制有概念相似性
- 与 Jamba 等混合架构的呼应:CAT-Alter 的"交替混合"策略与 Jamba(SSM + Attention 混合)异曲同工,都说明单一机制难以在所有场景最优
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ 循环矩阵替代注意力矩阵的想法优雅且理论基础扎实,EIT框架定义有指导意义
- 实验充分度: ⭐⭐⭐ 任务覆盖视觉和语言,但序列长度太短,缺少长序列任务验证和误差棒
- 写作质量: ⭐⭐⭐⭐ 理论推导清晰,消融设计合理,但部分概念重复叙述略冗长
- 价值: ⭐⭐⭐⭐ 为高效注意力设计提供了新思路,但causal场景的限制降低了对LLM的实际适用性