Conference: NeurIPS'25

github: https://github.com/Theia-4869/CDPruner


My Thoughts

相较于 Attention-based methods 中的 attention shift 问题和 Similarity-based methods 中忽略了 query 的问题,这篇论文从 增加 visual tokens 全局多样性(同时保持对 query 的关注) 的角度出发,提出了 CDPruner,通过将 token 多样性建模为 DPP(Determinantal Point Process)求解问题,实现了 SOTA 的 pruning 效果。


Motivations

在多模态大语言模型(MLLMs)中,视觉 token 的输入长度往往远大于文本 token,从而带来高昂的推理开销。例如:

  • LLaVA-1.5 将一张 336×336 图像转换为 576 tokens
  • LLaVA-NeXT 的高分辨率版本在输入加倍的情况下生成 2,880 tokens
  • LongVA 处理 2,000 帧视频时生成超过 200K visual tokens
  • LongVILA 能处理 6,000 帧并产生 超过 1M visual tokens,导致巨大的计算成本。

Challenges

现有的视觉 token 剪枝方法主要分为两类:

  1. Attention-based pruning:利用 text-visual attention 分数衡量视觉 token 的重要性。
  2. Similarity-based pruning:依据视觉 token 间的相似性移除冗余部分。

然而,这两种方法均存在固有缺陷。Attention-based 方法只考虑重要性,容易保留大量重复 token;Similarity-based 方法忽略了指令(instruction)的关联性,导致无法针对问题进行动态剪枝,从而性能次优。

“However, as pointed out by Zhang et al. [2024b] and Wen et al. [2025a], such methods suffer from attention shift, which compromises pruning accuracy.”

⚠ 什么是 attention shift?

(gpt回答 不一定正确)

Attention shift 指在 token 剪枝或输入变化后,模型的注意力分布发生偏移,导致原本重要的 token 被低估或信息丢失,从而降低剪枝准确性。由于 Transformer 的 self-attention 是全局依赖的,删除部分 token 会改变 query-key 分布,使剩余 token 的注意力重新分布,如果不考虑这种 shift,基于注意力的剪枝方法容易出现决策失真,保留重复 token 或遗漏关键 token,从而影响推理性能。(Zhang et al., 2024b, “[CLS] Attention Is All You Need for Training-Free Visual Token Pruning: Make VLM Inference Faster”; Wen et al., 2025a, “Token Pruning in Multimodal Large Language Models: Are We Solving the Right Problem?”)

⚠ attention-based 方法不兼容高效实现如 FlashAttention

此外,attention-based 方法还依赖显式 attention 权重,不兼容高效实现如 FlashAttention。

因为attention-based 方法依赖显式的 attention 权重矩阵(attention map) 来评估每个视觉 token 的重要性,需要在推理过程中访问完整的 Softmax(QKᵀ) 结果或其中的行向量。但高效注意力实现(如 FlashAttention)的核心思想正是避免显式构建与存储整个注意力矩阵

FlashAttention 将注意力计算分块(block-wise)执行,通过在 GPU 的高速寄存器和片上 SRAM 中即时计算 Softmax(QKᵀ)V,并在每个块结束后立刻丢弃中间的 QKᵀ 结果与注意力权重,仅保留最终输出。这种方法极大降低了显存读写和带宽占用,使得注意力计算的复杂度从内存瓶颈(memory-bound)变为计算受限(compute-bound),从而实现高效推理。

然而,这种实现方式带来三个关键后果,使 attention-based pruning 与 FlashAttention 不兼容

  1. 不可访问性:FlashAttention 不显式存储或返回完整的 attention map,因此无法直接提取每个 token 的注意力权重;而 attention-based 方法恰恰需要这些分数来判断保留与删除。
  2. 存储与性能冲突:若强行修改 FlashAttention 以输出 attention map,就必须重新显式计算并缓存 QKᵀ 和 Softmax 结果,这会破坏其内存复用机制,重新引入大规模内存访问与显存占用,性能急剧下降。
  3. 多头与层次不稳定性:不同层、不同头的注意力权重分布差异显著,且 FlashAttention 内部按块累积 Softmax,会进一步导致 attention 值在不同分块间不可直接比较,增加了基于 attention 值进行统一排序和剪枝的难度。

因此,在采用 FlashAttention 或其他高效注意力优化(如 xFormers、PagedAttention)的现代 MLLM 中,attention-based pruning 方法无法直接使用或会破坏推理加速效果

因此,作者提出从 “beyond attention or similarity” 的角度重新思考 token pruning。


Contributions

  1. 提出 CDPruner:一种 plug-and-play、model-agnostic 的视觉 token 剪枝方案,通过最大化条件多样性(conditional diversity)实现高效动态剪枝;
  2. 将 token pruning 问题重构为 DPP(Determinantal Point Process),联合考虑 feature similarity 与 instruction relevance;
  3. 在多种视觉语言基准上实验验证,CDPruner 在不同压缩率下均取得 SOTA。

Methods

Determinantal Point Process (DPP)

DPP 最初用于刻画费米子系统的“排斥效应”,在机器学习中被广泛用于建模集合选择的全局多样性

与 Max-Min Diversity Problem (MMDP) 仅关注极端样本不同,DPP 强调全局平衡和代表性。传统 DPP 仅考虑样本间相似度,而本论文在此基础上引入 instruction relevance,使剪枝过程同时考虑“token 相关性”与“多样性”。


DPP with Token Similarity

核心思想: 将视觉 token 的 pairwise similarity 建模为一个核矩阵 ( L ),并通过最大化其行列式(determinant)来选择最具代表性的子集。

定义每个视觉 token 的特征向量为 ( $H^v_i \in \mathbb{R}^d$ ),则相似核矩阵为:

$$ L_{ij} = \frac{H^v_i \cdot H^v_j}{|H^v_i| , |H^v_j|} $$

目标是选择一个包含 ( m ) 个 token 的子集 ( $S \subset Z$ ),使得:

$$ S^* = \arg\max_{S \subset Z, |S|=m} \det(L_S) $$

这里 ( $L_S$ ) 是对应子集的子矩阵。行列式越大,代表该子集在特征空间中“覆盖的方向”越多,信息冗余越低。

直观解释: 如果两个 token 的特征非常相似(线性相关),行列式会减小。因此 DPP 天然倾向保留“互补”信息而非重复 token,从而实现高效且全局均衡的剪枝。


Instruction Relevance

传统 DPP 仅基于视觉特征构建核矩阵,无法体现视觉 token 与文本指令的相关性。CDPruner 通过以下步骤引入 条件相关性(conditional relevance)

  1. 获取视觉 token 表示 ( $H_v \in \mathbb{R}^{n \times d}$ );

  2. 获取文本嵌入(instruction embedding) ( $\bar{H}_q \in \mathbb{R}^d $),其来源可以是:

    • CLIP-like text encoder(若模型具备双编码结构);
    • 或通过 multimodal projector 与 LLM 的指令 token 平均表示。
  3. 计算每个视觉 token 与 instruction 的余弦相似度:

    $$ r_i = \frac{H^v_i \cdot \bar{H}_q}{|H^v_i| , |\bar{H}_q|} $$

  4. 对相关性进行 min–max 归一化:

    $$ \tilde{r}_i = \frac{r_i - \min(r)}{\max(r) - \min(r)} $$

得到的向量 ( $\tilde{r} \in [0,1]^n$ ) 反映了各视觉 token 对当前指令的重要程度。

直观理解:

这使得剪枝可以根据用户问题动态调整保留区域。例如,同一张图片在不同问题下,关注区域(即高 ( $\tilde{r}_i$ ))完全不同。论文在 Figure 3 中展示了这种可视化结果。


CDPruner

核心机制:条件 DPP(Conditional DPP)

为了联合考虑视觉特征多样性与指令相关性,CDPruner 构建了条件核矩阵:

$$ \tilde{L} = \operatorname{diag}(\tilde{r}) , L , \operatorname{diag}(\tilde{r}) $$

DPP 的目标变为最大化该核矩阵的行列式:

$$ S^* = \arg\max_{S \subset Z, |S| = m} \det(\tilde{L}_S) $$

通过对数形式可以分解为:

$$ \log\det(\tilde{L}S) = \sum{i \in S} \log(\tilde{r}_i^2) + \log\det(L_S) $$

这清楚地表明,CDPruner 同时优化两项:

  • relevance term(指令相关性)
  • diversity term(全局多样性)

MAP 推断与高效近似

DPP 的 MAP inference 是 NP-hard 的,因此论文采用了 贪心近似(Fast Greedy MAP) 算法:

  1. 初始化空集合 ( S = \emptyset );
  2. 每次迭代选取能最大化当前增益(即行列式增加量)的 token;
  3. 使用 Cholesky 分解快速计算增益,从而实现高效近似。

该算法的时间复杂度为:

$$ O(nm^2) $$

在实践中,当保留 token 数 ( m \ll n ) 时,额外延迟仅约 <10ms/sample,可直接应用于 MLLM 推理流程中。


可调权重(平衡项)

论文还提出可选平衡因子 ( \theta ),用于控制相关性与多样性之间的权重:

$$ \log\det(\tilde{L}S) = \theta \sum{i \in S} \tilde{r}_i + (1 - \theta) \log\det(L_S) $$

实验显示在不同任务上最优的 ( $\theta$ ) 不同,为模型提供了更好的灵活性。


Implementation & Model-agnostic 特性

  • 对不同架构均可用(LLaVA、Qwen2.5-VL 等);
  • 若模型具备 CLIP-like 双编码结构,直接使用 text encoder 的输出;
  • 否则通过 multimodal projector + LLM 的 instruction token 平均得到文本表示;
  • 完全 training-free,可直接嵌入推理流程作为模块。

Evaluations

Main Results

在多种视觉语言基准上,CDPruner 在不同 token 削减率下均超过已有方法。


CDPruner for High-resolution Inputs

对高分辨率输入(如 LLaVA-NeXT-7B, 2880→320 tokens),CDPruner 显著减少 FLOPs 与延迟,同时保持性能。


CDPruner for Video Understanding & Advanced Architectures

在视频任务与先进架构(如 Qwen2.5-VL)上,CDPruner 仍能在不同压缩率下取得一致提升,验证了方法的通用性与可迁移性。


Efficiency Analysis & Ablation Study

论文的效率实验(以 LLaVA-NeXT-7B 为例)显示:

  • FLOPs 减少约 10×;
  • Prefill latencyDecode latency 显著下降;
  • GPU memoryKV cache size 大幅降低;
  • 仅在极高剪枝率下性能略有下降。

Ablation 研究表明:

  • 同时考虑 conditional relevance 与 DPP 多样性带来最佳性能;
  • 仅使用 attention 或 similarity 的变体均低于 CDPruner。

Conclusions

本文提出了一个 训练无关、模型无关 的视觉 token 剪枝方法 CDPruner,通过定义基于指令的条件相似性,并以 DPP 形式最大化选中 token 的条件多样性,实现了推理加速与性能保持的统一。

CDPruner 在多种 MLLM 架构(如 LLaVA 系列与 Qwen2.5-VL)上实现 SOTA 性能,同时显著降低延迟与显存使用,展现出在真实应用中部署多模态大模型的潜力。