RaBiT: Residual-Aware Binarization Training for Accurate and Efficient LLMs¶
会议: ICML 2026
arXiv: 2602.05367
代码: 待确认
领域: 模型压缩 / LLM 量化 / 二值化
关键词: 残差二值化, 量化感知训练, LLM, 路径协同适应, matmul-free 推理
一句话总结¶
本文针对残差二值化 LLM 中"并行二值路径学到冗余特征"这一被作者命名为 inter-path adaptation 的失败模式,提出 RaBiT——用单一共享的全精度权重在线派生所有二值路径并配合函数感知初始化,从而结构性地强制残差层级,使 2-bit Llama2-7B 在 matmul-free 架构下首次反超 VQ 强基线(Wiki2 PPL 5.78 vs QTIP 5.86),同时获得 4.49× 推理加速。
研究背景与动机¶
领域现状:LLM 部署到极致压缩比时,4-bit 量化(GPTQ、AWQ)已经成为工业标准,但前沿正在向 2-bit 推进。2-bit 区域有两条主要路线:(i) 向量量化 VQ(AQLM、QuIP#、QTIP)通过查表或复杂旋转保留较高精度,但硬件开销大;(ii) 残差二值化通过堆叠多个 \(\{\pm1\}\) 二值层,天然支持 matmul-free(只用加减)的极致高效执行。残差二值化的核心承诺是"后续路径补偿前面路径的误差",从而以二值的代价获得接近多比特的表达力。
现有痛点:尽管残差结构看起来很美,但在 QAT 训练中始终不稳定。作者深入分析发现,标准 QAT 把同一个全局梯度同时作用到所有并行路径,这会驱使每条路径在"竞速降低同一全局损失"中学到几乎一样的特征——即 Hinton 2012 命名的"feature co-adaptation"在残差二值化中的具体表现,作者称为 inter-path adaptation。结果是误差补偿层级被破坏,模型表达力被严重削弱。
核心矛盾:MSE 分解告诉我们路径间必须负相关、第二条路径必须主动对齐第一条路径的残差,模型才能真正发挥多路径的容量;但标准 QAT 的对称结构和共享梯度让路径几乎独立、关联接近零,于是堆叠多路径只是徒增参数,没有起到误差补偿的作用。以往工作(DB-LLM、MBOK)依赖启发式约束(路径冻结、机械分裂)来打破这种对称,但要么牺牲了联合优化空间,要么虽然制造了负相关但残差对齐很差。
本文目标:(i) 给出 inter-path adaptation 的形式化诊断指标;(ii) 在算法层面而非启发式层面把残差层级写进训练循环;(iii) 解决 2-bit QAT 初始化对最终精度的强敏感性。
切入角度:既然问题根源是"两条路径各自维护独立的潜变量权重 + 共享全局梯度",那就反过来——只保留一个全精度权重 \(\mathbf{W}_{\mathrm{FP}}\) 作为锚点,每个 step 现场从它派生出第一条路径与残差,再从残差派生第二条路径,让"第二条路径补偿第一条"成为图结构上的硬约束而不是 loss 上的软鼓励。
核心 idea:用一个共享全精度权重在线串联派生所有二值路径(耦合前向),让残差层级在每个 step 都被自动重建;再用 Iterative Residual SVID + I/O 通道重要性预条件化提供一个"保功能而非保权重"的稳定初始化。
方法详解¶
整体框架¶
RaBiT 用两类设计组合解决 2-bit 残差二值化的训练病:(a) 训练时不再为每条路径单独维护潜变量权重,而是只保留一个共享 \(\mathbf{W}_{\mathrm{FP}}\),每个前向 step 现场派生 \(\mathbf{B}_1=\text{sign}(\mathbf{W}_{\mathrm{FP}})\)、\(\mathbf{R}_1=\mathbf{W}_{\mathrm{FP}}-\hat{\mathbf{W}}_1\)、\(\mathbf{B}_2=\text{sign}(\mathbf{R}_1)\);(b) 每条路径仍保留独立的可学习逐通道缩放 \(\{\mathbf{g}_i,\mathbf{h}_i\}\) 以保留容量;(c) 用 Iterative Residual SVID + 输入/输出通道重要性预条件化做函数感知初始化;(d) 推理时把训练好的 \(\mathbf{B}_i\) 冻结、丢弃 \(\mathbf{W}_{\mathrm{FP}}\),回到原来的并行 matmul-free 架构。
二值基本块写作 \(\hat{\mathbf{W}}=\mathbf{g}\odot\mathbf{B}\odot\mathbf{h}\),其中 \(\mathbf{B}\in\{-1,+1\}^{d_{\text{out}}\times d_{\text{in}}}\)、\(\mathbf{g}\in\mathbb{R}^{d_{\text{out}}}\)、\(\mathbf{h}\in\mathbb{R}^{d_{\text{in}}}\),矩阵-向量乘 \(\mathbf{y}=\mathbf{g}\odot(\mathbf{B}(\mathbf{h}\odot\mathbf{x}))\) 只用加减实现。2-bit 时堆叠 \(k=2\) 条这样的二值块求和。
关键设计¶
-
共享 FP 权重 + 在线派生的耦合前向 (Coupled Forward Pass):
- 功能:把"残差补偿"从损失上的偏好变成图结构上的硬约束,从根本上消除 inter-path adaptation。
- 核心思路:训练阶段只存 \(\mathbf{W}_{\mathrm{FP}}\),每个 step 三步走——(1) \(\mathbf{B}_1=\text{sign}(\mathbf{W}_{\mathrm{FP}})\),组合得到 \(\hat{\mathbf{W}}_1=\mathbf{g}_1\odot\mathbf{B}_1\odot\mathbf{h}_1\);(2) 残差 \(\mathbf{R}_1=\mathbf{W}_{\mathrm{FP}}-\hat{\mathbf{W}}_1\);(3) \(\mathbf{B}_2=\text{sign}(\mathbf{R}_1)\)。最终有效权重 \(\hat{\mathbf{W}}^{(2)}=\hat{\mathbf{W}}_1+\hat{\mathbf{W}}_2\)。反向用一个 STE 把 \(\nabla_{\hat{\mathbf{W}}^{(2)}}\mathcal{L}=(\partial\mathcal{L}/\partial\mathbf{Y})\mathbf{X}^{\top}\) 直接灌给 \(\mathbf{W}_{\mathrm{FP}}\),缩放向量 \(\{\mathbf{g}_i,\mathbf{h}_i\}\) 按常规链式求导更新且把动态派生的 \(\mathbf{B}_i\) 视为常数。
- 设计动机:以往做法用两套独立潜权重让两条路径独立优化,结构上根本不区分"主路径"与"补偿路径",所以梯度一锤同抡时它们必然学到冗余特征;本文把第二条路径写成第一条路径残差的函数,即使梯度同抡,结构本身也强制 \(\mathbf{B}_2\) 永远在追 \(\mathbf{R}_1\)。一个意外的副产物:只存一套全精度权重直接把优化器状态(如 Adam 的动量/方差)减半,节省了 LLM 微调中最稀缺的显存。
-
函数感知初始化 (Iterative Residual SVID + I/O Importance Preconditioning):
- 功能:解决 2-bit QAT 起点对最终精度的强敏感性,把"先准确重建权重"换成"先准确重建功能输出"。
- 核心思路:先用基于校准集的输入激活幅度 \(\mathbf{s}_{\text{in}}\) 和输出梯度幅度 \(\mathbf{s}_{\text{out}}\) 对原始权重做预条件化 \(\mathbf{W}'=\mathbf{s}_{\text{out}}^{\alpha_{\text{out}}}\odot\mathbf{W}_{\mathrm{FP}}\odot\mathbf{s}_{\text{in}}^{\alpha_{\text{in}}}\),把分解资源集中到功能上敏感的通道(Fisher / K-FAC 的局部敏感性直觉)。随后用 Iterative Residual SVID 以 Gauss-Seidel 风格在 \(T\) 轮里迭代刷新每条路径的 \((\mathbf{B}_i,\mathbf{g}_i,\mathbf{h}_i)\):每轮把"其他路径已经吃掉的部分"从 \(\mathbf{W}'\) 里减去,再用 SVID(基于秩-1 SVD 的幅度分解)拟合剩余残差;最后把缩放映射回原始域。
- 设计动机:标准 SVID 是贪心的——第一条路径独占最优拟合,会把残差结构推到很差的局部极小,后续路径再怎么救都救不回来。SVID 的"残差迭代 + 通道重要性预条件化"两步分别解决了路径间贪心耦合和"等权拟合所有通道但实际只有少数通道功能上重要"两个互相加剧的初始化病。Table 7 / Figure 5 显示两者各自贡献明显,组合后起点 loss 最低、QAT 启动期最稳。
-
MSE 分解给出的可诊断指标 (Inter-Path Adaptation Diagnostic):
- 功能:给"路径协同适应"提供量化的诊断量,并指明 RaBiT 修复了哪一项。
- 核心思路:对 2-bit 残差网络 \(y_s=y_1+y_2\),将 MSE 展开为 \(\text{MSE}(y_t,y_s)=C'+2\sigma_1\sigma_2\cdot\text{Corr}(y_1,y_2)\),并补一个等价视角 \(\text{MSE}\approx\sigma_{R_1}^2+\sigma_{y_2}^2-2\sigma_{R_1}\sigma_{y_2}\cdot\text{Corr}(R_1,y_2)\),其中 \(R_1=y_t-y_1\) 是第一条路径的功能残差。两个相关性给出"路径对路径相关 Corr\((y_1,y_2)\) 应该足够负"和"残差对齐 Corr\((R_1,y_2)\) 应该足够正"两条独立判据。
- 设计动机:以往工作只看最终 PPL,看不出"为什么没起到补偿效果"。这套指标直接揭示了:标准 QAT 路径相关接近 0(根本没产生补偿);DB-LLM 机械负相关 -0.49 但残差对齐仅 0.26(互相抵消而非追误差);MBOK 略有改善仍偏弱;只有 RaBiT 同时取得高负相关(≈-0.35 到 -0.50)和高残差对齐(0.58–0.65),从机理上证明耦合训练真的让第二条路径在追功能残差。
损失函数 / 训练策略¶
总损失 \(\mathcal{L}_{\text{total}}=\mathcal{L}_{\text{kl}}+\gamma\sum_i\mathcal{L}_{\text{inter},i}\),KL 散度蒸馏 + 中间层 MSE 蒸馏(Llama 取 \(\gamma=100\);Gemma3 因激活幅度大取 \(\gamma=0\))。在 WikiText-2 + C4 的 2 亿 token 校准集上用 Muon 优化器训 6 个 epoch,上下文 4096。论文附录 B 把 MSE 分析中的最优性扩展到 KL 目标。
实验关键数据¶
主实验¶
在 Llama2-7B/13B、Llama3-8B、Gemma3-1B/4B/12B 上对比 SOTA 2-bit 方法。
| 模型 / 数据 | 指标 | RaBiT (2-bit) | 之前最佳 (2-bit) | 全精度基线 (16-bit) |
|---|---|---|---|---|
| Llama2-7B Wiki2 | PPL ↓ | 5.78 | QTIP 5.86 / DBF 6.10 / MBOK 6.99 | 5.12 |
| Llama2-7B QA Avg | Acc ↑ | 61.51 | QTIP 58.97 / DBF 58.42 | 62.26 |
| Llama2-13B Wiki2 | PPL ↓ | 5.15 | QTIP 5.11(仅次) | 4.57 |
| Llama3-8B Wiki2 | PPL ↓ | 7.34 | QTIP 7.52 / QuIP# 8.70 / BitStack 2.75e3(崩) | 5.75 |
| Llama3-8B QA Avg | Acc ↑ | 64.13 | AQLM 64.12 / QTIP 63.88 | 68.66 |
| Gemma3-1B Wiki2 | PPL ↓ | 11.27 | QTIP 13.14 / DBF 13.28 | 9.80 |
| Llama2-13B 难任务平均 (BBH+GPQA+MMLU-Pro+IFEval) | Acc ↑ | 27.14 | QTIP 25.38 | 29.27 |
| Llama2-7B 端到端解码加速 | Speedup ↑ | 4.49× vs FP16 | — | 1.00× |
消融实验¶
| 配置 | Llama2-7B Wiki2 PPL ↓ | 说明 |
|---|---|---|
| Standard QAT(独立潜权重) | 6.55 | 基线,inter-path adaptation 严重 |
| Standard QAT + Iterative SVID 初始化 | 6.21 | 仅换初始化也有收益 |
| Standard QAT + I/O 重要性预条件化 | 6.31 | 单独的功能感知预条件化 |
| Standard QAT + 两者组合初始化 | 6.18 | 两者协同 |
| Coupled QAT(仅本文耦合前向) | 5.84 | 解决 inter-path adaptation 是主要收益 |
| Coupled QAT + SVID | 5.80 | |
| Coupled QAT + 预条件化 | 5.81 | |
| RaBiT(全套) | 5.78 | 完整方案 |
Table 1 的 MSE 分解给出机理验证:在 Llama2-7B 第 5 / 15 / 25 层,RaBiT 的残差对齐 Corr\((R_1,y_2)\) 分别达到 0.65 / 0.58 / 0.62,显著高于 DB-LLM 的 0.26 / 0.25 / 0.25,证明残差层级真的被恢复了。
关键发现¶
- 耦合训练贡献最大:从 Standard QAT 6.55 到 Coupled QAT 5.84 只换前向结构就拿到 0.71 PPL,远超初始化单独带来的 0.34;说明 inter-path adaptation 是 2-bit 残差架构的首要瓶颈,而非初始化。
- 两条改动协同:耦合 + 函数感知初始化各自贡献 ≈0.7 / 0.4 PPL,组合后到 5.78,没有明显的边际递减;说明二者互补,前者优化结构,后者优化起点。
- 反超 VQ:Llama2-7B 上 RaBiT 5.78 PPL 微优于 QTIP 5.86,且保留 matmul-free,在 RTX 4090 上 4.49× 推理加速;在 Llama3-8B 上更明显(7.34 vs 8.70),QuIP# 等 VQ 方法在新架构上明显退化。
- 训练显存减半:只维护一个 \(\mathbf{W}_{\mathrm{FP}}\) 直接把优化器状态从 2 路减到 1 路,副作用是 QAT 显存压力大幅缓解。
亮点与洞察¶
- 把"残差补偿"从损失偏好升级为图结构强约束,这是非常优雅的范式切换——很多并行多路径架构(不只是二值化)都受困于路径冗余,"链式派生 + 共享锚点"的思路完全可以迁移到 MoE 路由、多分支蒸馏、低秩残差适配等场景。
- MSE 分解给出的双指标(Corr\((y_1,y_2)\) 与 Corr\((R_1,y_2)\))非常有解释力,揭示了 DB-LLM 那种"机械负相关"是假补偿,真正的补偿必须看残差对齐——这条诊断方法值得任何做残差/集成模型的人借鉴。
- 把缩放 \(\{\mathbf{g}_i,\mathbf{h}_i\}\) 留作独立可学习参数而不是每步重算 SVD,是一个"结构上严格 + 优化上松弛"的精妙折中:算法保证误差层级,优化器仍能用 momentum 等状态对幅度做细调,免去 SVD 的高额计算。
局限与展望¶
- 难任务上仍有显著差距:BBH/GPQA/MMLU-Pro/IFEval 平均 27.14 vs 全精度 29.27 在 Llama2-13B 上还可以,但 Llama3-8B 的 25.12 vs 31.03 差距更明显,说明 2-bit 在复杂推理上仍未完全过关。
- IFEval 是个不一致的弱项:Llama3-8B 上 RaBiT 15.42 比 QTIP 15.60 还略低且远低于基线 32.51,说明指令遵循类的"格式敏感性"在二值化下损失严重,需要更针对性的训练。
- Llama2-13B 上 RaBiT 5.15 PPL 略输 QTIP 5.11,作者归因为模型越大 VQ 越占便宜,但没有提出针对大模型的进一步改进。
- 共享权重训练时每步要重做派生,单 step 计算量增加;论文没有详细给出训练吞吐量与标准 QAT 的对比,工程实现上的额外开销在长训练 schedule 下值得关注。
- 框架只验证了 \(k=2\)(即 2-bit),\(k\ge 3\) 时链式派生的数值稳定性、梯度信号是否依然清晰未做实验。
相关工作与启发¶
- vs DB-LLM [Chen 2024]: 用启发式分裂强制路径负相关,但残差对齐很差;本文揭示了"负相关 ≠ 真补偿",残差对齐才是关键指标,并用结构约束让两者同时满足。
- vs MBOK [Tran & Nguyen 2025]: 用路径冻结避免协同适应,本质上是限制了联合优化空间;本文允许全部参数联合优化,靠图结构而非冻结来约束行为。
- vs DBF [Boža & Macko 2025]: 也强调函数保留,本文进一步把这一思想落到初始化的"I/O 重要性预条件化"上,并且和耦合训练正交可叠加。
- vs QTIP / QuIP# / AQLM 等 VQ: 这些方法用查表换精度,硬件不友好;本文证明 matmul-free 残差二值化在精度上也能追平甚至反超,是对"2-bit 必须靠 VQ"这一行业默认假设的有力反驳。
评分¶
- 新颖性: ⭐⭐⭐⭐ "用结构约束代替启发式"+"用残差对齐而非路径相关诊断"两点都很新。
- 实验充分度: ⭐⭐⭐⭐⭐ 覆盖 Llama2/3、Gemma3 共 6 个模型,含 PPL、QA、BBH/GPQA/MMLU-Pro/IFEval、推理速度、MSE 分解、消融。
- 写作质量: ⭐⭐⭐⭐⭐ 问题诊断、机理证明、消融对齐做得非常工整,是 LLM 量化少见的"机理论文"。
- 价值: ⭐⭐⭐⭐⭐ 首次让 matmul-free 2-bit 反超 VQ 强基线并保持 4.49× 加速,是 2-bit 部署落地的重要里程碑。