Multi-Head Low-Rank Attention (MLRA)¶
会议: ICLR 2026
arXiv: 2603.02188
代码: GitHub / HuggingFace
领域: 自动驾驶
关键词: KV Cache, 张量并行, 低秩注意力, 解码效率, Multi-Head Latent Attention
一句话总结¶
提出 Multi-Head Low-Rank Attention (MLRA),通过将 MLA 的单一 latent head 分解为多个可独立分片的 latent head,并对各分支注意力输出求和,实现原生 4-way 张量并行支持,在保持 SOTA 性能的同时获得 2.8× 的解码加速。
研究背景与动机¶
长上下文推理瓶颈:LLM 推理时,解码阶段需要在每一步从 HBM(高带宽内存)向 SRAM 反复传输 KV cache,数据搬运(而非计算)成为长上下文推理的主要延迟来源。
MLA 的成功与局限:DeepSeek 提出的 Multi-Head Latent Attention (MLA) 通过将 KV cache 压缩为单一 latent head(每 token 仅需 \(4.5 d_h\)),显著减少了 KV cache 总量。但其单一 latent head 不可分片——在张量并行(TP)解码时,每个设备被迫冗余加载完整 KV cache。
TP 下 MLA 的性能停滞:无论 TP 程度如何,MLA 的每设备 KV cache 加载量始终固定在 \(4.5 d_h\),TP 带来的权重分片优势被完全抵消。
GLA-2 的部分解决:GLA-2 将 latent head 二等分为两个较小的 latent head,在 2-way TP 下降低为 \(2.5 d_h\),但 TP > 2 时同样无法进一步降低。
算术强度的重要性:解码效率的关键指标是算术强度(FLOPs/byte),需要在减少内存加载的同时保持高计算密度,从而将工作负载从内存受限推向计算受限。
方差不匹配问题:MLA 中 RoPE key 与 NoPE key 存在显著的方差不匹配(\(\text{Var}(K^{\text{RoPE}}) / \text{Var}(K^{\text{NoPE}}) \approx d/d_c\)),当 latent 维度远小于隐藏维度时问题尤为突出。
方法详解¶
整体框架¶
这篇论文要解决的痛点是:MLA 把 KV cache 压成单一 latent head,总量虽小,却无法分片——做张量并行(TP)解码时,每个设备都被迫冗余加载完整的 \(4.5 d_h\) KV cache,TP 本该带来的权重分片红利被抵消,加载量随设备数停滞不前。MLRA 的破局思路是把这个不可分的 latent head 显式拆成多个独立的 latent block,让每块各自 up-project 成 NoPE KV、各跑一路分支注意力,最后把各分支输出求和;分支之间互不耦合,于是天然能一一映射到多个 TP 设备。
前向一条线走下来是:隐藏状态 \(H\) 先下投影得到 latent \(C^Q\)、\(C^{KV}\) 和共享的 RoPE key → 把 \(C^{KV}\) 沿通道切成多块、各块独立 up-project 后分路算注意力 → 各分支输出按分支数归一后求和得到 \(O\)。其中 latent 侧的缩放与输出侧的归一一起构成方差校准,保证拆分后 softmax logits 不失衡。论文给两个变体:MLRA-2 用 2 个 latent block、每块服务半数注意力头,产生 2 分支求和(适合 2-way TP);MLRA-4 用 4 个 latent block、每块服务全部注意力头,产生 4 分支求和,原生支持 4-way TP。
%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400}}}%%
flowchart TD
H["隐藏状态 $H$<br/>下投影得 latent $C^Q$,$C^{KV}$ 与 RoPE key"]
SPLIT["块分解:$C^{KV}$ 沿通道切成多块<br/>各块独立 up-proj 成 NoPE KV,分路算分支注意力"]
CAL["方差校准:latent 侧按 $\sqrt{d/d_c}$ 缩放<br/>各分支输出按分支数归一后求和 → $O$"]
DEC["高效解码:weight absorption 后<br/>各 latent block 映射一个 TP 设备,每设备只搬 $1.5 d_h$,再 all-reduce"]
H --> SPLIT
SPLIT -->|"MLRA-2: 2 分支 / MLRA-4: 4 分支"| CAL
CAL --> DEC
关键设计¶
1. 块分解:把不可分片的求和拆成可独立计算的分支
MLA 无法分片,根因是它把整段 KV 信息压在一个 latent head 里,解码时只能整块搬运。MLRA 的破局点在于一个被忽视的代数恒等式:MLA 里每个头的 NoPE key/value 其实等价于若干子块乘积之和(原文 Eq. 2)。既然结果是一个求和,就能把这个求和从「KV 计算阶段」搬到「注意力输出阶段」。具体做法是把 KV latent 矩阵 \(C^{KV} \in \mathbb{R}^{n \times d_c}\) 沿通道切成块 \(C_{:,(b)}^{KV}\),对应地把 up-projection 矩阵 \(W^{UK}, W^{UV}\) 按行切成子块。这样每个头的输出就变成各分支注意力之和(以 MLRA-4 的 4 块为例):
关键在于:每个分支 \(b\) 只依赖自己那块 \(C_{:,(b)}^{KV}\)(大小 \(d_h\)),分支之间互不耦合,于是各分支可以原封不动地分配给不同 TP 设备各算各的,最后求和——这正是 MLA 做不到的事。同一套拆分思路按分支数衍生出两个变体:MLRA-4 用 4 块、每块经 up-projection 服务全部 \(h\) 个头,产生 4 分支求和;MLRA-2 沿用 GLA-2 的分组思路把 latent head 二等分,由分组映射函数 \(\gamma(i)\) 决定哪些头用第一组 latent、哪些用第二组,每块只服务 \(h/2\) 个头,产生 2 分支求和,在更轻的拆分开销下适配 2-way TP。
2. 方差校准:拆分会放大 RoPE/NoPE 的方差错位,得显式补偿
分支拆开不是免费的——它会加剧 MLA 本就存在的方差不匹配。理论上 NoPE key 的方差是 \(d_c \sigma_w^2\)、RoPE key 的方差是 \(d \sigma_w^2\),两者之比约为 \(d/d_c\),当 latent 维度远小于隐藏维度时已差一个量级;而把 latent 再切成多块、再把多分支输出加起来,又一次改变了输出的方差尺度,若放任不管 softmax 的 logits 分布就会失衡。MLRA 的对策是两处显式缩放:在 latent state 侧用 \(\alpha_q = \sqrt{d/d_c'}\)、\(\alpha_{kv} = \sqrt{d/d_c}\) 把 query/NoPE key 的方差拉回与 RoPE key 可比的尺度;在输出侧再按分支数归一,MLRA-2 的输出乘 \(1/\sqrt{2}\)、MLRA-4 乘 \(1/2\),抵消多分支求和带来的方差膨胀。这套缩放是从方差推导直接得来、而非凭经验调参;不过它依赖权重 i.i.d. 的 Assumption 1,训练中该假设并不严格成立,因此最终以消融验证为准。
3. 高效解码:各块天然落到不同设备,每设备只搬 \(1.5 d_h\)
前两步的设计最终要兑现成解码时的真实加速。因为各 latent block 彼此独立,它们可以自然地一一映射到 TP 设备上——以 MLRA-4 的 4-way TP 为例,每个设备只加载一个 latent block(\(d_h\))外加各设备共享的 RoPE key(\(0.5 d_h\)),于是每设备 KV 加载量从 MLA 的 \(4.5 d_h\) 降到 \(1.5 d_h\)(GLA-2 在 2-way TP 下也只能降到 \(2.5 d_h\),且 TP>2 后停滞)。落地时沿用 MLA 的 weight absorption 技巧,把 up-projection 吸收进 query 侧,每个设备就退化成一个标准的 MQA-style 解码,算完各分支后再 all-reduce 求和即可。这样 TP 带来的好处第一次真正传导到了 KV cache 搬运上,把解码从内存受限推向计算受限。
损失函数 / 训练策略¶
- 初始化:输出投影参数使用零初始化(优于标准 \(\mathcal{N}(0, 0.02)\)),其余参数标准初始化
- 优化器:AdamW,\((\beta_1, \beta_2)=(0.9, 0.95)\),weight decay 0.1,梯度裁剪 1.0
- 学习率:peak \(1.6 \times 10^{-4}\),前 2000 步线性 warmup,之后余弦退火至 10%
- 训练规模:2.9B 参数,FineWeb-Edu-100B 数据集,98.3B token,上下文长度 2048,8 × H100 GPU
- 可选门控机制:在注意力输出投影前添加 gating,可进一步降低困惑度
实验关键数据¶
主实验¶
验证集困惑度(7 个数据集平均,越低越好):
| 方法 | Wikipedia | C4 | Pile | RefinedWeb | Cosmopedia | FineWeb | FineWeb-Edu | 平均 |
|---|---|---|---|---|---|---|---|---|
| MHA | 14.624 | 16.575 | 12.929 | 18.698 | 9.102 | 15.656 | 9.434 | 13.860 |
| GQA | 15.057 | 16.628 | 13.758 | 18.885 | 9.504 | 15.713 | 9.427 | 14.139 |
| MLA | 14.567 | 16.345 | 12.965 | 18.523 | 8.966 | 15.440 | 9.284 | 13.727 |
| GLA-2 | 14.605 | 16.323 | 13.225 | 18.509 | 9.118 | 15.424 | 9.249 | 13.779 |
| MLRA-4 | 14.407 | 16.286 | 13.124 | 18.398 | 8.937 | 15.361 | 9.193 | 13.672 |
零样本常识推理准确率(%):
| 方法 | ARC-E | ARC-C | OBQA | BoolQ | HellaSwag | Winogrande | PIQA | 平均 |
|---|---|---|---|---|---|---|---|---|
| MHA | 69.11 | 39.16 | 40.80 | 62.26 | 60.82 | 57.62 | 74.86 | 57.81 |
| GQA | 67.13 | 39.42 | 42.00 | 63.39 | 61.29 | 56.91 | 75.08 | 57.89 |
| MLA | 68.22 | 39.16 | 42.60 | 64.10 | 61.39 | 60.06 | 75.68 | 58.75 |
| MLRA-4 | 67.63 | 41.38 | 43.00 | 61.74 | 62.16 | 61.48 | 74.48 | 58.84 |
消融实验¶
门控机制对困惑度的影响(7 个数据集平均):
| 方法 | 无门控 | 有门控 | 提升 |
|---|---|---|---|
| GQA | 14.139 | 13.806 | -0.333 |
| MLA | 13.727 | 13.642 | -0.085 |
| GLA-2 | 13.779 | 13.701 | -0.078 |
| MLRA-2 | 13.804 | 13.651 | -0.153 |
| MLRA-4 | 13.672 | 13.621 | -0.051 |
其他消融发现: - 零初始化 vs \(\mathcal{N}(0, 0.02)\):零初始化在所有模型上均优于随机初始化 - 方差缩放:MLA 和 GLA-2 收益显著,MLRA-2 收益边际(因分支已天然减缓方差不匹配) - 加倍注意力头数:在参数总量不变的条件下加倍 GQA/MLA/GLA-2 的头数,均未带来改善反而损害性能
关键发现¶
- MLRA-4 全面最优:在困惑度(13.672)和零样本推理准确率(58.84%)上均超越所有基线,包括 MLA
- 2.8× 解码加速:相比 MLA,MLRA-4 在 128K-2M token 长上下文解码中实现稳定 2.8× 加速
- 1.05-1.26× 相比 GQA:在长上下文解码中超越 GQA,且差距随上下文长度增加而扩大
- TP=4 即可达到 1.5 \(d_h\):GQA 和 GTA 需要 8-way TP 才能达到类似的每设备 KV 加载量,MLRA 仅需 4-way
- 门控进一步提升:加入门控后 MLRA-4 困惑度降至 13.621,仍保持最优
亮点与洞察¶
- 优雅的数学动机:从 MLA 的块分解出发,将 KV 层面的求和移至注意力输出层面求和,数学上简洁且直觉清晰
- 方差分析扎实:从理论上推导了各组件的方差,给出了明确的校准策略,避免了凭经验调参
- 实用性极强:MLRA 与 MLA 共享相同的 KV cache 总量(每 token \(4.5 d_h\)),区别仅在于分布式解码时是否可分片,对现有 MLA 系统的迁移成本低
- 算术强度分析:MLRA-4 的算术强度约为 \(2h\)(MLA 也是 \(2h\)),说明在减少内存加载的同时不牺牲计算效率
- 完整的工程实现:基于 FlashAttention-3 实现 MLRA-4 kernel,并在真实 H100 集群上验证
局限与展望¶
- 仅在 2.9B 规模验证:未在 7B+ 或更大规模上实验,大模型上 MLRA 的优势是否保持有待验证
- 预训练数据单一:仅使用 FineWeb-Edu-100B,未在多语言或代码混合数据上验证
- Assumption 1 的局限:方差分析假设权重 i.i.d. 分布且与输入独立,训练过程中此假设不严格成立
- 固定 4-way 分解:MLRA-4 绑定了 4-way TP,对于 2-way 或 8-way TP 场景需要分别使用 MLRA-2 或进一步扩展
- 多分支求和的近似性:将求和从 softmax 内部移至外部本质上改变了注意力分布,虽然实验效果好但理论上并非等价变换
- 未评估指令微调和对齐后的效果:仅关注预训练,未覆盖 SFT/RLHF 阶段
相关工作与启发¶
- MLA (DeepSeek-V2/V3):MLRA 的直接前驱,通过 latent compression 压缩 KV cache,但不支持 TP 分片
- GQA:Grouped Query Attention,通过减少 KV head 数实现效率提升,但 KV cache 仍随头数线性增长
- GLA-2 (Zadouri et al., 2025):首个尝试拆分 MLA latent head 的工作,但仅支持 2-way TP
- TPA (Zhang et al., 2025):Tensor Product Attention,用共享 head 的线性组合构造 KV,TP 支持有限
- FlashMLA / FlashAttention-3:高效注意力 kernel,MLRA 的 kernel 基于 FlashAttention-3 实现
- LongCat (2025):首次观察到 RoPE key 的方差不匹配问题,MLRA 沿用并扩展了其缩放策略
- 启发:这种"将内部求和外提为多分支独立计算"的思路可能适用于其他需要 latent compression + TP 的场景(如 KV cache 量化、稀疏注意力等)
评分¶
- 新颖性: ⭐⭐⭐⭐ — 洞察精准,将 MLA 的"不可分片"问题通过分支分解优雅解决
- 实验充分度: ⭐⭐⭐⭐ — 多组消融+多数据集评估+解码速度/吞吐量测试,但仅 2.9B 规模
- 写作质量: ⭐⭐⭐⭐⭐ — 数学推导严谨,符号清晰,从背景到方法的逻辑链条完整
- 价值: ⭐⭐⭐⭐ — 直接解决 MLA 部署的实际痛点,对 DeepSeek 系列和大规模推理部署有重要实用价值