Learning to Weight Parameters for Training Data Attribution¶
会议: ICLR 2026
OpenReview: https://openreview.net/forum?id=EhUkQp9Yah
代码: 待确认
领域: 可解释性 / 训练数据归因(Data Attribution)
关键词: data attribution, influence functions, gradient-based attribution, parameter heterogeneity, self-supervised, diffusion models
一句话总结¶
本文指出梯度归因里"不同参数组的归因质量差异巨大"被现有方法忽视,提出一个统一框架用自监督直接从数据里学一组参数组权重 \(w\),无需标注就把 TracIn / TRAK / EK-FAC 等方法的归因精度系统性拉高,并能解耦 subject/style/background 等语义维度。
研究背景与动机¶
领域现状:训练数据归因(data attribution)要回答"模型这个输出,是哪些训练样本最该负责",对版权、隐私、数据治理至关重要。可扩展的主流做法是梯度类方法:TracIn 直接算 query 与训练样本的梯度内积;Influence Functions / TRAK 进一步引入 Hessian(或其低秩/Kronecker 近似 + 随机投影)来对梯度做二阶预处理。
现有痛点:这些方法对参数的处理要么一视同仁(TracIn 把所有参数的梯度等权拼接),要么靠 Hessian 近似做隐式加权。但 Hessian 在大模型上不可解、模型也很少真正收敛到最优、再叠加随机投影的信息损失——隐式加权来自一个"不精确且带噪"的信号,并不能可靠反映各参数组的真实重要性。
核心矛盾:作者用 LDS(Linear Datamodeling Score)实测发现,归因信号在网络里高度不均匀:UNet 里 up_blocks.1/2 的平均 LDS(5.58)远高于其它块,self-attention 的输出投影层(attn.to_out)显著强于 cross-attention 的 k/q/v 投影;更进一步,不同参数组还专门负责不同语义——style 归因集中在浅层,background 归因集中在特定 attention 组件。也就是说,归因质量随参数位置和功能系统性变化,而这恰恰是现有等权/隐式加权方案没有利用的信息。
本文目标:不再依赖噪声 Hessian 近似,而是显式、直接地从数据中学习每个参数组的重要性权重,让任意梯度归因方法都能即插即用地变强,同时让权重本身可解释。
核心 idea:[显式参数加权 + 自监督 bootstrap] 把归因写成"对角权重矩阵 \(\mathrm{Diag}(w)\) 调制的加权相似度"这一统一形式,再用现有方法的 top-k 排名当伪正样本,自监督地优化 \(w\) 去最大化归因信噪比。
方法详解¶
整体框架¶
方法分两步。第一步是统一加权形式:把模型参数切成 \(M\) 个不相交组(按层/张量),给每组一个非负标量权重 \(w_j\),让归因得分变成被 \(\mathrm{Diag}(w)\) 调制的加权相似度,这一形式同时涵盖 TracIn(核为单位阵)和 TRAK 类核方法。第二步是自监督学权重:没有归因 ground-truth,就用 base 方法当前排名的 top-k 当伪正样本,构造一个"信号/噪声"比例的损失去优化 \(w\),迭代中 top-k 集合随权重动态刷新,从弱信号里 bootstrap 出越来越好的权重。
flowchart TD
A[训练集 D + query 集 Q] --> B[base 归因方法<br/>TracIn/TRAK/EK-FAC...]
B --> C[按参数组拆梯度特征<br/>g_j(x), 预计算各组归因贡献]
C --> D[加权得分<br/>τ̃ = g(query)ᵀ·Diag(w)·K·g(xₙ)]
D --> E[取 top-k 当伪正样本<br/>I_top-k(w)]
E --> F[自监督损失 L_SSL<br/>正样本均分 / ℓ2 范数]
F -->|softmax 参数化保证 w≥0| G[更新 w]
G -->|每步重算 top-k| D
G --> H[学到的参数组权重 w*<br/>可整体/可按语义 subject·style·background]
关键设计¶
1. 统一的参数加权归因形式:把"该信任哪些参数组"显式写进得分。作者把任意梯度归因抽象成一个可加权的双线性形式。设 query 与训练样本 \(x_n\) 的拼接梯度特征为 \(g(\cdot)=[g_1,\dots,g_M]\),引入非负权重 \(w\) 后,加权得分为
其中 \(\mathrm{Diag}(w)\) 把 \(w_j\) 沿组 \(j\) 的所有维度复制展开,\(K\) 是相似度度量。当 \(K=I\) 退化成 TracIn 式加权内积;当 \(K=(\Phi^\top\Phi+\lambda I)^{-1}\) 就是 TRAK 这类核方法。关键的工程取舍是:权重只乘在 query 一侧,训练侧的 \(K\,g(x_n)\) 当作固定项预计算一次——因为若两侧对称加权,每次更新 \(w\) 都要重算所有训练样本的核项,代价不可承受;而在 \(K=I\) 时单侧/双侧加权本就等价。直观理解,\(w\) 编码的是"读取 query 梯度信号时,每个参数组有多可信"。
2. 自监督权重学习:用 base 方法的 top-k 当伪标签去最大化信噪比。既然拿不到真值归因,就假设"base 方法当前打分最高的 top-k 训练样本"是弱可信的伪正样本。对 query \(x_{query}\),设全部加权得分向量为 \(\tilde\tau(x_{query},D;w)\)、其 top-k 索引集为 \(I_{top\text{-}k}(w)\),损失定义为伪正样本的平均得分除以整体得分幅度:
分子相当于信号强度估计、分母 \(\ell_2\) 范数相当于总噪声估计,作者在附录证明最小化它等价于最大化归因得分的 SNR。优化时 \(I_{top\text{-}k}(w)\) 每步随更新后的 \(w\) 重新评估,使排名信号自举式变好;再对 query 分布 \(Q\) 取期望并加 \(\lambda\lVert w\rVert^2\) 正则,最终 \(w^*=\arg\min_{w\ge0} \mathbb{E}_{x_{query}\sim Q}[L_{SSL}]+\lambda\lVert w\rVert^2\),非负性用 softmax 参数化保证。
3. 极致高效:预计算组级贡献,整个学习"一分钟内收敛"。效率来自两点:一是待学权重极少(每个参数组一个标量,如每层一个);二是注意到得分对各参数组线性可分,于是把每组的归因贡献预计算并缓存,优化时只对组级标量得分施加权重,完全避免每步用加权梯度特征重算归因。这让权重学习通常一分钟内收敛。
4. 细粒度语义归因:换 query 集就能解耦 subject/style/background。同一套机制可学多组语义专用权重 \(w_{style}, w_{subject}, w_{background}\)。诀窍是构造只强调目标属性、其余属性留空的 query 集:例如学 style 权重时,prompt 只指定风格、不指定主体和背景,于是风格相关训练样本排名升高,优化自然把"持续贡献风格语义"的参数组权重抬高。学出的语义权重比通用权重更聚焦,呈现明显不同的分布模式。
实验关键数据¶
主实验表格¶
图像分类(ImageNet,LDS %,越高越好):
| 方法 | ResNet-18 w/o w | ResNet-18 w | ViT-B/16 w/o w | ViT-B/16 w |
|---|---|---|---|---|
| TracIn | 11.39 | 23.92 | 9.67 | 17.63 |
| TRAK | 16.86 | 23.30 | 14.77 | 16.74 |
语言建模(WikiText-103, GPT-2-small,LDS %):
| 方法 | w/o w | w |
|---|---|---|
| TracIn | 6.31 | 9.23 |
| TRAK | 12.69 | 14.63 |
| LoGRA | 11.42 | 12.86 |
| EK-FAC | 15.14 | 18.33 |
扩散模型(LDS %,四数据集,节选):
| 方法 | ArtBench-2 w/o→w | Naruto w/o→w | SB-Pokemon w/o→w | CIFAR-2 w/o→w |
|---|---|---|---|---|
| TracIn | 17.63→22.02 | 10.54→13.59 | 9.34→11.79 | 1.39→8.48 |
| TRAK | 18.39→22.15 | 14.61→17.02 | 10.68→12.24 | 8.51→10.59 |
| D-TRAK | 22.72→25.15 | 16.75→17.85 | 33.88→35.05 | 10.17→12.18 |
| DAS | 30.47→31.58 | 18.72→20.44 | 33.55→36.12 | 12.66→13.79 |
消融实验表格¶
下游任务验证(误标检测 AUC / tail-patch,越高越好):
| 任务 | 方法 | w/o w | w |
|---|---|---|---|
| 误标检测 AUC (ResNet-18) | TracIn | 54.40 | 61.46 |
| 误标检测 AUC (ViT-B/16) | TracIn | 71.27 | 83.58 |
| 误标检测 AUC (ViT-B/16) | TRAK | 80.08 | 83.48 |
| Tail-patch (WikiText) | TracIn | 4.66 | 5.60 (Δ+0.94) |
| Tail-patch (WikiText) | EK-FAC | 5.54 | 6.09 (Δ+0.55) |
细粒度归因(SB-Pokemon, D-TRAK, Recall@10 %):学语义专用权重后,对应语义维度的召回明显优于无权重基线,验证了 subject/style/background 的可解耦性。
关键发现¶
- 普适增益:在分类/语言/扩散三大任务、6 种 base 方法(TracIn/TRAK/EK-FAC/JourneyTRAK/D-TRAK/DAS)上加权后 LDS 一致提升,CIFAR-2 上 TracIn 从 1.39 暴涨到 8.48。
- 归因异质性是模型的内在稳定属性:附录的余弦相似度分析显示,per-group LDS 与学到的权重在不同数据集、不同归因方法间高度一致,说明这种异质性不是特定设置的 artifact。
- k 不敏感、训练极快:超参 \(k\) 在很宽范围内稳定,权重学习通常一分钟内收敛。
亮点与洞察¶
- 把"隐式加权"显式化:以往 Hessian 预处理本质上也是在加权,但它来自不可靠近似;本文把这一步剥离出来直接从数据学,绕开了噪声近似链条。
- 统一框架的优雅:一个 \(\mathrm{Diag}(w)\cdot K\) 形式同时覆盖内积法和核方法,让所有梯度归因方法即插即用受益。
- 单侧加权 + 预计算是让方法真正可扩展的关键工程洞察,把权重学习从"每步重算全量核项"降到分钟级。
- 可解释性是顺带的红利:学出的权重直接给出"哪些层负责风格、哪些负责主体"的语义地图,归因方法第一次能做语义解耦。
局限与展望¶
- 自监督依赖 base 方法的初始排名当伪标签,若 base 方法在某任务上极差,bootstrap 的弱信号可能不足以纠偏。
- 权重粒度是"参数组级"(每层一个标量),未探索组内更细粒度(单参数/方向级)加权是否带来额外收益。
- 语义解耦实验主要在合成 SB-Pokemon 等可控数据集上,真实开放域生成中"语义"边界更模糊,专用 query 集的构造可能更难。
- 训练侧核项固定为预计算值,意味着权重无法反作用于训练侧表示,理论上是次优近似(但换来可扩展性)。
相关工作与启发¶
- 梯度归因谱系:TracIn(梯度内积)→ Influence Functions / TRAK(Hessian/核预处理)→ 扩散专用 JourneyTRAK / D-TRAK / DAS、LLM 专用 LoGRA / TrackStar。本文与它们正交,是一层可叠加的"参数加权"增强。
- 参数重要性异质性:在剪枝、机器遗忘、知识编辑、量化(知识定位 / 混合精度)里早被利用,但本文首次把它显式引入数据归因。
- 启发:把"模型里不同部件承担不同功能"这一普遍观察,转化成一个轻量、自监督、可解释的加权层,是一个可推广到其它依赖梯度相似度任务(如检索、影响估计、数据筛选)的范式。
评分¶
- 新颖性: ⭐⭐⭐⭐ — 首次显式建模数据归因中的参数异质性,统一框架 + 自监督 bootstrap 思路清晰且与已有方法正交。
- 实验充分度: ⭐⭐⭐⭐ — 覆盖分类/语言/扩散三域、6 种 base 方法、多数据集,并补充误标检测/tail-patch 下游任务与语义解耦,证据扎实。
- 写作质量: ⭐⭐⭐⭐ — 动机(异质性实测)→ 统一形式 → 自监督目标的逻辑链顺畅,图表充分。
- 价值: ⭐⭐⭐⭐ — 即插即用、分钟级训练、一致涨点,且带来语义可解释性,对版权/数据治理等落地场景实用。