COFT: Counterfactual-Conformal Decoding for Fair Chain-of-Thought Reasoning in Large Language Models¶
会议: ICML 2026
arXiv: 2605.30641
代码: 无
领域: LLM安全/公平性
关键词: 反事实公平性, 共形预测, 链式推理去偏, 解码时干预, 无训练去偏
一句话总结¶
COFT 通过在解码时构造反事实掩码分支并与原始分支进行 logit 融合,再用双分支分裂共形预测过滤 token,以无训练、免梯度的方式在冻结 LLM 上实现了逐步 token 级别的反事实公平性保证,将偏见指标降低 30–55%(中位 38%)且几乎不损失任务性能。
研究背景与动机¶
领域现状:大语言模型(LLM)在 CoT(链式推理)生成过程中会逐 token 暴露并放大训练数据中的社会偏见——即使最终答案看似中立,推理轨迹中也可能包含有害的刻板联想。
现有痛点:已有的去偏方案各有局限。数据清洗和微调需要重新训练且可能损害通用能力;辅助分类器引导的方法(如 DExperts、GeDi)依赖外部模型并继承其盲区;表示空间去偏(如 INLP)做全局线性投影,无法适应特定 prompt 语义,且可能误删合法内容。
核心矛盾:上述方法都缺少两个关键属性的同时满足:(1) 逐步统计保证——在每一步解码时,无法保证所选 token 在敏感属性替换后仍然稳定;(2) 局部反事实对等——公平性目标通常只在聚合层面定义,而非逐 token 操作。
本文目标:设计一个解码时框架,同时实现三个属性——逐 token 反事实不变性、无梯度/模型无关(适用于冻结权重)、可审计的逐步边际保证。
切入角度:将每个 prompt 同时构造为原始(factual)和掩码(masked)两个分支,通过对比两者的 logit 分布差异来定位并消除敏感属性驱动的偏差。再利用共形预测(Conformal Prediction)的分布无关保证来过滤不安全 token。
核心 idea:用反事实掩码 + logit 凸插值融合 + 双分支共形过滤,三阶段联合实现无训练的逐 token 反事实公平性解码。
方法详解¶
整体框架¶
COFT 想解决的是:冻结的 LLM 在逐 token 生成推理链时会暴露并放大社会偏见,而我们既不想重训也不想接外部分类器,还希望每一步都有可审计的统计保证。它的做法是把每个 prompt 同时跑成"原始"和"去敏感"两个世界,再用三个串起来的阶段处理每一步解码:先把 prompt 中的敏感词替换成中性哨兵得到掩码 prompt \(\tilde{p}=M(p)\),再对原始和掩码两组 logit 做凸插值融合以衰减属性驱动的偏差,最后用离线校准的双分支共形阈值过滤候选,只从两个世界都高概率支持的 token 里采样。整条流水线只多一次缓存的前向传播,不碰梯度、不动权重、不依赖辅助模型。
%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400}}}%%
flowchart TD
A["输入 prompt p"] --> M["反事实掩码 M(p)<br/>敏感片段→[MASK],保持 token 数对齐"]
A --> F["原始分支(冻结 LLM)<br/>logit z_t^F"]
M --> CF["掩码分支(冻结 LLM)<br/>logit z_t^CF"]
F --> FUSE["反事实 Logit 融合<br/>ẑ_t = (1−λ)z_t^F + λz_t^CF"]
CF --> FUSE
FUSE --> FILT["双分支分裂共形过滤<br/>C_t = {v : min(π̂_t, π_t^CF) ≥ τ_t}"]
CF --> FILT
FILT -->|C_t 非空| SAMP["从限制在 C_t 上的 π̂_t 采样"]
FILT -->|C_t 为空| FALL["回退 argmax"]
SAMP --> OUT["输出 next token"]
FALL --> OUT
关键设计¶
1. 反事实掩码:构造一个与原始 prompt 严格对齐的"去敏感"分支
要做逐 token 的反事实对比,就必须有一个除了敏感属性以外一切都相同的对照世界。COFT 定义了一个确定性掩码算子 \(M\),把 prompt 里每个敏感片段 \(s\in S\)(性别、种族等标识词)替换成中性哨兵 token [MASK]。这里的关键是保持 token 数量不变:如果某个敏感片段被 tokenizer 切成 \(k\) 个 token,就替换成 \(k\) 个哨兵副本,从而保证原始分支和掩码分支在每个绝对位置上严格对齐,逐位的 \(z_t^F \leftrightarrow z_t^{CF}\) 配对比较才成立。之所以选"掩码"而不是别的——直接删掉敏感片段会破坏语法和注意力几何,替换成另一个身份又会注入一份新偏见,只有掩码既保住了结构又切断了与敏感属性的直接词汇关联。
2. 反事实 Logit 融合:在 logit 源头机械地抹掉属性驱动的概率偏差
有了两个对齐的分支后,两者在同一位置的 logit 之差 \(\Delta_t = z_t^F - z_t^{CF}\) 就恰好刻画了"这一步有多少是被敏感属性推着走的"。COFT 不去显式建模这个偏差,而是直接做凸插值得到融合 logit \(\hat{z}_t = (1-\lambda) z_t^F + \lambda z_t^{CF}\),其中 \(\lambda\in[0,1]\) 控制去偏强度。在概率空间这等价于两个分支分布的归一化几何混合 \(\hat{\pi}_t(v) \propto (\pi_t^F(v))^{1-\lambda}(\pi_t^{CF}(v))^{\lambda}\),\(\lambda\) 越大越靠近去敏感世界。\(\lambda\) 通过验证集偏见-效用 Pareto 曲线的拐点选取,通常落在 \(\lambda\approx 0.6\)。把融合放在过滤之前是有意为之:先在 logit 层面压掉虚假放大方向,后续的共形过滤就能在已经对齐的高概率区域上工作,少误拒、阈值也不必过度保守。
3. 双分支分裂共形过滤:给每一步采样集盖上一个分布无关的统计认证
光做融合还不够——它降低了偏差,却没有"这个 token 在反事实下也稳定"的保证。COFT 为此设计了双分支非一致性得分 \(s_t(v) = 1 - \min\{\hat{\pi}_t(v), \pi_t^{CF}(v)\}\):只有当 token \(v\) 在融合分布和掩码分布里都足够高概率时,得分才低。离线阶段在校准集上对所有真实 next-token 算出这个得分,取 \((1-\alpha)\) 分位数 \(q_t\) 当阈值;在线解码时构造候选集 \(C_t = \{v : \min\{\hat{\pi}_t(v), \pi_t^{CF}(v)\} \geq \tau_t\}\)(\(\tau_t = 1 - q_t\)),然后从限制在 \(C_t\) 上的 \(\hat{\pi}_t\) 条件分布采样,若 \(C_t\) 为空则回退到 \(\arg\max\)。借助分裂共形预测的分布无关性质,每步都拿到一个边际覆盖保证。单分支共形只看原始世界、保证不了反事实稳定,而双分支强制 token 同时被两个世界支持,正是把"反事实对等"直接操作化成了一个标准的分位数校准问题。
实验关键数据¶
主实验:偏见度量¶
| 数据集 | 指标 | Vanilla | SDD | DExperts | DT-CD | COFT | 改善(vs DT-CD) |
|---|---|---|---|---|---|---|---|
| StereoSet (LLaMA-13B) | Bias↓ | 0.41 | 0.36 | 0.33 | 0.31 | 0.26 | -16% |
| CrowS-Pairs (LLaMA-13B) | Acc↑ | 58.7 | 60.1 | 61.0 | 61.3 | 63.5 | +2.2 |
| BBQ (LLaMA-13B) | Bias Rate↓ | 0.27 | 0.22 | 0.20 | 0.19 | 0.14 | -26% |
| BOLD (LLaMA-13B) | Toxicity↓ | 0.123 | 0.105 | 0.099 | 0.094 | 0.079 | -16% |
| Utrecht (LLaMA-13B) | DP Gap↓ | 0.184 | 0.153 | 0.149 | 0.141 | 0.118 | -16% |
| COMPAS (LLaMA-13B) | Bias Gap↓ | 0.161 | 0.147 | 0.141 | 0.136 | 0.119 | -12% |
| BBQ (Mistral-7B-Inst) | Bias Rate↓ | 0.24 | 0.20 | 0.18 | 0.17 | 0.12 | -29% |
| Utrecht (Mistral-7B-Inst) | DP Gap↓ | 0.173 | 0.146 | 0.141 | 0.136 | 0.112 | -18% |
消融实验¶
| 配置 | BiasAvg↓ | UtilityAvg↑ | 说明 |
|---|---|---|---|
| COFT (完整) | 0.129 | 68.0 | 三阶段全开 |
| w/o 融合 (仅CP) | 0.171 | 68.2 | 去掉 logit 融合后偏见指标升高 32% |
| 单分支CP (仅factual) | 0.158 | 68.1 | 无法保证反事实稳定性 |
| 仅融合 (无CP) | 0.149 | 67.9 | 缺少统计认证,残留偏见 |
关键发现¶
- Logit 融合贡献最大:单独去掉融合后 BiasAvg 从 0.129 升至 0.171(+33%),是三个组件中贡献最大的,因为它在 logit 源头机械性地衰减了属性驱动的 log-odds 偏差
- 双分支 vs 单分支 CP:双分支 CP 比单分支额外减少 18% 偏见(0.158→0.129),验证了要求 token 同时在两个世界有高概率支持的必要性
- 任务性能几乎无损:COFT 在 GSM8K、StrategyQA、ARC-easy、PIQA 上与 Vanilla 差距 ≤ 0.2 点,PPL 和 MAUVE 也几乎无差异
- 效率开销可控:额外约 10.2% 的吞吐量开销(相当于一次缓存前向传播),峰值显存仅增加 ≤ 0.8 GB
- \(\lambda\) 和 \(\alpha\) 的敏感性:\(\lambda\) 在 0.4–0.8 范围内偏见-效用 Pareto 曲线较平稳,默认取拐点 \(\lambda \approx 0.6\);\(\alpha = 0.10\) 是共形过滤的最佳风险水平
亮点与洞察¶
- 三阶段解耦设计:掩码→融合→过滤的流水线使每个组件可独立分析和替换,融合先压缩 logit 差异空间再交给共形过滤,二者协同效果远超单独使用,这种"先去噪再认证"的范式可迁移到任何需要在解码时施加约束的场景
- 共形预测在公平性中的创新应用:将分布无关的统计保证从传统的"置信集"场景拓展到"反事实稳定性认证",通过双分支得分设计将公平性约束转化为标准的分位数校准问题,方法论上具有通用性
- 完全免训练的实用优势:只需一次额外缓存前向传播(≤11% 开销),适用于任何冻结 LLM 检查点,无需权重访问、辅助分类器或微调,对于 API-only 部署场景极具实用价值
局限与展望¶
- 敏感片段检测依赖外部工具:COFT 控制的是解码时对已识别敏感片段的使用,但本身不是万能的隐式偏见检测器;未被识别的代理词(proxy terms)可能逃逸
- 保证为边际覆盖而非条件覆盖:共形预测提供的是边际(marginal)而非输入条件(conditional)保证,在分布严重偏移时可能失效
- 序列级别保证需要额外处理:当前逐步保证不直接延伸到整条推理链,需要用联合上界或 rollout score 校准来获得端到端控制
- \(\lambda\) 选取需验证集:需要一个干净的验证分裂来做 Pareto 拐点选取,在新领域部署时可能需要重新调优
相关工作与启发¶
- 反事实公平性(Kusner et al. 2017)提供了核心理论框架,COFT 将其操作化为逐 token 局部对等;共形预测(Vovk et al. 2005)提供分布无关保证工具,COFT 创新地将其适配到自回归解码的双分支场景
- 与 DExperts/GeDi 等推理时方法相比,COFT 不需要外部分类器,且提供统计保证;与 INLP 等表示去偏方法相比,COFT 是 prompt 级别自适应的而非全局固定投影
- 启发:可将类似的"反事实掩码+共形过滤"范式推广到其他可信 AI 目标(如隐私保护、事实性约束),通过定义不同的掩码算子和非一致性得分来实现不同的安全属性