Trion: FFT-based Dynamic Subspace Selection for Low-Rank Adaptive Optimization of LLMs¶
会议: ICLR 2026
OpenReview: https://openreview.net/forum?id=TkHjRwbMNl
代码: https://github.com/IST-DASLab/Trion
领域: LLM效率 / 优化器 / 内存高效训练
关键词: 低秩优化, DCT, 动态列选择, 优化器状态压缩, FFT
一句话总结¶
本文用一个固定的离散余弦变换(DCT)正交矩阵 + 动态列选择来替代 GaLore/Dion 等低秩优化器里昂贵的 SVD/QR 投影,每层只需存 \(r\) 个整数索引而非完整投影矩阵,由此得到运行时与秩无关、内存最多降 25% 而精度不掉的两款优化器 Trion 与 DCT-AdamW。
研究背景与动机¶
领域现状:AdamW 是训练 LLM 的事实标准,但它要为每个参数维护两个动量 buffer,显存随模型规模线性膨胀。为了省下这块开销,出现了一条「低秩优化器」路线:GaLore 用 SVD 把梯度压到低维子空间里更新动量,后续的 LDAdam、FRUGAL、FIRA、Q-GaLore,以及最近基于动量正交化的 Muon、Dion,都沿用了「用矩阵分解求投影矩阵」这一核心套路。
现有痛点:这些方法的瓶颈都在 SVD/QR 分解本身。第一,分解要对每一个线性层单独做(有的每步做、有的每隔几步做),在大模型上计算量巨大;第二,求出的投影矩阵要逐层显式存下来,又吃掉一块显存;第三,像 Dion 用 QR 正交化、Muon 用 Newton-Schulz 迭代,它们的运行时还随秩 \(r\) 增长,秩调大就更慢。
核心矛盾:低秩压缩本想省显存省时间,但「为每层动态求一个最优正交基」这件事本身既贵又占地方——投影质量(要贴合当前梯度)和投影代价(SVD/QR 太重)之间存在 trade-off。
本文目标:找一个便宜、可移植、精度不掉的替代品来顶替 SVD/QR 求出的正交矩阵,让它能插进各种内存高效优化器里。
切入角度:作者观察到——我们其实不必为每层从零算一个正交基。可以预先固定一个「万能」正交矩阵(DCT),然后针对每层的梯度,从它的列里动态挑出最对齐的 \(r\) 列当投影矩阵。DCT 在 JPEG 图像压缩里早已证明能高效逼近能量集中的子空间,且它有 FFT 快速算法。
核心 idea:用「固定 DCT 矩阵 + 按对齐度动态选列」替代「逐层 SVD/QR」,把每层的投影矩阵存储从「一个稠密矩阵」压缩成「\(r\) 个列索引」。
方法详解¶
整体框架¶
方法的核心是一个与具体优化器解耦的子例程——动态列选择(Dynamic Column Selection):给定一个固定的 \(n\times n\) 正交矩阵 \(Q\)(取 DCT)和当前层的梯度/动量矩阵 \(G\),计算相似度矩阵 \(S=GQ\),按列的 \(\ell_1/\ell_2\) 范数排序,挑出最大的 \(r\) 列索引,用这些索引去 \(Q\) 里抽列就得到该层专属的投影矩阵 \(Q_r\)。整套流程只需一次矩阵乘 + 一次排序,DCT 矩阵全程只在训练开始时算一次、每张 GPU 存一份。
作者把这个子例程分别塞进两类主流优化器,得到两个独立优化器:Trion(改 Dion,用 DCT 选列替掉 Power-Iteration,再对低秩动量做 Newton-Schulz)和 DCT-AdamW(改 LDAdamW 一类低秩 AdamW,用 DCT 投影替掉 SVD,并可选 8-bit 量化误差反馈)。两者共享同一个「动态列选择」内核,区别只在外层优化器逻辑。
下图以 Trion 为例展示一个训练步的数据流:
%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400}}}%%
flowchart TD
A["梯度 G + 上一步动量<br/>累加得 B"] --> B["相似度矩阵 S = B·DC<br/>(Makhoul FFT 或 matmul)"]
B --> C["动态列选择<br/>按范数选最对齐的 r 列索引"]
C --> D["索引 DCT 得投影矩阵 Q_r<br/>抽出低秩动量 b"]
D --> E["误差反馈<br/>把投影残差累回动量"]
D --> F["Newton-Schulz 正交化<br/>只对 r×r 的低秩 b 做"]
F --> G["升维回原尺寸 O = o·Qᵀ<br/>更新参数"]
关键设计¶
1. 动态列选择:把"求投影矩阵"变成"挑列索引"
这是替掉 SVD/QR 的关键一招,直击「逐层分解又贵又占显存」的痛点。给定固定正交矩阵 \(Q\in\mathbb{R}^{n\times n}\) 和梯度 \(G\in\mathbb{R}^{n\times n}\),先算相似度矩阵 \(S=GQ\),它第 \(i\) 列装的是 \(G\) 各行与 \(Q\) 第 \(i\) 列的内积——即各列基向量与当前梯度的「对齐度」。然后按每列的 \(\ell_1\) 或 \(\ell_2\) 范数排序,取最大的 \(r\) 个列索引 \(i_t\),用它们去 \(Q\) 里抽列就得到该层投影矩阵 \(Q_r\),把梯度投到 \(r\) 维:\(g=GQ_r\)。
它的妙处在于「动态」体现在索引集合上而非矩阵本身:每层从 \(\binom{n}{r}\) 种可能的列组合里挑最贴合当前梯度的那组,所以投影会随训练中梯度的变化而变;但代价侧只需存 \(r\) 个整数,而不是一个 \(C\times r\) 的稠密投影矩阵。配合 GaLore 的标准做法(把矩阵较小的那一维压到 \(r\),按层形状选左投影或右投影),整个模型在每张 GPU 上的额外显存就只剩「一份 \(d_{model}\times d_{model}\) 的 DCT 矩阵 + 每层 \(r\) 个索引」。第 4 节还从理论上论证了这种按对齐度选列能最小化投影误差。
2. 用 DCT 当固定正交矩阵 + Makhoul FFT 算法加速对齐
为什么固定矩阵选 DCT 而不是随便一个正交阵?因为 DCT 不仅在信号/图像压缩里被证明能很好地集中能量,更关键的是它有结构、能用快速算法算。本文用 DCT-II/III,矩阵元素 \(Q_{ij}=\sqrt{2/n}\cdot\cos\frac{i(2j+1)\pi}{2n}\)(第一行除以 \(\sqrt 2\) 以保证 \(Q^\top Q=I_n\) 正交),训练前一次性在 GPU 上物化。
算相似度矩阵 \(S=GQ\) 本来要 \(O(n^3)\) 的稠密矩乘,但因为 \(Q\) 是 DCT,可以用 Makhoul 的 N-point 算法借助 FFT 在 \(O(n^2\log n)\) 时间内完成。对大层而言,这能把算对齐的开销加速 \(8\sim 50\times\)(在低端/老一代 GPU 上尤其明显,因为它们的 tensor-core 弱、FFT 收益更突出;新卡 tensor-core 太快时 matmul 反而不输)。这一设计让 Trion 的运行时与秩无关——无论 \(r\) 是 128 还是 512,选列开销几乎不变,而 Dion 因为 QR 的缘故运行时随 \(r\) 增长。
3. Trion:低秩动量 + Newton-Schulz,给 Dion 提速降存
Dion 用 Power-Iteration 求低秩投影、再用 QR 正交化,运行时随秩增长。Trion 把这两步换成「DCT 动态列选择 + 对低秩动量做 Newton-Schulz」。具体每步:累加动量 \(B_t=M_{t-1}+G_t\),算相似度 \(S_t=B_t D_C\),选列得索引 \(i_t\) 与投影矩阵 \(Q_t=D_C[:,i_t]\),从 \(S_t\) 直接抽出低秩动量 \(b_t=S_t[:,i_t]\);投影残差 \(\Delta_t=B_t-b_tQ_t^\top\) 通过 \(M_t=\mu B_t+(1-\mu)\Delta_t\) 以误差反馈形式累回动量,避免被丢弃的方向信息流失。
最关键的一步是:Newton-Schulz 只作用在 \(R\times r\) 的低秩动量 \(b_t\) 上,而非原始的 \(R\times C\) 全尺寸动量 \(B_t\),正交化后再 \(O_t=o_tQ_t^\top\) 升维回去更新参数。这把 Newton-Schulz 的运算从全尺寸矩阵降到 \(r\times r\),作者称这是首次用「动量的低秩近似」来降低 Newton-Schulz 的复杂度。在 DDP 下还能只在源 GPU 算更新、跨卡只通信低秩项 \(o_t\)(而非全尺寸正交矩阵),进一步省通信。
4. DCT-AdamW:用两组索引替掉两个投影矩阵
为了让动量 buffer 始终在「相同的低维子空间」里累加新梯度(这是 LDAdamW 保证一致性的做法),LDAdamW 必须逐层存两个相邻的投影矩阵来旋转动量,显存因此接近满秩 AdamW。DCT-AdamW 把这两个矩阵换成两组 \(r\) 个索引,配合可选的 8-bit 量化误差反馈(实测 8-bit 是不掉点的最低分辨率)和 ZeRO-redundancy 技巧(一层只在一张 GPU 上更新再广播、收方不分配优化器状态),从而在保持子空间一致性的同时把显存大幅压下来。
损失函数 / 训练策略¶
方法不改训练目标,只换优化器内部的投影机制。Trion 沿用 Dion 的最优超参(学习率 \(\eta=0.01\)、权重衰减 \(\lambda=0.01\)),参数更新带形状缩放 \(\theta_{t+1}=(1-\lambda\eta_t)\theta_t-\eta_t\max(1,\sqrt{R/C})O_t\)。秩取 \(r\in\{128,256,512\}\),对应秩比 \(r/d\in\{1/16,1/8,1/4,1/2\}\)。
实验关键数据¶
预训练在 C4 上从头训 Llama 350M/800M/1.3B,按 Chinchilla 最优(20 tokens/参数)、序列长 512、8×H100 DDP、全局 batch 512。
主实验:Trion vs Dion(PT,节选 rank 256)¶
| 模型 | 指标 | Trion | Dion | Muon(参考) |
|---|---|---|---|---|
| 350M | Val PPL | 15.30 | 15.64 | 14.99 |
| 350M | Memory(GB) | 42.42 | 45.59 | 42.42 |
| 350M | Runtime | 1h53m | 2h3m | 1h52m |
| 800M | Val PPL | 12.22 | 12.42 | 12.05 |
| 800M | Memory(GB) | 67.45 | 71.75 | 67.45 |
| 1.3B | Val PPL | 11.28 | 11.47 | 11.13 |
| 1.3B | Memory(GB) | 63.62 | 68.58 | 63.64 |
Trion 在所有模型/所有秩上训练与验证 loss/PPL 都低于 Dion,显存约低 10%,且运行时几乎不随秩变化;Dion 的运行时则随秩明显增长。Trion 相对 Dion 提速:rank 128 约 2.5–4.5%、rank 256 约 4.5–9%、rank 512 约 8–18%(秩越大优势越大)。
DCT-AdamW vs LDAdamW(Llama-800M,100 tokens/参数)¶
| 优化器 | Val PPL | Mem(GiB) | Time |
|---|---|---|---|
| AdamW(满秩参考) | 11.73 | 73.72 | 1d13h22m |
| LDAdamW | 13.91 | 72.10 | 2d1h24m |
| DCT-AdamW | 13.69 | 57.82 | 1d15h17m |
DCT-AdamW 比 LDAdamW PPL 更低、显存从 72.10 降到 57.82 GiB、运行时快约 10h7m(≈25.75%);相比满秩 AdamW 仅慢约 5%。LDAdamW 因要存两个投影矩阵,显存几乎和 AdamW 一样,而 DCT-AdamW 只存两组索引。
关键发现¶
- 投影更准是性能更好的根因:作者直接测累加器 \(B_t\) 与各优化器实际更新之间的 \(\ell_2\) 投影误差(Llama-30M 首个 transformer block 各层),Trion 的投影误差持续低于 Dion,解释了为何 loss 更低。
- 运行时与秩解耦是 Trion 相对 Dion 最大的工程优势,规模越大、秩越高,收益越显著。
- 省显存来自"存索引不存矩阵":把逐层稠密投影矩阵换成 \(r\) 个整数,这是 10%(Trion)到 25%(DCT-AdamW)显存下降的来源。
- 作者也把 DCT 投影插进 FRUGAL/FIRA 替掉 SVD(Llama-800M / 16B tokens,附录 I),验证该子例程的可移植性。
亮点与洞察¶
- 「固定基 + 选列」替「逐层分解」:最巧的地方是认识到投影矩阵不必逐层从头算——一个万能 DCT 矩阵 + 动态挑列就能逼近 SVD/QR,且把存储从矩阵降到索引。这种「预定义正交基 + 数据相关选择」的思路可迁移到任何需要逐层低秩投影的场景。
- 借 JPEG 的经验到优化器:DCT 在图像压缩里是能量集中的老将,作者首次把它引入低秩自适应梯度优化,并用 Makhoul FFT 算法把对齐计算降到 \(O(n^2\log n)\)。
- 只对低秩动量做 Newton-Schulz:把昂贵的正交化从全尺寸矩阵搬到 \(r\times r\),是给 Muon/Dion 一类正交化优化器提速的通用 trick。
- DDP 下只通信低秩项:跨卡只传 \(o_t\) 再本地 \(O_t=o_tQ_t^\top\),因为 DCT 矩阵已在每卡复制,省通信带宽。
局限与展望¶
- 作者指出,本文实验模型的 embedding 维 \(d\) 偏小,Makhoul FFT 相对 matmul 的运行时优势在 H100 这类新卡上看不太出来——FFT 的收益主要在老一代/低端 GPU 或更大的 \(d\) 上才显著(附录 D/E)。这意味着「FFT 加速」这个卖点在主流大卡上当前更多是理论优势。
- 评测集中在 PT 的 PPL 与 FT 准确率(FT 在附录 K),缺少下游任务大规模验证。
- 用固定 DCT 基本质上限制了子空间的「字典」——当梯度的主方向恰好与所有 DCT 基都不太对齐时,选列逼近 SVD 的质量可能下降,论文未深入刻画这种最坏情况。
- FSDP 实现需要根据每层是左/右投影来决定分片方式,工程上不如 DDP 直接。
相关工作与启发¶
- vs GaLore / LDAdamW(SVD 系):它们逐层做 SVD 求投影并显式存矩阵;本文用固定 DCT + 选列逼近,存储降到索引,免去每步/每几步的分解开销,DCT-AdamW 直接对标 LDAdamW 并省 25% 显存。
- vs Dion(QR + Power-Iteration):Dion 运行时随秩增长且逐层存投影矩阵;Trion 用 DCT 选列(秩无关)+ 仅对低秩动量做 Newton-Schulz,loss 更低、显存更省、运行时几乎不随秩变。
- vs Muon(Newton-Schulz 满尺寸正交化):Muon 要在 GPU 上物化全尺寸矩阵做迭代,难并行;Trion 先低秩再正交化,把 Newton-Schulz 降到 \(r\times r\),作者称是首次用低秩近似降低 Newton-Schulz 复杂度。
评分¶
- 新颖性: ⭐⭐⭐⭐ 首次把 DCT/FFT 的「固定基 + 动态选列」引入低秩优化器状态压缩,思路简洁且可移植。
- 实验充分度: ⭐⭐⭐⭐ 覆盖 3 个 Llama 规模、多种秩、Trion/DCT-AdamW 两条线 + 投影误差分析,但模型偏小、缺下游任务。
- 写作质量: ⭐⭐⭐⭐ 动机清晰、算法伪代码完整,FFT 收益的适用条件也坦诚交代。
- 价值: ⭐⭐⭐⭐ 即插即用替换 SVD/QR,显存降 10–25%、运行时与秩解耦,对大规模高效训练实用。