Skipping the Zeros in Diffusion Models for Sparse Data Generation¶
会议: ICML 2026
arXiv: 2605.01817
代码: https://github.com/PhilSid/sparsity-exploiting-diffusion
领域: 扩散模型 / 稀疏数据生成 / 科学计算 / 生成式建模
关键词: 稀疏扩散, 隐扩散, 自回归解码, 单细胞测序, 量能器图像
一句话总结¶
SED 把扩散模型从"对所有维度做全密集去噪"改成"只在非零维度上跑扩散+自回归解码维度-值对",让计算量从随维度线性增长变成几乎随非零数恒定,同时严格保留科学数据中"显式零"这一语义信息。
研究背景与动机¶
领域现状:扩散模型 (DM) 在图像、音频、文本等密集连续数据上做到了 SOTA,DDPM/LDM 是事实标准。但许多科学数据本质稀疏:粒子物理量能器 (~95% 零)、单细胞 RNA 测序 scRNA (90-98% 零)、推荐系统、稀疏图像等——大部分坐标确实是零(没有信号),而不是"接近零"。
现有痛点:(1) 稀疏被抹平——把稀疏数据喂给 DDPM/LDM,输出会在所有零坐标处产生虚假非零值,破坏稀疏模式(见 Figure 1 的 MNIST 演示)。这对生物上的 dropout、物理上的"无能量沉积"等具有明确物理含义的零是灾难性的。(2) 计算被浪费——率失真分析 (Figure 3) 显示扩散模型对零维度只分配几乎零比特率,但去噪网络依然在所有维度上跑前向传播——"信息容量集中在信息维度,计算却不集中"。(3) 现有补丁式方法都有缺陷:thresholding 输出(DDPM-T/LDM-T)保留稀疏但牺牲细节;domain-specific 模型(SARM)依赖手工螺旋采样的零位置先验,泛化差;离散 DM 处理不了连续值;最近 Ostheimer 2025 的稀疏感知 DM 给每个维度配一个二进制指示符,反而把维度翻倍。
核心矛盾:信息(rate)是稀疏的,但密集 DM 的计算和参数化是 dense 的,二者天然不匹配。要么保留 dense 架构丢稀疏语义,要么放弃 Transformer 的可扩展性手工编码稀疏先验,没有兼得方案。
本文目标:(1) 让 DM 在只处理非零维度的紧凑表示上做扩散,使计算量随信号密度而非环境维度增长;(2) 严格保留零模式,输出零的位置必须与真实数据匹配;(3) 不依赖手工先验,跨多个科学域 (物理/生物/视觉) 都能用。
切入角度:作者把每个稀疏样本表示成"(维度索引集合, 对应非零值集合)"对,用 Transformer 编码器把这个变长集合 pool 成一个定长的密集潜变量 \(\mathbf{z}\);扩散在这个密集潜空间上跑(成熟稳定),解码时自回归生成"下一个维度-值对"直到 [EOS]。
核心 idea:把"扩散应该跨所有维度"这个隐含假设打破——扩散在潜空间保持密集稳定,但输入空间表示和解码都跳过零,让算力跟着信号走。
方法详解¶
整体框架¶
两阶段 LDM-style 训练:(1) SAVAE (Sparsity-Aware VAE)——非零提取器 NZE 把 \(\mathbf{x}^{(i)} \in \mathbb{R}^s\) 转成 \((\mathbf{d}^{(i)}, \mathbf{v}^{(i)})\)(长度 \(l_i \ll s\));Transformer 编码器 \(q_\phi\) 处理变长的"维度-值"token 序列,输出经均值池化得定长 \(\mathbf{z}\);自回归解码器 \(p_\theta = p_{\theta_1}(\mathbf{d}) p_{\theta_2}(\mathbf{v})\) 依次预测下一个维度(多项分布)和对应值(高斯)。(2) 潜空间扩散——SAVAE 训练完冻结,标准 DM (DDPM/DDIM) 在 \(\mathbf{z}\) 空间上学习,记作 SEDP/SEDI。生成时从 \(\mathcal{N}(0,I)\) 采样去噪得 \(\mathbf{z}_0\),再用 \(p_\theta\) 自回归解码出维度-值序列填回稀疏向量。
关键设计¶
-
Sparse-to-Dense 潜编码(SAVAE):
- 功能:把高维稀疏数据压缩到固定大小的密集潜表示,让扩散模型可以在低维致密空间稳定训练。
- 核心思路:用 NZE 提取非零索引集合 \(\mathbf{d}^{(i)} = \{j | \mathbf{x}^{(i)}_j \neq 0\}\) 和对应值 \(\mathbf{v}^{(i)}\),长度 \(l_i = \|\mathbf{x}^{(i)}\|_0 \ll s\)。引入 Dimension Encoding (DE)——形式上类似位置编码 \(\text{DE}_{(dim, 2i)} = \sin(dim / k^{2i/d_{model}})\)(\(k=20000\))但编的是维度索引而非序列位置。值用线性投影 embed,DE 与值 embedding 相加输入 Transformer 编码器。编码器输出 mean pooling 得定长 \(\mathbf{z}\)(作者也试过加 [CLS] token,性能近似但稳定性弱,故用 mean pooling)。用重参数化采样 \(\mathbf{z} \sim q_\phi\)。
- 设计动机:Transformer 输入序列长度随非零数 \(l_i\) 而非环境维度 \(s\),所以计算量与稀疏度(而非维度)线性相关——这是 SED 计算效率优势的根源。对扩散后端友好:得到的 \(\mathbf{z}\) 是密集低维向量,可以直接套用成熟的 DDPM/DDIM。
-
自回归稀疏解码(dim-value 对):
- 功能:把潜变量 \(\mathbf{z}\) 解回稀疏空间,必须既决定哪些维度非零(变长),又决定它们的值。
- 核心思路:解码器 \(p_\theta(\mathbf{d}, \mathbf{v} | \mathbf{z})\) 分解成两个头:\(p_{\theta_1}\) 在剩余维度上输出多项分布预测下一个非零维度索引,\(p_{\theta_2}\) 在该位置输出高斯分布预测对应值。两个头联合训练,按维度索引升序的规范顺序解码,遇到 [EOS] 停止。重要的是:训练时用 teacher forcing 并行评估所有目标对,只有采样时才必须串行——因此训练效率不被自回归拖累。
- 设计动机:稀疏样本的非零数 \(l_i\) 因样本而异(一个细胞可能少数活跃基因,另一个很多),固定长度无法表达;自回归是结构性需求。规范升序排序消除了排列歧义。
-
稀疏感知潜扩散 SED + 自条件训练:
- 功能:在 SAVAE 提供的密集低维潜空间上做扩散,损失只关注信息维度。
- 核心思路:固定 SAVAE 后训练扩散 \(\mathcal{L}_{\text{SED}}(\theta) = \mathbb{E}\|\mathbf{z}_0 - f_\theta(\mathbf{z}_t, t, \tilde{\mathbf{z}}_0)\|^2\),其中 \(\mathbf{z}_t = \sqrt{\gamma(t)}\mathbf{z}_0 + \sqrt{1-\gamma(t)}\boldsymbol{\epsilon}\),\(\tilde{\mathbf{z}}_0\) 是自条件 (Chen 2023) 的先前估计。骨干网用 MLP-based 时间条件 U-Net(不用卷积,因为 \(\mathbf{z}\) 没有网格空间结构)。采样时支持 DDPM/DDIM 两种 sampler,分别得到 SEDP/SEDI。
- 设计动机:扩散本身保持成熟稳定,所有"稀疏定制"集中在 SAVAE;解耦让模块化、可替换。在密集 \(\mathbf{z}\) 上跑扩散比在高维稀疏空间训练稳定得多。
损失函数 / 训练策略¶
SAVAE 用 \(\beta\)-VAE 形式:\(\mathcal{L}_{\text{SAVAE}} = -\log p_\theta(\mathbf{d}, \mathbf{v}|\mathbf{z}) + \beta \cdot D_{\text{KL}}(q_\phi \| p)\),\(\beta = 10^{-6}\) 轻正则;负对数似然分解为维度部分(多项)和值部分(高斯)。两阶段训练:先训 SAVAE 到收敛,再冻结后训扩散。\(\gamma(t)\) 从 1 到 0 单调递减,用 log-SNR 参数化。
实验关键数据¶
主实验¶
跨三域六数据集:物理——muon 信号/背景量能器图像 (\(32 \times 32\), ~95% 零);生物——Tabula Muris (98% 零) 和 Human Lung PF (96% 零) scRNA;视觉——MNIST (81% 零), Fashion-MNIST (50% 零)。指标:物理用 Wasserstein 距离 \(W_P\) 对 \(P_T\) 和 invariant mass;scRNA 用 SCC 和 MMD;视觉用稀疏度直方图匹配。
| 任务 | 模型 (参数) | 指标 | 值 | 备注 |
|---|---|---|---|---|
| Muon Signal | DDPM (37M) | \(W_P (P_T)\)↓ | 220.32 | dense 完全失败 |
| Muon Signal | DDPM-T (37M) | \(W_P (P_T)\)↓ | 24.22 | thresholding 缓解 |
| Muon Signal | SARM (25M, domain) | \(W_P (P_T)\)↓ | 28.01 | 用螺旋先验 |
| Muon Signal | SEDP (15M) | \(W_P (P_T)\)↓ | 16.31 | 参数最少且最优 |
| Tabula Muris | DDPM (5M) | SCC↑ / MMD↓ | 0.50 / 3.60 | dense 失败 |
| Tabula Muris | scDiffusion (5M, domain) | SCC↑ / MMD↓ | 0.71 / 1.53 | 需预训练 cell corpus |
| Tabula Muris | SEDP (4M) | SCC↑ / MMD↓ | 0.74 / 0.55 | 无需 domain pretraining |
| Human Lung PF | SEDP (4M) | SCC↑ / MMD↓ | 0.82 / 0.54 | 胜 scDiffusion |
消融实验¶
| 配置 | 关键指标 | 说明 |
|---|---|---|
| SED 完整 (SAVAE + 潜扩散 + AR 解码) | 最优 | — |
| LDM (无稀疏感知) | LDM SCC=0.87 但 MMD=5.82 | 形状对了但分布距离大 |
| LDM-T (thresholded) | SCC 暴跌至 0.26 | thresholding 破坏 LDM 的细节 |
| DDPM/DDIM 原始 | 完全失败 | 既不保稀疏也距离大 |
| SARM (物理 domain prior) | 弱于 SED | 手工螺旋先验不通用 |
| Sampling 时间 (95% 稀疏) | SED 24ms vs DDPM 453ms | 高稀疏度下加速 19× |
| 维度排序正确率 (Fashion-MNIST) | 100% | 序列长但简单不出错 |
| 维度排序正确率 (Muon BG) | 87.9% | 复杂结构下出错率最高 |
关键发现¶
- 计算几乎随维度恒定:scRNA 数据保持 1000 活跃基因,添加额外零基因到 27k 维,DDPM/LDM 线性增长,SED 几乎平坦(Figure 2/9)。
- 稀疏度越高加速越明显:Muon (95%) 加速近 20×,MNIST (81%) 加速 7×,Fashion-MNIST (50%) 几乎无加速——SED 的优势严格随稀疏度递增。
- scRNA 任务上 SED 同时胜过 dense baseline、thresholded 变体、domain-specific scDiffusion——后者还要昂贵的 cell corpus 预训练。
- 自回归错排序在长序列上不会系统性恶化(Fashion-MNIST 长序列 100% 正确),错率高的是复杂结构 (Muon BG 87.9%)——说明难点是数据复杂度而非长度。
亮点与洞察¶
- "信息维度才需计算"是 dense DM 时代被忽视的原则:作者用率失真分析直观证明 DDPM 给零维度几乎零比特率却给它们全部算力,是非常漂亮的诊断分析。
- Dimension Encoding 复用位置编码思想:把"序列位置"换成"特征索引"是非常简洁的工程改造,让 Transformer 直接吃稀疏 (index, value) 对。
- 两阶段解耦:SAVAE 解决"如何表示稀疏",扩散负责"如何生成密集潜变量"——干净分离,每个模块可独立替换或升级。
- 这个思路可迁移到 graph generation(边稀疏)、3D 点云(空间稀疏)、稀疏注意力中的 KV cache 压缩、稀疏激活 MoE 等。
局限与展望¶
- 依赖自回归解码——采样时必须串行,对超长非零序列仍有延迟开销;作者明确点名要找非自回归替代。
- 维度排序错误会产生不真实样本(Figure 7 MNIST 演示),尤其在复杂稀疏模式(如粒子物理)上 12% 样本被影响;但作者验证物理实验数据上这种错误不影响整体生成质量。
- 低稀疏度下优势消失:Fashion-MNIST (50% 零) 上 SED 的采样时间几乎与 DDPM 持平,LDM 反而最快——SED 是 high-sparsity 专用。
- scRNA 上 SED 略弱于某些 LDM 配置的 SCC(但 MMD 更好)——说明保稀疏与匹配整体分布之间有微妙的取舍。
- 缺少与 sparse Transformer 类工作(如 XTrimoGene)的更细比较。
相关工作与启发¶
- vs DDPM/LDM (dense baselines):他们在所有维度上跑去噪,SED 只在非零上跑,效率随稀疏度线性加速且严格保零模式。
- vs DDPM-T / LDM-T (post-hoc thresholding):thresholding 是 hack——保稀疏但破坏边界细节;SED 是结构性方案。
- vs SARM (Lu 2021, domain-specific):SARM 用螺旋采样硬编码物理零位置先验,泛化差;SED 不用 domain prior 且性能更好。
- vs Discrete DM (Austin 2021):离散 DM 能精确生成零但只能在离散态空间,处理不了连续稀疏值;SED 能同时建模"哪里有信号 + 信号值多少"。
- vs scDiffusion (Luo 2024):scDiffusion 需要在大规模 cell corpus 上预训练 autoencoder;SED 端到端无 domain pretraining 即可超过它。
- vs Sparse Transformer (XTrimoGene, scGPT):那些是 representation learning,用 MSE 只在 masked 基因上算 loss;SED 是生成式,需要从头预测维度索引。
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ "扩散应该跨稀疏维度跑"这一隐含假设被打破,自回归 dim-value 对的解码视角是真正原创
- 实验充分度: ⭐⭐⭐⭐⭐ 跨物理/生物/视觉三个截然不同的稀疏域,对比覆盖 dense/thresholded/domain-specific 全谱
- 写作质量: ⭐⭐⭐⭐⭐ 率失真诊断 → 方法动机 → 实验验证的逻辑链非常清晰,图示直观
- 价值: ⭐⭐⭐⭐⭐ 在科学计算(高稀疏)场景下提供可立刻使用的计算+保真双赢方案,开源代码已放出