COFT: Counterfactual-Conformal Decoding for Fair Chain-of-Thought Reasoning in Large Language Models¶
会议: ICML 2026
arXiv: 2605.30641
代码: 无
领域: LLM安全/公平性
关键词: 反事实公平性, 共形预测, 链式推理去偏, 解码时干预, 无训练去偏
一句话总结¶
COFT 通过在解码时构造反事实掩码分支并与原始分支进行 logit 融合,再用双分支分裂共形预测过滤 token,以无训练、免梯度的方式在冻结 LLM 上实现了逐步 token 级别的反事实公平性保证,将偏见指标降低 30–55%(中位 38%)且几乎不损失任务性能。
研究背景与动机¶
领域现状:大语言模型(LLM)在 CoT(链式推理)生成过程中会逐 token 暴露并放大训练数据中的社会偏见——即使最终答案看似中立,推理轨迹中也可能包含有害的刻板联想。
现有痛点:已有的去偏方案各有局限。数据清洗和微调需要重新训练且可能损害通用能力;辅助分类器引导的方法(如 DExperts、GeDi)依赖外部模型并继承其盲区;表示空间去偏(如 INLP)做全局线性投影,无法适应特定 prompt 语义,且可能误删合法内容。
核心矛盾:上述方法都缺少两个关键属性的同时满足:(1) 逐步统计保证——在每一步解码时,无法保证所选 token 在敏感属性替换后仍然稳定;(2) 局部反事实对等——公平性目标通常只在聚合层面定义,而非逐 token 操作。
本文目标:设计一个解码时框架,同时实现三个属性——逐 token 反事实不变性、无梯度/模型无关(适用于冻结权重)、可审计的逐步边际保证。
切入角度:将每个 prompt 同时构造为原始(factual)和掩码(masked)两个分支,通过对比两者的 logit 分布差异来定位并消除敏感属性驱动的偏差。再利用共形预测(Conformal Prediction)的分布无关保证来过滤不安全 token。
核心 idea:用反事实掩码 + logit 凸插值融合 + 双分支共形过滤,三阶段联合实现无训练的逐 token 反事实公平性解码。
方法详解¶
整体框架¶
COFT 是一个纯推理时框架,作用于冻结的因果语言模型。对于给定 prompt \(p\),COFT 在每一步解码时执行三个阶段:(1) 将 prompt 中的敏感片段替换为中性哨兵 token 生成掩码 prompt \(\tilde{p} = M(p)\);(2) 分别对原始和掩码 prompt 做前向推理,对两组 logit 做凸插值融合以衰减属性驱动的偏差;(3) 用离线校准的双分支共形预测阈值过滤 token,仅从两个分支都高概率支持的候选集中采样。整个流程只需多做一次缓存的前向传播,不需要训练、梯度或外部分类器。
关键设计¶
-
反事实掩码(Counterfactual Masking):
- 功能:生成与原始 prompt 在结构上完全对齐的"去敏感化"版本,作为反事实分支
- 核心思路:定义确定性掩码算子 \(M\),将 prompt 中每个敏感片段 \(s \in S\)(如性别、种族标识词)替换为中性哨兵 token
[MASK]。关键设计是保持 token 数量不变:若敏感片段被 tokenizer 切分为 \(k\) 个 token,则替换为 \(k\) 个哨兵副本,确保两个分支在绝对位置上严格对齐,使 \(z_t^F \leftrightarrow z_t^{CF}\) 的逐位配对比较有效 - 设计动机:删除敏感片段会破坏语法和注意力几何;替换为另一身份会注入新属性;唯有掩码既保留结构又切断与敏感属性的直接词汇关联
-
反事实 Logit 融合(Counterfactual Logit Fusion):
- 功能:在 logit 空间中衰减由敏感属性驱动的 token 概率偏差
- 核心思路:定义逐 token 属性敏感度 \(\Delta_t = z_t^F - z_t^{CF}\),通过凸插值生成融合 logit \(\hat{z}_t = (1-\lambda) z_t^F + \lambda z_t^{CF}\),其中 \(\lambda \in [0,1]\) 控制去偏强度。等价于 \(\hat{\pi}_t(v) \propto (\pi_t^F(v))^{1-\lambda} (\pi_t^{CF}(v))^{\lambda}\),即两个分支概率分布的归一化几何混合。\(\lambda\) 通过验证集 Pareto 曲线的拐点选取(通常 \(\lambda \approx 0.6\))
- 设计动机:融合先于过滤执行,可提前移除虚假放大方向,使后续共形过滤在已对齐的高概率区域上操作,减少误拒和过度保守阈值
-
双分支分裂共形过滤(Dual-Branch Split-Conformal Filtering):
- 功能:为每步解码构造一个经统计认证的候选 token 集,提供分布无关的边际覆盖保证
- 核心思路:定义双分支非一致性得分 \(s_t(v) = 1 - \min\{\hat{\pi}_t(v), \pi_t^{CF}(v)\}\)——仅当 token \(v\) 在融合分布和掩码分布中都有足够高概率时,该得分才小。离线在校准集上计算所有真实 next-token 的得分,取 \((1-\alpha)\) 分位数 \(q_t\) 作为阈值。在线解码时,候选集 \(C_t = \{v : \min\{\hat{\pi}_t(v), \pi_t^{CF}(v)\} \geq \tau_t\}\)(\(\tau_t = 1 - q_t\)),然后从 \(\hat{\pi}_t\) 限制在 \(C_t\) 上的条件分布采样;若 \(C_t = \emptyset\) 则回退到 \(\arg\max\)
- 设计动机:单分支共形预测无法保证反事实稳定性,双分支要求 token 同时被两个世界支持,直接操作化了反事实对等
实验关键数据¶
主实验:偏见度量¶
| 数据集 | 指标 | Vanilla | SDD | DExperts | DT-CD | COFT | 改善(vs DT-CD) |
|---|---|---|---|---|---|---|---|
| StereoSet (LLaMA-13B) | Bias↓ | 0.41 | 0.36 | 0.33 | 0.31 | 0.26 | -16% |
| CrowS-Pairs (LLaMA-13B) | Acc↑ | 58.7 | 60.1 | 61.0 | 61.3 | 63.5 | +2.2 |
| BBQ (LLaMA-13B) | Bias Rate↓ | 0.27 | 0.22 | 0.20 | 0.19 | 0.14 | -26% |
| BOLD (LLaMA-13B) | Toxicity↓ | 0.123 | 0.105 | 0.099 | 0.094 | 0.079 | -16% |
| Utrecht (LLaMA-13B) | DP Gap↓ | 0.184 | 0.153 | 0.149 | 0.141 | 0.118 | -16% |
| COMPAS (LLaMA-13B) | Bias Gap↓ | 0.161 | 0.147 | 0.141 | 0.136 | 0.119 | -12% |
| BBQ (Mistral-7B-Inst) | Bias Rate↓ | 0.24 | 0.20 | 0.18 | 0.17 | 0.12 | -29% |
| Utrecht (Mistral-7B-Inst) | DP Gap↓ | 0.173 | 0.146 | 0.141 | 0.136 | 0.112 | -18% |
消融实验¶
| 配置 | BiasAvg↓ | UtilityAvg↑ | 说明 |
|---|---|---|---|
| COFT (完整) | 0.129 | 68.0 | 三阶段全开 |
| w/o 融合 (仅CP) | 0.171 | 68.2 | 去掉 logit 融合后偏见指标升高 32% |
| 单分支CP (仅factual) | 0.158 | 68.1 | 无法保证反事实稳定性 |
| 仅融合 (无CP) | 0.149 | 67.9 | 缺少统计认证,残留偏见 |
关键发现¶
- Logit 融合贡献最大:单独去掉融合后 BiasAvg 从 0.129 升至 0.171(+33%),是三个组件中贡献最大的,因为它在 logit 源头机械性地衰减了属性驱动的 log-odds 偏差
- 双分支 vs 单分支 CP:双分支 CP 比单分支额外减少 18% 偏见(0.158→0.129),验证了要求 token 同时在两个世界有高概率支持的必要性
- 任务性能几乎无损:COFT 在 GSM8K、StrategyQA、ARC-easy、PIQA 上与 Vanilla 差距 ≤ 0.2 点,PPL 和 MAUVE 也几乎无差异
- 效率开销可控:额外约 10.2% 的吞吐量开销(相当于一次缓存前向传播),峰值显存仅增加 ≤ 0.8 GB
- \(\lambda\) 和 \(\alpha\) 的敏感性:\(\lambda\) 在 0.4–0.8 范围内偏见-效用 Pareto 曲线较平稳,默认取拐点 \(\lambda \approx 0.6\);\(\alpha = 0.10\) 是共形过滤的最佳风险水平
亮点与洞察¶
- 三阶段解耦设计:掩码→融合→过滤的流水线使每个组件可独立分析和替换,融合先压缩 logit 差异空间再交给共形过滤,二者协同效果远超单独使用,这种"先去噪再认证"的范式可迁移到任何需要在解码时施加约束的场景
- 共形预测在公平性中的创新应用:将分布无关的统计保证从传统的"置信集"场景拓展到"反事实稳定性认证",通过双分支得分设计将公平性约束转化为标准的分位数校准问题,方法论上具有通用性
- 完全免训练的实用优势:只需一次额外缓存前向传播(≤11% 开销),适用于任何冻结 LLM 检查点,无需权重访问、辅助分类器或微调,对于 API-only 部署场景极具实用价值
局限与展望¶
- 敏感片段检测依赖外部工具:COFT 控制的是解码时对已识别敏感片段的使用,但本身不是万能的隐式偏见检测器;未被识别的代理词(proxy terms)可能逃逸
- 保证为边际覆盖而非条件覆盖:共形预测提供的是边际(marginal)而非输入条件(conditional)保证,在分布严重偏移时可能失效
- 序列级别保证需要额外处理:当前逐步保证不直接延伸到整条推理链,需要用联合上界或 rollout score 校准来获得端到端控制
- \(\lambda\) 选取需验证集:需要一个干净的验证分裂来做 Pareto 拐点选取,在新领域部署时可能需要重新调优
相关工作与启发¶
- 反事实公平性(Kusner et al. 2017)提供了核心理论框架,COFT 将其操作化为逐 token 局部对等;共形预测(Vovk et al. 2005)提供分布无关保证工具,COFT 创新地将其适配到自回归解码的双分支场景
- 与 DExperts/GeDi 等推理时方法相比,COFT 不需要外部分类器,且提供统计保证;与 INLP 等表示去偏方法相比,COFT 是 prompt 级别自适应的而非全局固定投影
- 启发:可将类似的"反事实掩码+共形过滤"范式推广到其他可信 AI 目标(如隐私保护、事实性约束),通过定义不同的掩码算子和非一致性得分来实现不同的安全属性