跳转至

OmniDraft: A Cross-Vocabulary Online Adaptive Drafter for On-Device Speculative Decoding

会议: NeurIPS 2025
arXiv: 2507.02659
代码: 暂无
领域: LLM效率
关键词: 推测解码, 跨词表, 在线蒸馏, 自适应起草, 端侧推理

一句话总结

提出 OmniDraft 框架,通过在线 n-gram 缓存实现跨词表推测解码、混合蒸馏损失在线对齐草稿模型与目标模型、并结合自适应起草长度控制,使单个轻量 Llama-68M 模型可为 Vicuna-7B、Qwen2-7B、Llama3-8B 等不同目标模型提供推测解码加速(1.5-2x)。

研究背景与动机

推测解码(Speculative Decoding)通过小模型(草稿模型)预测多个后续 token,再由大模型(目标模型)一次性验证来加速 LLM 推理。然而目前面临两个核心难题:

草稿与目标模型的紧耦合:现有方法要求草稿模型与目标模型来自同一模型家族(如都是 Llama 系列),共享相同的分词器和词表。一旦目标模型更换为其他家族(如 Qwen),草稿模型就无法使用。

在线部署场景的动态需求:用户在端侧使用时可能切换不同目标模型,且期望延迟随使用时间逐渐改善。

现有工作 UAG 提出了词表交集映射,但只处理直接映射的 token,无法解决"假拒绝"问题——草稿模型提出的多个子 token 在目标词表中对应一个合并 token 时会被错误拒绝。此外,离线蒸馏对齐方法假设目标模型固定,无法适应动态切换场景。

核心 idea:构建一个"one drafter for all"范式——通过 n-gram 缓存解决跨词表映射、通过在线蒸馏实现动态对齐、通过自适应起草控制效率。

方法详解

整体框架

OmniDraft 包含三个核心组件:(1)跨词表 n-gram 缓存用于草稿/目标词表间的token翻译;(2)混合蒸馏损失用于在线对齐草稿模型;(3)自适应起草头用于动态调整提议长度。整个流程:草稿模型生成 token → n-gram 缓存翻译到目标词表空间 → 目标模型验证 → 接受/拒绝结果反馈用于更新缓存和蒸馏草稿模型。

关键设计

  1. 跨词表 N-gram 缓存:核心思路是维护一个缓存 \(\mathcal{C} = \{(t_i, [d_j^i]_{j=1:n})\}\),记录目标 token \(t_i\) 与草稿 token 序列 \([d_1^i, d_2^i, \cdots, d_n^i]\) 之间的映射关系。在提议阶段,扫描草稿 token 序列并查找 n-gram 缓存进行合并映射,概率计算为:

    \(q'(t_i) = \begin{cases} q(d_i), & \text{直接映射} \\ \prod_j q(d_j^i), & \text{n-gram 映射} \end{cases}\)

对于修正阶段的残差分布计算,需要在整个目标词表上定义 \(q'\),对前缀子 token 做概率调整:\(q'(d_1^i) = q(d_1^i) - \prod_j q(d_j^i)\),确保概率质量的正确分配。缓存在推理过程中实时更新——每当出现新的未见映射实例就加入。设计动机:相比 UAG 仅处理词表交集,n-gram 缓存能处理合并 token 的情况,避免"假拒绝",提高接受率。

  1. 跨词表混合蒸馏损失:在线蒸馏分为两部分——对直接映射 token 使用反向 KL 散度以获得丰富的监督信号,对 n-gram token 使用负对数似然(NLL)因为只有可靠的点概率估计。总损失:

    \(\mathcal{L}_{\text{cross\_vocab\_distill}}(\theta) = \mathcal{L}_{\text{DM}}(\theta) + \lambda \mathcal{L}_{\text{N-gram}}(\theta)\)

其中 \(\mathcal{L}_{\text{DM}}\) 对直接映射 token 计算 KL 散度,\(\mathcal{L}_{\text{N-gram}}\) 对 n-gram token 计算 NLL。\(\lambda\) 可设为固定超参或动态权重(如目标模型对该 n-gram 的验证概率)。该设计使草稿模型能在在线推理过程中持续与(可能变化的)目标模型对齐。

  1. 在线自适应起草:使用轻量头网络 \(f_\phi\) 预测当前提议 token 的接受率。通过累积拒绝概率控制是否提前终止提议:

    \(P(\exists 1 \leq i \leq k, \text{s.t. } y_i \text{ rejected}) > \gamma \Rightarrow \text{exit}\)

提出两种训练变体:联合训练(蒸馏 + 自适应头同步更新)和交替训练(自适应头多次更新/蒸馏一次更新,使用更大 buffer 缓解分布漂移)。

损失函数 / 训练策略

  • 蒸馏采用在策略(on-policy)数据,即草稿模型自身生成的数据
  • 固定 \(\lambda = 0.2\) 对所有任务/实验
  • LoRA 微调作为轻量替代方案支持动态适配器切换
  • 自适应头使用加权 BCE 损失,以接受率 \(\min(1, p/q)\) 为标签

实验关键数据

主实验:跨词表在线蒸馏

目标模型 方法 GSM8K Acc/Speed MBPP+HE Acc/Speed Alpaca Acc/Speed XSum Acc/Speed
Llama3-8B SpD_DM (baseline) 0.10 / 0.94x 0.09 / 1.03x 0.09 / 0.96x 0.11 / 0.91x
Llama3-8B \(\mathcal{L}_{\text{DM}}\) + \(\lambda\mathcal{L}_{\text{N-gram}}\) 0.42 / 1.70x 0.27 / 1.33x 0.20 / 1.30x 0.24 / 1.24x
Qwen2-7B SpD_DM (baseline) 0.14 / 1.04x 0.09 / 0.91x 0.13 / 1.01x 0.12 / 0.96x
Qwen2-7B \(\mathcal{L}_{\text{DM}}\) + \(\lambda\mathcal{L}_{\text{N-gram}}\) 0.37 / 1.61x 0.26 / 1.36x 0.20 / 1.30x 0.22 / 1.22x

消融实验:自适应起草(Vicuna-7B 目标)

方法 GSM8K Acc/Speed MBPP+HE Acc/Speed Alpaca Acc/Speed XSum Acc/Speed
SpD (vanilla) 0.21 / 1.44x 0.14 / 1.22x 0.20 / 1.44x 0.20 / 1.42x
Distill Only 0.42 / 2.20x 0.35 / 1.92x 0.25 / 1.57x 0.23 / 1.53x
Joint Distill+Adapt 0.61 / 2.08x 0.51 / 1.91x 0.44 / 1.61x 0.42 / 1.59x
Interleaved Distill+Adapt 0.52 / 2.15x 0.48 / 1.94x 0.41 / 1.60x 0.38 / 1.58x

关键发现

  • N-gram 缓存即使不训练也能带来显著提升(cache hit 0.87),配合蒸馏效果最佳
  • 缓存大小很小(1-5 MB),适合端侧部署
  • 框架可扩展到更大目标模型(Qwen2.5-32B),加速可达 2.05x
  • 交替训练变体在加速上略优于联合训练,但联合训练接受率更高
  • LoRA 微调性能接近全参微调,支持多目标模型动态切换

亮点与洞察

  • "One drafter for all" 范式极具实际价值——端侧只需部署一个 68M 的草稿模型即可服务所有目标 LLM
  • N-gram 缓存是优雅的工程设计,将跨词表映射问题转化为在线缓存查找
  • 混合蒸馏损失对直接映射和 n-gram token 采用不同损失函数,体现了对问题结构的深入理解

局限与展望

  • 在线适应仅一次遍历数据流,对全新数据可能不稳定
  • 尚未解决特殊 token(如多模态 token)的跨词表映射
  • 自适应起草头在线训练不够稳定,可能低估最优提议长度
  • 缓存没有淘汰策略,内存受限设备需要优化

相关工作与启发

  • 与 UAG 相比,n-gram 缓存从词表交集扩展到多对一映射,解决"假拒绝"
  • 与 OSD(Online Speculative Decoding)相比,增加了跨词表能力
  • 启发:端侧推理场景下,轻量+通用+可适应的设计比重量级但专用的方案更有价值

评分

  • 新颖性: ⭐⭐⭐⭐ N-gram 缓存解决跨词表是新颖贡献,但自适应起草借鉴 SpecDec++
  • 实验充分度: ⭐⭐⭐⭐ 多任务多目标模型,消融完整,但缺少与更多基线对比
  • 写作质量: ⭐⭐⭐⭐ 框架清晰,公式推导详细,图示直观
  • 价值: ⭐⭐⭐⭐⭐ 端侧通用草稿模型是重要实际需求,框架完整且可落地