Critical Patch-Aware Sparse Prompting with Decoupled Training for Continual Learning on the Edge¶
会议: CVPR 2026
arXiv: 2604.07399
代码: https://github.com/laymond1/cps-prompt
领域: 模型压缩 / 持续学习
关键词: 持续学习, 边缘设备, Prompt-based CL, Token Reduction, 训练效率
一句话总结¶
提出 CPS-Prompt 框架,通过任务感知的关键 patch 采样(CPS)和解耦 prompt-分类器训练(DPCT)两个模块,在边缘设备上实现 Prompt-based 持续学习的训练时内存和计算效率提升约 1.6 倍,同时准确率仅下降约 2%。
研究背景与动机¶
领域现状:持续学习(CL)在边缘设备(家用机器人、无人机、手机)上需要在有限内存和算力下不断适应新任务。Prompt-based 持续学习(PCL)通过冻结 ViT 骨干+轻量可学习 prompt 实现参数高效学习,但既有工作主要关注精度和推理效率。
现有痛点:PCL 方法如 C-Prompt 虽然精度高,但训练时内存开销巨大(4.3× 于本文方法),不适合部署在内存受限的边缘设备上。OS-Prompt 虽然简化了两阶段流水线,但反向传播时峰值内存仍然很高。
核心矛盾:现有 token reduction 方法(ToMe、PatchDropout)在与 PCL 结合时会丢弃任务关键 patch,导致精度严重下降——因为它们是"任务无关"的。
本文要解决:如何在 PCL 的两阶段架构中实现训练时内存和计算的显著节省,同时保持竞争力的精度?
切入角度:利用冻结 query encoder 最后一层的注意力和 value 信号来估计 patch 重要性,做任务感知的稀疏化;再通过解耦训练消除稀疏训练与全 patch 推理之间的表征错位。
核心 idea:任务感知的 patch 采样 + 解耦的 prompt/分类器训练 = 训练高效 + 精度保持。
方法详解¶
整体框架¶
CPS-Prompt 要解决的核心问题,是让 Prompt-based 持续学习能跑在内存和算力都吃紧的边缘设备上,而不是只在数据中心刷精度。它沿用了 PCL 标准的两阶段架构:先用冻结的 query encoder \(f_q\) 跑一次前向,从图像里提取出"任务线索",再把这条线索注入到 prompt-injected backbone \(f_p\) 里做分类。CPS-Prompt 的两个改动正好卡在这条流水线的两个关口上——在两阶段之间插入关键 patch 采样(CPS)模块,借第一次前向已经算好的注意力信号挑出真正关键的 patch,把进入第二阶段 backbone 的 token 砍掉一大半;再用解耦 prompt-分类器训练(DPCT)策略把 prompt 和分类器拆成两段训练,专门补偿"训练时只看稀疏 patch、推理时却看全 patch"带来的表征错位。前者省内存省算力,后者把省下来的精度找回来。
%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400, 'subGraphTitleMargin': {'top': 8, 'bottom': 16}}}}%%
flowchart TD
A["输入图像 x"] --> B["冻结 Query Encoder f_q<br/>跑一次前向产出任务线索"]
subgraph CPS["关键 Patch 采样(CPS)"]
direction TB
C["取末层 class→patch 注意力 a<br/>与 value 范数 ‖V‖₂"] --> D["临界分数 s = a · ‖V‖₂"]
D --> E["温度 softmax 转采样概率 p"]
E --> F["无替换多项式采样<br/>选 k=⌊(1−r)·N⌋ 个 patch"]
end
B --> CPS
CPS --> G["稀疏 token 序列 X_sampled"]
subgraph DPCT["解耦 Prompt-分类器训练(DPCT)"]
direction TB
H["阶段一:稀疏 patch 输入<br/>联合训 prompt φ + 分类器 θ"] --> I["冻结 prompt φ"]
I --> J["阶段二:全 patch 输入<br/>只微调分类器 θ"]
end
G --> DPCT
DPCT --> K["推理:全 patch 分类"]
关键设计¶
1. Critical Patch Sampling:用冻结 backbone 已有的注意力信号免训练地挑关键 patch
直接把通用 token reduction(ToMe、PatchDropout)套到 PCL 上会出事——它们是"任务无关"的,砍 patch 时不知道哪些对当前任务的类别判别最关键,容易把任务相关区域误删,精度直线下降。CPS 的切入点是:PCL 本来就要让 query encoder 跑一次前向,那一层的注意力其实已经隐含了"哪些 patch 重要"的判断,白白浪费太可惜。具体地,从 query encoder 最后一层取出 class token 对每个 patch token 的注意力权重 \(A^L_{\text{cls},j}\),再取该 patch 的 value 向量 L2 范数 \(\|V^L_j\|_2\),两者相乘得到临界分数 \(s_j = A^L_{\text{cls},j} \cdot \|V^L_j\|_2\)——注意力反映这个 patch 对类别表征贡献多大,value 范数反映它的特征本身有多显著,乘起来才是"既被关注、信息又足"的综合重要性。分数不是直接 Top-k 截断,而是经温度缩放 softmax 转成采样概率
再以无替换多项式采样选出 \(k = \lfloor(1-r) \cdot N\rfloor\) 个 patch(\(r\) 是削减率),每个 mini-batch 重新采一次。这里特意用带温度的多项式采样而不是直接取分数最高的 Top-k,是被消融验证过的选择:固定 Top-k 每轮只喂同一批"最高分" patch,相当于把模型锁死在一小撮区域上;带温度的随机采样让每轮见到的 patch 略有不同,这种受控随机性在训练中起到类似数据增强的探索作用,尤其在削减比例较大时对泛化到持续到来的新任务帮助明显。温度 \(\tau\) 越小分布越尖锐越接近确定性,越大越偏向随机探索(实验中 \(\tau=0.1\) 最佳)。整个打分过程都基于冻结 backbone,零额外训练、可无缝插进现有 PCL 流水线。
2. Decoupled Prompt and Classifier Training:把 prompt 和分类器拆两段训,消除稀疏训练与全 patch 推理的错位
CPS 省了内存,但留下一个隐患:训练时 prompt 只见过稀疏 patch,推理时却要面对全 patch,prompt 学到的表征和实际推理分布对不上,精度会被拖下来。DPCT 的办法是把总共 \(E\) 个 epoch 切成两段:前 \(\lfloor \lambda \cdot E \rfloor\) 个 epoch 用稀疏 patch 输入,联合优化 prompt \(\phi\) 和分类器 \(\theta\),目标是
剩下的 epoch 把 prompt 冻住,只用全 patch 输入单独微调分类器,目标是
第一段在稀疏输入下学高效的 prompt,第二段让分类器在真实的全 patch 分布上把表征重新对齐,正好把 CPS 引入的错位补回来。而且冻结 prompt 之后梯度不再回传到 prompt,第二段的反向传播也更省算力——精度和效率两头都照顾到。
损失函数 / 训练策略¶
- 使用标准交叉熵损失
- Prompt 阶段和分类器阶段各用 Adam 优化器
- 学习率 cosine decay,起始 0.001
- 最优超参:patch 削减率 \(r=0.4\),温度 \(\tau=0.1\)
实验关键数据¶
主实验¶
| 数据集 | 指标 | CPS-Prompt | C-Prompt (SOTA) | CODA-Prompt | 差异说明 |
|---|---|---|---|---|---|
| CIFAR-100 | ACC↑ | 66.89 | 68.34 | 67.06 | 仅差 1.45% |
| ImageNet-R | ACC↑ | 49.96 | 53.32 | 50.24 | 差 3.36% |
| CUB-200 | ACC↑ | 52.85 | 52.64 | 53.96 | 与 CODA 持平 |
效率比较(Jetson Orin Nano 上测量):
| 方法 | 峰值内存倍率 | 训练时间倍率 | 能耗倍率 |
|---|---|---|---|
| CPS-Prompt | 1× | 1× | 1× |
| CODA-Prompt | ~1.6× | ~1.5× | ~1.6× |
| C-Prompt | ~4.3× | ~3.1× | ~3.3× |
消融实验¶
| 配置 | ACC↑(ImageNet-R) | 内存 | 训练时间 | 说明 |
|---|---|---|---|---|
| CODA-Prompt 基线 | 50.24 | 440MB | 1788s | 基线 |
| + PD (随机丢 patch) | 45.32 | 253MB | 1388s | 精度大幅下降 |
| + CPS (任务感知) | 47.16 | 253MB | 1389s | 比 PD 好 1.8% |
| + PD + DPCT | 47.96 | 253MB | 1126s | DPCT 恢复精度 |
| + CPS + DPCT (完整) | 49.28 | 253MB | 1126s | 最优配置 |
关键发现¶
- CPS 和 DPCT 提供互补收益:CPS 提升 patch 质量,DPCT 消除表征错位
- 即使内存削减超过 60%,CPS-Prompt 仍保持基线 90% 以上精度
- 随机采样在低 phase ratio 下表现尤其优于确定性 Top-k
- 温度 \(\tau=0.1\)(较尖锐分布)在所有数据集上表现最佳
亮点与洞察¶
- 真正的边缘部署视角:在 Jetson Orin Nano 上做了完整的实测(内存、时间、能耗),而非仅仅报告理论 FLOPs
- 任务感知 token reduction:巧妙利用 PCL 两阶段架构中本已存在的 query forward pass 信号,零额外训练开销
- 解耦训练的简洁性:分两阶段分别用稀疏/全 patch 训练 prompt/分类器,设计简单但有效
局限与展望¶
- 仅在 ViT-Tiny/16 上验证,更大模型(ViT-Base/Large)上的表现未知
- 固定的 patch 削减比 \(r=0.4\),未探索动态自适应策略
- 仅考虑 class-incremental 设定,未涉及 task-incremental 或 domain-incremental
- 未与更新的 VLM-based CL 方法对比
相关工作与启发¶
- 与 ToMe(token merge)和 PatchDropout 的对比表明,任务无关的 token reduction 在 PCL 中表现糟糕
- DPCT 的思路与知识蒸馏中的"训练-推理不一致"问题异曲同工
- 可以启发将 CPS 思路推广到其他需要 token reduction 的 ViT 下游任务
评分¶
- 新颖性: ⭐⭐⭐⭐ 任务感知 patch 采样+解耦训练的组合新颖,但单独看每个模块技术贡献有限
- 实验充分度: ⭐⭐⭐⭐⭐ 三个数据集+真实边缘硬件+完整消融+效率分析
- 写作质量: ⭐⭐⭐⭐ 结构清晰,算法流程图和伪代码完备
- 价值: ⭐⭐⭐⭐ 对边缘持续学习有实际意义,但整体 scope 偏小众