跳转至

Self-Correction Distillation for Structured Data Question Answering

会议: AAAI 2026
arXiv: 2511.07998
代码: 无
领域: 图学习
关键词: 知识蒸馏, 结构化数据问答, CoT蒸馏, 错误纠正, 小模型

一句话总结

提出自纠正蒸馏(SCD)方法,通过错误提示机制(EPM)和两阶段蒸馏策略,将大规模LLM(GPT4)的结构化数据问答能力高效迁移到小规模LLM(8B),在5个基准上取得最优蒸馏性能。

研究背景与动机

结构化数据问答(Structured Data QA)包括表格问答(Table QA)、知识图谱问答(KGQA)和时序知识图谱问答(Temporal KGQA),是NLP领域的重要方向。近年来,TrustUQA等统一框架利用LLM生成结构化查询来回答自然语言问题,取得了显著进展。

然而,这些框架面临两个核心挑战:

部署限制:许多实际场景缺乏部署100B+参数大模型的硬件条件,且出于数据隐私和API稳定性考虑,用户更倾向于本地部署小模型

小模型能力不足:直接将统一QA框架适配到小规模LLM(<10B参数)面临严峻挑战——小模型在生成结构化查询时容易出现多种错误,包括调用未定义函数、参数非法、函数嵌套调用等

作者观察到,小模型的错误类型具有明确的模式性(如Figure 1所示),这启发了"先识别错误类型,再针对性纠正"的思路。现有CoT蒸馏方法要么只学习正确输出(Naive-SFT),要么错误采样范围有限(PERsD),无法充分解决小模型的结构化查询生成问题。

方法详解

整体框架

SCD基于TrustUQA的两层查询框架,核心创新在于:(1) 设计嵌入查询执行器的错误提示机制EPM;(2) 提出教师蒸馏+自蒸馏的两阶段训练策略,同时提升小模型的查询生成和错误纠正能力。

关键设计

1. 错误提示机制(EPM)

EPM嵌入在查询执行器中,负责检测LLM生成查询中的错误并提供定制化的错误信息。错误分为两大类:

解析错误(Parsing Error):在查询执行前的解析阶段发现,包括5种类型: - 未定义函数名:如调用subtract()而非合法函数 - 非法参数:如max(set=…, key=…)key不合法 - 参数不一致:同时为互斥参数赋值 - 非法比较符:使用<<等非法比较运算符 - 非原子操作:函数嵌套调用sum(set=set_negation(…))

执行错误(Execution Error):在查询通过解析后的执行阶段发现: - Python执行器异常:如类型不匹配的运算 - 中间步骤结果为空:中间查询返回空集,可能是关系或实体映射错误

每种错误类型都有对应的定制化错误消息模板,为LLM提供具体的纠正指导。EPM基于正则表达式精确匹配,解析准确率达100%。

2. 多轮纠正流程

给定问题和数据schema,LLM生成初始查询 → 查询执行器解析执行 → 若无错误则输出 → 若有错误,EPM报告错误信息 → LLM分析错误原因并纠正 → 重新解析执行 → 循环直到无错误或达到最大纠正次数 \(MCT\)

设第 \(i\) 轮纠正后的错误分析为 \(CoT^{(i)}\),更新后的查询为 \(q_{upd}^{(i)}\),经过 \(n\) 轮后得到可正确执行的查询 \(q_{cor}\)

3. 两阶段蒸馏策略

第一阶段——教师蒸馏

学生从教师学习查询生成和错误纠正两种能力。教师(GPT4)和学生同时生成初始查询,收集两者的错误查询,由教师进行多轮纠正。

查询生成损失:

\[\mathcal{L}_q = -\sum_j \log P_\mathcal{M}(q_{cor(j)}^t | \text{P\_Q}(Q,S); q_{cor(<j)}^t)\]

错误纠正损失(关键设计——每轮都以最终正确查询为目标):

\[\mathcal{L}_c = -\sum_{i=1}^{n}\sum_j \log P_\mathcal{M}([CoT_t^{(i)};q_{cor}^t]_{(j)} | \text{P\_C}(Q,S,q_{upd}^{t(i-1)},err^{(i-1)}))\]

总损失:\(\mathcal{L}_1 = \mathcal{L}_q + \mathcal{L}_c\)

关键创意:每轮训练都以最终正确查询(而非当轮更新查询)为目标,形成难度递增的多级课程——后期轮次中 \(CoT^{(i)}\)\(q_{cor}\) 匹配度更高,纠正更容易;前期轮次匹配度低,学生需要更深层的理解。

第二阶段——自蒸馏

学生用自己的输出迭代提升。对给定问题,学生生成查询并自行纠正,增大正确查询的生成概率,减小错误查询的生成概率:

\[\mathcal{L}_2 = -\sum_{i=1}^{n}(\mathcal{S}(q_{cor}^s) - \mathcal{S}(q_{upd}^{s(i-1)}))\]

其中 \(\mathcal{S}(q) = \sum_j \log P_\mathcal{M}(q_{(j)}|\text{P\_Q}(Q,S);q_{(<j)})\),本质是对比学习,让模型远离自身易犯的错误。

损失函数 / 训练策略

  • 使用LoRA微调Llama3.1-8B-Instruct
  • 2×NVIDIA A100 40GB GPU,batch size=1,梯度累积=8,训练3个epoch
  • AdamW优化器,学习率0.0001,cosine调度,warmup ratio=0.1
  • 使用SentenceBERT做演示检索,检索数=15,查询生成演示数=8
  • 最大纠正次数MCT=3,纠正演示数=8

实验关键数据

主实验

在5个数据集(WikiSQL、WTQ、MetaQA、WebQSP、CronQuestion)上评估:

方法 WikiSQL WTQ MetaQA-1hop MetaQA-3hop WebQSP CronQ-Complex
SOTA 89.5 66.7 98.4 99.4 85.7 95.4
GPT4 w/ EPM 91.1 53.2 98.6 99.3 91.7 97.3
Llama3.1 74.5 35.1 94.6 83.4 75.2 80.9
Naive-SFT 79.3 42.2 96.1 88.1 78.2 86.4
PERsD 84.2 44.3 97.2 97.3 79.7 89.1
KPOD 86.1 46.5 97.1 97.8 82.3 92.4
SCD (Ours) 86.9 48.9 97.4 98.2 83.7 95.1

GPT4+EPM在大多数数据集上超过SOTA,SCD在所有蒸馏方法中最优。

消融实验

配置 WikiSQL WTQ MetaQA-3hop WebQSP CronQ-Complex 说明
SCD完整 86.9 48.9 98.2 83.7 95.1 完整方法
- 自蒸馏 86.1 46.2 97.2 82.2 93.6 去掉第二阶段
- EPM错误信息 83.7 44.8 97.7 79.5 90.9 EPM只报告有错,不提供详细信息

关键发现

  1. 错误纠正率差异:解析错误的纠正率较高(教师40.2%,学生45.5%),执行错误较低(教师27.6%,学生15.7%),因为解析错误更表面化,执行错误的根因更难定位
  2. 泛化性验证:在未见数据集TabFact上,SCD的8B小模型达到83.4%准确率,接近基于GPT3.5的StructGPT(87.6%)
  3. 超参数分析:大模型只需2轮纠正即可纠正大部分错误,小模型需要3轮以上;纠正演示数和纠正次数的增加对经过SCD训练的模型有效,但对未做错误纠正训练的模型无效

亮点与洞察

  1. 错误分类学的启发:将LLM查询错误系统性地分类为解析错误和执行错误两大类6小类,每类配备定制化错误消息,是一种非常实用的工程设计
  2. 多轮纠正中以最终正确查询为目标的设计巧妙——自然形成了难度递增的课程学习,无需额外设计难度调度
  3. 两阶段蒸馏互补:教师蒸馏解决"从零到一"的能力迁移,自蒸馏解决"自身易犯错误的针对性预防",设计逻辑清晰
  4. EPM的通用性:EPM不仅帮助小模型,也让大模型(GPT4)超越了现有SOTA,说明精确的错误反馈本身就有重要价值

局限与展望

  1. 当正确答案为"null/none"时,EPM会误报"空中间步骤"执行错误(假阴性)
  2. 对于需要"LLM函数"的复杂推理,8B模型受限于固有能力上限
  3. 实体/关系对齐错误不在EPM处理范围内
  4. 依赖TrustUQA框架的特定查询语言,泛化到SQL等其他查询语言需要额外适配

相关工作与启发

  • TrustUQA的Condition Graph表示和两层查询方法为统一结构化QA提供了基础架构
  • PERsD通过教师纠正学生代码实现个性化蒸馏,但错误采样范围有限
  • KPOD模拟人类认知的渐进式学习,SCD的多轮难度递增设计是另一种实现课程学习的途径
  • 方法可扩展到其他需要结构化输出的任务(如NL2SQL、代码生成)

评分

  • 新颖性: ⭐⭐⭐⭐ — EPM和两阶段蒸馏结合创新性强
  • 实验充分度: ⭐⭐⭐⭐⭐ — 5个数据集、3种数据类型、完善的消融和分析
  • 写作质量: ⭐⭐⭐⭐ — 结构清晰,图表丰富
  • 价值: ⭐⭐⭐⭐ — 小模型部署结构化QA的实用方案