跳转至

Alignment-Enhanced Integration of Connectivity and Spectral Sparsity in Dynamic Sparse Training of LLM

会议: ICLR 2026
OpenReview: https://openreview.net/forum?id=jZplmg7Ad9
代码: 补充材料中提供(未公开仓库链接)
领域: 模型压缩 / 参数高效预训练
关键词: 动态稀疏训练, 低秩分解, 连接稀疏, 谱稀疏, 抵消效应, 对齐损失

一句话总结

本文首次把动态连接稀疏(CHTs)与动态低秩谱稀疏真正融合到一个统一的稀疏预训练框架里,发现两个分支朴素相加会产生输出"互相抵消"的现象,并用一个简单的对齐损失把它们拉到同一方向协作,得到的 CHTsL 在 LLaMA-60M/130M 上仅保留 10%~30% 参数即逼近 dense 训练。

研究背景与动机

领域现状:从零训练 LLM 极度消耗显存与算力,参数高效稀疏预训练应运而生。它分两支:连接稀疏训练(在权重矩阵的连接结构上强制稀疏,代表是动态稀疏训练 DST,如 SET/RigL/MEST/CHT/CHTs)和谱稀疏训练(用低秩分解约束权重子空间,如 LoRA 的预训练版 ReLoRA/GaLore,以及训练+推理全程保持低秩的 CoLA)。两支各有所长,DST 在 10% 参数下能逼近 dense,低秩则擅长捕捉整体子空间。

现有痛点:把两支结合的工作几乎没有。唯一的先驱 SLTrain 有两个硬伤——(1) 它的稀疏分支是静态的,只当低秩的"补充项",没用上动态连接稀疏的真正威力;(2) 它只是把稀疏输出和低秩输出直接求和,没有任何让两者协作的机制。

核心矛盾:作者观察到,当稀疏分支输出 \(S\) 和低秩分支输出 \(L\) 一起训练时,二者经常指向相反方向——一个把某个特征往正方向推,另一个往负方向推,净效果被中和掉,白白浪费表达能力。这种 cancellation effect(抵消效应) 让朴素相加 \(S+L\) 无法真正承载两个分支各自的信息,尤其在注意力的 Q、K 矩阵上最严重,因为它们的点积直接决定注意力权重,对不一致极其敏感。

本文目标:建立一个真正融合动态连接稀疏与动态谱稀疏的统一框架,量化并缓解抵消效应,让两个分支协同而非互斥。

核心 idea用一个对齐损失把稀疏分支和低秩分支的输出拉到同方向,配合对低秩分支的激活稳定化,使二者从"互相抵消"变成"互补协作"。

方法详解

整体框架

框架由三步构成:先用 OCR 指标识别并量化抵消效应;再用对齐损失 + 激活调整组成的训练框架让两个分支稳定协作;最后把先进的动态稀疏方法 CHTs 与低秩分支实例化为 CHTsL。每层输出是稀疏分支与低秩分支之和 \(O^{(l)}=S^{(l)}+L^{(l)}\),前向时两分支并行计算、反向时由对齐损失牵引彼此靠拢。

flowchart LR
    X[输入 x] --> S["动态连接稀疏分支<br/>CHTs: 连接演化"]
    X --> L["谱稀疏分支<br/>L = B·σ(A·x), σ=SiLU"]
    S --> ADD["逐元素相加 O = S + L"]
    L --> ADD
    S -.对齐.-> AL["对齐损失<br/>‖S − L‖_F"]
    L -.对齐.-> AL
    ADD --> OUT[层输出]
    AL --> LOSS["总损失 L = L_task + λ·L_align"]

关键设计

1. OCR 指标:把"抵消"量化成一个可观测的数——作者先把直觉变成度量,定义 重叠抵消比 Overlap Cancellation Ratio\(\mathrm{OCR}=\frac{\sum_i \min(|S_i|,|L_i|)\cdot\mathbb{1}\{S_iL_i<0\}}{\sum_i \min(|S_i|,|L_i|)+\varepsilon}\)。分子只统计那些两分支符号相反\(S_iL_i<0\))的位置上,被抵消掉的重叠信号量(取两者绝对值的较小者,即真正被中和的部分),分母是全部重叠信号。OCR 落在 \([0,1)\),越大说明抵消越严重。这个指标让"两分支打架"从模糊感觉变成可以逐层、逐训练步追踪的曲线,也是后续验证对齐有效性的核心证据。

2. 对齐损失:把两分支拉到同方向协作——既然抵消来自方向冲突,最直接的办法就是惩罚两分支输出的差异。作者对每层定义 \(L^{(l)}_{\text{align}}=\frac{1}{BN}\lVert S^{(l)}-L^{(l)}\rVert_F\)\(B\) 是 batch size,\(N\) 是单样本该层输出的元素数),所有层求和得 \(L_{\text{align}}=\sum_l L^{(l)}_{\text{align}}\)。这个 Frobenius 范数惩罚让稀疏与低秩输出趋于一致,从而减少破坏性干涉——但注意目标不是让两者完全相同(那样就退化成单分支),而是降低方向冲突,让每个分支专注于表达的互补侧面。它最终以系数 \(\lambda\) 加权进总目标。

3. 低秩激活稳定化:让低秩分支可靠出力——低秩分解虽省参数,但在极端稀疏下输出容易不稳定、尺度失控。借鉴 CoLA,作者在低秩两个因子矩阵之间插入一个温和的非线性:\(L^{(l)}=B^{(l)}\,\sigma(A^{(l)}x)\),其中 \(\sigma\) 取 SiLU。这里激活的作用不是增强表达,而是维持合理尺度、防止数值发散,确保低秩分支能稳定地与稀疏分支并肩工作。消融显示单加激活(Act)已能大幅改善朴素相加(Naive)的崩溃,再叠加对齐(Act+Align)进一步提升。

4. CHTsL 实例化与统一目标——把上述框架落地:稀疏分支用 CHTs 的连接演化规则(基于 Cannistracci-Hebbian 理论从脑连接组得到启发的动态稀疏),低秩分支带激活调整,对齐损失逐层施加。总目标为 \(L=L_{\text{task}}+\lambda L_{\text{align}}\)\(\lambda\) 平衡对齐强度(LLaMA-60M/OpenWebText 与 130M 取 0.5,60M/C4 取 0.3)。两分支在统一目标下联合优化,既稳定低秩训练又促成协作,从根上缓解抵消。

实验关键数据

模型:LLaMA-60M / 130M;数据:OpenWebText、C4;预算:保留 dense 的 10%/20%/30% 参数(总稀疏度 0.9/0.8/0.7)。稀疏度统一定义为 \(s=1-\#\text{params}/\#\text{params}_{\text{dense}}\),融合方法总稀疏度 \(s_{\text{total}}=1-d_{\text{conn}}-d_{\text{spec}}\),保证各方法可训练参数量相同。

主实验表格(验证集 PPL↓,节选)

数据集 方法 60M s=0.9 60M s=0.8 60M s=0.7 130M s=0.9 130M s=0.8 130M s=0.7
OpenWebText Dense 26.56 19.46
CHTs 33.03 29.84 28.12 24.75 22.67 21.48
CoLA 37.58 30.87 28.53 27.07 23.24 21.61
SLTrain 33.90 29.83 27.86 25.33 22.81 21.25
CHTsL 31.77 29.11 27.40 24.07 21.87 20.65
C4 Dense 33.21 24.55
CHTs 40.62 37.55 35.23 31.00 28.69 27.46
SLTrain 41.05 37.00 34.89 31.38 28.28 26.78
CHTsL 39.29 35.95 34.19 30.03 27.59 26.19

CHTsL 在所有模型/数据/稀疏度组合下都是最优稀疏方法,且最接近 dense(如 130M/OpenWebText/s=0.7 时 20.65 vs dense 19.46)。

消融实验表格(整合策略对比,PPL↓)

模型/数据 总稀疏度 Naive(直接相加) Act(+激活) Act+Align(+对齐)
60M/OpenWebText 0.9 32.64 32.21 31.77
60M/C4 0.9 189.55 39.66 39.29
60M/C4 0.7 591.42 34.55 34.33
130M/OpenWebText 0.9 119.35 24.45 24.07
130M/C4 0.7 920.16 26.55 26.19

Wilcoxon 符号秩检验:Act+Align vs Naive 的 \(p=0.00049\),vs Act 的 \(p=0.00049\),均 \(<0.05\),差异显著。可见朴素相加在极端稀疏下会灾难性崩溃(PPL 高达数百),激活先救回稳定性,对齐再稳步提升。

关键发现

  • 抵消主要发生在 Q、K:OCR 逐层曲线显示对齐损失显著降低 Query/Key 层的 OCR,而 V/O 和 FFN 因有残差连接更宽容。Q、K 决定注意力权重,对不一致最敏感,对齐稳住注意力图、缓解梯度冲突。
  • 稀疏配置敏感性:固定总稀疏度 0.7 时,OpenWebText(同质语料)偏好连接稀疏占多,低秩占比过高会崩溃;C4(异质多样语料)则更受益于较高低秩比例,因多样语言模式需要对整张权重矩阵做更广的适配。

亮点与洞察

  • 把模糊直觉做成可测指标再针对性优化:OCR 不只是诊断工具,它把"两分支打架"变成可观测、可优化的对象,方法论上很干净。
  • 极端稀疏下的崩溃现象很有说服力:Naive 在 s=0.9 时 PPL 飙到 189/591/920,直观说明朴素相加不是"次优"而是"不可用",凸显对齐/激活的必要性。
  • 抵消集中在 Q、K 的解释自洽:从注意力点积的敏感性出发解释为何 Q、K 受益最大,并被 OCR 曲线佐证,因果链完整。
  • 统一了两条互不相通的稀疏范式:首次让动态连接稀疏与动态低秩谱稀疏真正"动态地"协作,而非像 SLTrain 那样静态补充。

局限与展望

  • 规模偏小:仅验证到 LLaMA-130M,OpenWebText/C4 上的 PPL,未在更大模型(1B+)或下游任务上检验,结论的可扩展性待证。
  • 对齐损失是启发式:直接惩罚 \(\lVert S-L\rVert_F\) 让两分支趋同,但"对齐到何种程度才最优"缺乏理论刻画,\(\lambda\) 需逐设置搜索(0.3~0.5),泛化到新任务需重新调参。
  • 稀疏配置需网格搜索:稀疏/低秩参数分配按 5% 步长系统搜索后报告最优,实际部署成本不低,且不同数据集最优配置差异大(同质 vs 异质语料)。
  • 额外计算开销:逐层对齐损失和低秩激活带来的训练时算力/显存代价未量化分析。

相关工作与启发

  • 动态连接稀疏:SET(随机重连)→ RigL(按梯度再生)→ MEST(权重+梯度)→ CHT/CHTs(Cannistracci-Hebbian 脑连接组启发,SOTA),本文稀疏分支即取 CHTs。
  • 谱稀疏/低秩:LoRA(微调低秩)→ ReLoRA/GaLore(预训练低秩,但前向仍需 dense)→ CoLA(训练+推理全程低秩),本文低秩分支与激活设计借鉴 CoLA。
  • 混合尝试:SLTrain 是最早的连接+谱混合,但静态稀疏+朴素相加,本文正是针对其两大缺陷做改进。
  • 启发:当多个互补模块被简单加法组合时,"输出方向冲突"可能是隐藏的性能杀手;定义一个抵消度量 + 一致性正则,是把模块从竞争拉向协作的通用思路,可迁移到 MoE、多分支网络、集成等场景。

评分

  • 新颖性: ⭐⭐⭐⭐ — 首次真正动态融合连接稀疏与谱稀疏,OCR 指标和对齐损失的组合简洁有效,切中 SLTrain 的痛点。
  • 实验充分度: ⭐⭐⭐ — 两模型两数据三稀疏度 + 消融 + Wilcoxon 检验 + OCR 逐层证据扎实,但模型规模偏小、无下游任务、无大模型验证。
  • 写作质量: ⭐⭐⭐⭐ — 问题—诊断—方法—验证逻辑清晰,OCR 与 Q/K 解释自洽;个别句子有语法瑕疵但不影响理解。
  • 价值: ⭐⭐⭐⭐ — 为参数高效稀疏预训练提供了"如何融合两条范式"的可落地方案,抵消效应的视角对其他多分支组合也有借鉴意义。