跳转至

Caracal: Causal Architecture via Spectral Mixing

会议: ICML 2026
arXiv: 2605.00292
代码: 见论文 Appendix E
领域: LLM效率 / 序列建模 / 长上下文
关键词: FFT、注意力替代、因果建模、长序列、SSM 对比

一句话总结

Caracal 用 \(\mathcal{O}(L \log L)\) 的多头傅立叶(MHF)模块替换 Transformer 的 \(\mathcal{O}(L^2)\) 注意力,通过"pad-FFT-multiply-iFFT-truncate"实现频域内的严格因果掩码,并完全去掉位置编码,仅用标准 FFT 算子(不依赖 Mamba 那样的 CUDA kernel)就在 Tiny→Large 全尺度上与 Llama / Mamba / Mamba-2 / Jamba 性能相当。

研究背景与动机

领域现状:长序列建模有两条主流路线 —— Transformer 的注意力(强表达力但 \(\mathcal{O}(L^2)\) 复杂度且需位置编码);以 Mamba 为代表的 SSM(线性复杂度但靠定制 CUDA 内核,可移植性差)。傅立叶系(FNet、AFNO、SPECTRE)有 \(\mathcal{O}(L \log L)\) 复杂度,但因频域因果掩码难写,几乎都局限在 encoder-only。

现有痛点:(1) sparse attention(Longformer/BigBird)牺牲信息覆盖;(2) RoPE/YaRN/ALiBi 等位置编码都是"打补丁",外推性总有上限;(3) Mamba 类要写 SSD-style 算子,门槛高、调试困难、不同 GPU 上行为不一致;(4) 已有谱方法(FNet、Hyena)要么不因果要么 filter 是静态 position-based 的,缺乏 data-dependent mixing。

核心矛盾:自回归生成的因果性约束与 FFT 的"全局原子运算"天然冲突 —— 注意力可以中途把 weight 矩阵的 upper-triangle 置零,但 FFT 没有显式权重矩阵可掩。要因果就只能改输入(对每个 \(t\) 跑长度 \(t\) 的 FFT),结果反而比 \(\mathcal{O}(L^2)\) 还慢(\(\mathcal{O}(L^2 \log L)\))。

本文目标:(1) 让 FFT-based mixing 在自回归训练里做到单次并行 forward 也保持因果;(2) 去掉位置编码并保持外推能力;(3) 只用标准 torch/numpy FFT 算子,不要 Mamba 那样的硬件依赖;(4) 引入 data-dependent gating 弥补 FFT 静态权重的表达力短板。

切入角度:作者从"频域乘法 = 时域因果卷积"的等价性出发:把输入 pad 到 \(2L\) → 做 FFT → 元素乘 → iFFT → truncate 回 \(L\),整条流水在数学上等价于一次严格因果卷积,但所有步骤都用并行 FFT 完成。

核心 idea:把注意力替换成"内容自适应卷积核 × FFT 加速 × 频域因果"的统一模块,并保留少量 sliding-window attention 用于局部精度。

方法详解

整体框架

Caracal 在结构上和 GPT-2 几乎一样,只做两处修改:(1) 把全局 masked multi-head attention 换成 MHF 模块;(2) 删掉位置编码(FFT 的正弦基天然带位置信息)。为保留局部精度,每两个 MHF 层之后插一个 Sliding-Window Attention (SWA) 层(窗口 256),整体复杂度仍是 \(\mathcal{O}(L \log L + L \cdot W)\)。Feed-forward / LN / 残差不变,能直接复用现有 Transformer 生态。

关键设计

  1. Multi-Head Fourier (MHF) 模块:

    • 功能:用频域乘法实现 token 之间的 \(\mathcal{O}(L \log L)\) 全局信息混合,且支持自回归。
    • 核心思路:4 步 pipeline。Step 1:用因果 depthwise 1D conv(kernel=3)注入局部归纳偏置,弥补移除位置编码后的局部模式损失。Step 2:LayerNorm 后并行投影出 value 流 \(x_v = \text{Linear}_V(x_{norm})\) 与 gate 流 \(x_g = \text{Conv1d}_{G2}(\sigma(\text{Linear}_{G1}(x_{norm})))\),gate 流的 group conv(\(n_{head}\) 组)实现 intra-head 通道交互。Step 3:把序列 zero-pad 到 \(N=2L\),做 FFT 拿 \(V_{fft}, G_{fft}\),频域元素乘 \(X_{fft} = V_{fft} \odot G_{fft}\),等价于时域因果卷积 \(r_t = \sum_{j=0}^{t} v_j g_{t-j}\)。Step 4:iFFT 后 truncate 回长度 \(L\) 去掉 padding 引入的"未来"伪信号,再过 \(\text{Linear}_O\)
    • 设计动机:把"注意力是 query/key 做的 data-dependent 权重 sum"改写成"gate 流当卷积核做的 data-dependent 权重 sum",保留 selectivity 的同时避开了 SSM 的串行 scan,全程用标准 FFT 算子。
  2. 频域因果掩码(pad-FFT-multiply-iFFT-truncate):

    • 功能:让 FFT 在保持并行性的同时严格满足"输出 \(t\) 只依赖 \(\leq t\) 的输入"。
    • 核心思路:纯 FFT 因果是数学上的硬骨头 —— 不能像 attention 那样掩权重。作者绕开了:把长度 \(L\) 序列右侧 zero-pad 到 \(2L\),做 FFT、乘 gate、iFFT 后只保留前 \(L\) 个元素。由于 \(2L\) 长 FFT 对应的圆卷积,在被 truncate 到前 \(L\) 维时退化为线性卷积 \(r_t = \sum_{j=0}^{t} v_j g_{t-j}\),对未来 token 的依赖被自动截掉。
    • 设计动机:把"看似无解"的因果性问题转化为对 padding/truncation 的几何安排,本质是用 \(2\times\) 序列长度换来 forward 一次完成因果卷积,训练时不需要为每个 \(t\) 单独跑 FFT。
  3. 去位置编码 + Hybrid SWA 局部回补:

    • 功能:彻底拿掉 RoPE/ALiBi 等显式位置编码,又用 SWA 保住局部分辨率。
    • 核心思路:FFT 的基 \(e^{-i \frac{2\pi}{L} tj}\) 内置序列位置信息,下游的 SWA 层也无需 PE。SWA 用 FlashAttention 实现,窗口 256,控制成本不爆炸。MHF:SWA 比例 2:1,足以兼顾全局长程依赖和局部短语级模式。
    • 设计动机:现代 PE(RoPE、YaRN)日益复杂仍解决不了外推根本问题;让模型从架构上自带位置感知,理论上更适合任意长上下文。

损失函数 / 训练策略

沿用标准 next-token prediction CE loss,没有架构外的辅助损失。训练用 GPT-3 风格 hyperparam 设置(Tiny 63M → Large 724M),所有 baseline 都启用硬件优化 kernel(Mamba 用 mamba_ssm、Llama 用 FlashAttention)。

实验关键数据

主实验

9 项 zero-shot common-sense 推理 + LM 评测(LMB / Hellaswag / ARC-e/c / Wino / BoolQ / PIQA / SIQA),4 个尺寸全 sweep:

Size Model LMB ppl↓ Avg acc↑
Tiny Llama (64M) 164.19 40.87
Tiny Mamba (66M) 129.88 41.12
Tiny Caracal (63M) 219.90 41.14
Small Llama (124M) 79.94 43.02
Small Mamba (129M) 86.33 43.60
Small Mamba2 (125M) 100.76 42.64
Small Caracal (120M) 92.05 43.35
Medium Llama (360M) 32.65 47.07
Medium Caracal (345M) 38.50 46.47
Large Llama (757M) 24.92 48.73
Large Caracal (724M) 29.39 49.01

Caracal 在所有 size 上 Avg accuracy 与 Llama / Mamba / Jamba 同档,Large 上 49.01 略超 Llama 的 48.73。

消融实验

与更广泛 baseline 在 345M 参数、15B token、4096 上下文设定下对齐:

Model LMB ppl↓ Avg acc↑
Transformer++ 41.08 42.92
RetNet 49.73 42.54
GLA 43.02 44.09
Mamba 40.21 43.59
Gated DeltaNet 30.94 45.42
Moneta 29.31 46.45
Yaad 29.11 45.94

Caracal 与 Mamba / DeltaNet 同处第一梯队,明显优于早期 Transformer++/RetNet。

关键发现

  • 算法上的"中间方案"取代硬件 trick:用 \(\mathcal{O}(L \log L)\) 换 SSM 的 \(\mathcal{O}(L)\),性能不掉但实现复杂度大幅降低,所有运算都是标准 FFT 算子。
  • Tiny 上 LMB ppl 偏高 (219.90) 是 Caracal 的弱点 —— 小模型容量下 dynamic gating 拟合不充分;但 Avg acc 仍并列第一,说明 ppl ≠ task 表现。
  • 去掉位置编码不掉点说明 FFT 基的隐式位置信息足够,给长上下文外推留下空间(论文未做直接外推实验,是个明显缺口)。
  • SWA 是必要的:消融显示纯 MHF 在 ARC-c 上偏弱,加入 2:1 比例的 SWA 后局部能力补齐。

亮点与洞察

  • 数学优雅的因果性 trick:pad-2L → FFT → multiply → iFFT → truncate 是经典 DSP 技巧的复用,但在生成式 LM 上下文里被首次完整论证、并配套了 data-dependent gating,把多年来 Fourier-based generative model 的"老大难"问题翻过去了。
  • "内容自适应卷积核"的统一视角:把 attention、SSM、FFT 都看作 \(r_t = \sum_j w_{tj} v_j\) 的不同 weight 来源 —— attention 是 query/key 算的,S4 是 static,Mamba 是 input-dependent state,Caracal 是 gate-generated content-aware filter。这种 framing 让人能清楚理解三类架构的本质异同。
  • 硬件无关是真正的工程价值。可以即插即用部署到任何带 FFT 的硬件(包括 TPU、专用 NPU),不像 Mamba 那样被绑死到 NVIDIA GPU。
  • 整体思路("频域乘法 + 因果 padding")可迁移到:speech autoregressive、长视频生成、protein generation 等所有需要因果性 + 长上下文的任务。

局限与展望

  • 作者自己承认理论 \(\mathcal{O}(L \log L)\) 慢于 SSM 的 \(\mathcal{O}(L)\),在 100k+ token 极长上下文下还是吃亏;论文也没做百万 token 级实验。
  • 没有显式 length extrapolation 实验,"FFT 基天然带位置"的卖点只是理论论证,没有 50k→200k 这种 zero-shot 拉长的对比。
  • 2L padding 浪费一半算力:实际 wall-clock throughput 是否真比 FlashAttention 强,要看具体 FFT 实现,论文没汇报针对短上下文 (1k–4k) 的真实速度对比。
  • 改进方向:(a) 用 RFFT (real FFT) 进一步减半算力;(b) 探索更激进的 MHF:SWA 比例(如 4:1)做超长上下文;(c) 把这套用到 image autoregressive 上做 sub-quadratic 自回归 ViT。

相关工作与启发

  • vs Mamba/Mamba-2:同样是 attention 替代品,但 Caracal 不需要硬件 kernel,可移植性强;性能在中小模型上打平。
  • vs Hyena:Hyena 也用 FFT,但 filter 是 position-based (由 MLP 从 \(t\) 生成),不是 content-aware;Caracal 的 gate 流由 input 动态生成,更接近 Mamba 的 selectivity。
  • vs FNet / FNO / AFNO:那些纯 encoder 模型完全不因果,无法做生成;Caracal 是首批严格因果的 FFT replacement。
  • vs Monarch Mixer:M2 用 GEMM 近似卷积追求硬件利用率,Caracal 用标准 FFT 追求实现简单;二者取舍不同。
  • vs FlashButterfly / SPECTRE:FlashButterfly 是 static global kernel,没有外推能力;SPECTRE 用 fixed sliding window 切断长程依赖;Caracal 通过 dynamic filter 解决这两个问题。

评分

  • 新颖性: ⭐⭐⭐⭐ 频域因果 + content-aware gating 的组合首次完整落地于自回归 LM,单独看每个零件都不算全新但拼出了优雅的新架构。
  • 实验充分度: ⭐⭐⭐ 4 个尺寸 sweep + 多 baseline 对比扎实,但缺真实长上下文 (≥32k) 与训练 throughput 的硬数据。
  • 写作质量: ⭐⭐⭐⭐⭐ 从注意力/FFT 第一性原理推到因果掩码困境,再到 pad-truncate trick,论证链条清晰,是一篇极适合教学的架构论文。
  • 价值: ⭐⭐⭐⭐ 给"非 NVIDIA 硬件"用户提供了一个真正 portable 的 SSM 替代方案,工业落地友好。