Conference: ICLR'25

Github: https://github.com/FYYFU/HeadKV.


My Thoughts

这篇工作和 DuoAttention 的关注点类似,都是关注不同 attention head 对模型不同能力的贡献不同,这篇工作更关注 attention head 对模型 RetrievalReasoning 的 importance score, 来实现 KV Cache 的 nonuniform budget allocation 。

1. Motivation

现代 LLM 越来越支持极长上下文(例如 GPT-4、Llama-3、Qwen-2、Claude 等),但随着输入长度增长,Transformer 的 self-attention 导致 KV cache(attention 的 key/value 状态)占用内存线性增长,成为推理阶段的主要瓶颈。已有工作通过 token eviction / 层级缓存压缩来缓解,但几乎没有研究在“头(head)级别”上对 KV cache 大小进行差异化分配。作者观察到 attention heads 在功能上高度异质(如 retrieval heads、reasoning heads 等),因此提出基于头重要性的 head-level KV cache 压缩方法(HeadKV),并在此基础上提出结合检索与推理能力评估的 HeadKV-R2。


2. Relative Work

2.1 Attention heads

回顾了对多头注意力中 head 功能的研究(Voita et al., Olsson et al., Wu et al., Zheng et al. 等),并指出不同 head 在词法、结构、复制(induction)、检索等方面扮演不同角色。这些观察为按 head 分配 KV cache 提供理论基础。

2.2 KV cache compression

StreamingLLM(attention sink)、Heavy-Hitter/PyramidKV、SnapKV、Ada-KV 等;并指出现有方法通常以 layer 为单位分配预算或在 layer 内部动态分配,但并未完全摆脱 layer 约束(因此难以实现真正的 head-level 压缩)。本工作选择完全基于 head 的重要性分布独立分配缓存预算,以期更精准保留关键信息。


3. Method

3.1 Head-Level Importance Score Estimation

目标:为每个 attention head 计算一个反映其在“检索 + 推理(retrieval + reasoning)”任务中重要性的分数 $S_h$,用于后续按比例分配 KV cache。

动机细化

  • 直接使用 Wu et al. (2024) 的 Needle-in-a-Haystack(精确匹配检索)测试会产生极为稀疏的分布(≈70% 头得分为 0),因为其依赖于 exact-match 的 argmax 规则,过于苛刻,不利于分布式预算分配。
  • 为了同时捕捉 检索能力(能把答案片段 k 从上下文中找到)和 推理能力(需要按给定推理步骤选择正确答案),作者将 needle 改造成三段式 $k=(r,c_1,c_2)$:其中 $r$ 是显式的 reasoning step,$c_1$ 是诱导的错误答案,$c_2$ 是正确答案;模型必须借助 $r$ 来推理并输出 $c_2$。这种构造能激活既用于定位也用于逻辑推理的头。

公式与计算步骤:

论文给出两类评分:原始 Retrieval(Eq.1)与新的 Retrieval-Reasoning(Eq.2)。

Retrieval(Wu et al.)

$$ S_h = \sum_{t=1}^{N} N_t, \quad N_t = \begin{cases} \dfrac{1}{N}, & \text{if } \arg\max(a^t_h) \in k [2mm] 0, & \text{otherwise} \end{cases} \tag{1} $$

其中 $a^t_h$ 是 head $h$ 在 decoding step $t$ 对合并输入(needle + haystack)的 attention 分数向量,$N$ 是 needle 长度。此法依赖 argmax 是否落在 needle 上(即是否 exact match)。


Retrieval-Reasoning(作者提出)

为了不丢失 token 以外的贡献,作者把“只看 argmax”扩展为对 top-i attention 的加权求和,并以正确答案 $c_2$ 的所有 tokens 为关注对象:

$$ S_h = \sum_{t=1}^N \sum_{i=1}^N M^t_i, \quad M^t_i = \begin{cases} \dfrac{a_i}{N}, & \text{if } \text{top-i}(a^t_h)\in c_2 [1mm] 0, & \text{otherwise} \end{cases} \tag{2} $$

其中 $a_i$ 是 head $h$ 在 step $t$ 的第 $i$ 大 attention 值;$\text{top-i}(a^t_h)$ 是对应的 token(第 $i$ 高 attention 指向的 token)。 直观上,若一个 head 在多个 decoding step 对正确答案的 token 给出高 attention,则该 head 得分较高;通过对 $a_i$ 的加权,评分更平滑且能反映“部分注意力”贡献。


实现细节(补充)

  • 在实际实现中,作者将 needle 插入不同位置 $p_i$(不同的 haystack 位置)以保证分布稳健。对每个插入位置跑若干样本,再对同一 head 的得分取平均(论文 pseudo-code 中有读取并平均的实现)。
  • 对 $S_h$ 做 $L_1$-归一化(sum 到 1),得到 head 重要性分布向量 $\mathbf{S}$,以便用于按比例分配动态池预算(见 3.2)。
  • 设计取 top-i 的上限 $i_\text{max}$,以及是否对不同 decoding step 赋不同权重,这些在论文实现中通过实验选择(默认扩大了 i 的数量,相比 Wu et al. 更宽松)。详见附录 pseudo-code。

直观总结:R2(Retrieval-Reasoning)评分把 head 对于**整段正确答案(而非单 token)**的持续注意力都考虑进来,因此评分密度更高、更能区分“负责检索”与“负责推理”的头,并且更适合驱动资源有限时的预算分配。论文可视化也显示 R2 分布更 dense,而原始 retrieval 分布较 sparse。


3.2 Head-Level KV Cache Allocation

核心思想:基于得到的 head 重要性分布 ${S_h}$ 对每个 head 分配不同的 KV cache 大小 $b_h$,同时保留一个「共享动态池」 $B$ 用于把预算从弱头收集并按权重给强头。


公式:

$$ B = \frac{b}{\beta} \cdot L \cdot H, \qquad b_h = \Big(b - \frac{b}{\beta}\Big) + S_h \cdot B \tag{4} $$

解释:

  • $b$:每个 head 的初始固定预算(baseline budget);
  • $\beta$:超参数控制共享池大小;较小的 $\beta$ ⇒ 更大的共享池(更多预算可被重新分配);
  • $L$、$H$:模型的层数与每层 head 数(因此 $L\times H$ 是总 head 数);
  • 最终 $b_h$ 由 “保证的最小预算” $\big(b - b/\beta\big)$ 与按重要性分配的动态预算 $S_h \cdot B$ 之和构成。

实现细节(Pseudo code)

  • 先保留最后 $\alpha$ 个 instruction token(local window),这些 token 在形成动态池前必须保证入池前被保留,用于指导 selection(见 3.3)。论文默认 $\alpha=8$。

  • 论文实现(pseudo-code)中的步骤:

    1. 读取并平均保存好的 head 得分列表,归一化得到 $\mathbf{S}$。
    2. 计算 $$ \text{total_pool_capacity} = \frac{\text{base_capacity}}{\beta} \cdot \text{num_hidden_layers} \cdot \text{num_attention_heads}, \quad \text{min_num} = \text{base_capacity} - \frac{\text{base_capacity}}{\beta} $$ 3.

    $$ \text{head_capacity} = \text{round}(\text{total_attention} \cdot \text{total_pool_capacity} + \text{min_num}) $$ (向最近整数取整得到每个 head 的最终条目数,用于 gather 操作)。


3.3 KV Cache Selection

目标:在确定每个 head 保留的条目数 (b_h) 后,如何从该 head 的历史 key/value 列表中选取具体的索引(tokens)以构成 compressed KV cache。

方法(基于 SnapKV)

  • 保留最后 (\alpha) 个 instruction tokens(local observation window);
  • 计算 local window(query)到每个候选 token(key)的 attention score(通常为点积或 softmax 归一化后的分数);
  • 对每个 head,将这些 attention 分数进行 pooling/聚合(论文中提到 pooling 层用于聚合来自 local window 的 attention),并将得分排序;
  • 选取得分最高的前 (b_h) 个 token 索引作为该 head 的缓存条目。最后将这些被选出的 key/value 与 local window 的最近 tokens 合并(concatenate),构成最终每个 head 的 compressed KV。

Pseudo-PyTorch 样例(论文 Listing 改写并注释要点)

# 伪代码要点(基于论文 Listing)
# 假设:origin_heads_key_states: [num_heads, batch, seq_len, head_dim]
#         origin_heads_value_states: 类似
#         head_capacity[layer_idx][head_idx] 已由 obtain_head_budget 得到

heads_key_states = []
heads_value_states = []

# 1. 计算 local window -> 所有 tokens 的 attention score(同 SnapKV)
attn_score = calc_attn_score(query_states, key_states)  # shape: [num_heads, seq_len]

# 2. 排序得到索引(按得分降序)
_, indices = attn_score.sort(dim=-1, descending=True)

# 3. 对每个 head 按 head_capacity 取 top-k 索引并 gather key/value
for head_idx in range(num_heads):
    k = head_capacity[layer_idx][head_idx]  # 该 head 最终保留数
    cache_index = indices[head_idx, :k]     # 取 top-k 索引
    # 扩展索引到 head_dim 用于 gather
    cache_index = cache_index.view(1,1,-1,1).expand(-1,-1,-1,head_dim)
    top_Kcache = origin_heads_key_states[head_idx].gather(dim=2, index=cache_index)
    top_Vcache = origin_heads_value_states[head_idx].gather(dim=2, index=cache_index)
    # 与 local window 的 last self.window_size tokens 合并
    selected_k = torch.cat([top_Kcache, origin_heads_key_states[head_idx][:,:, -self.window_size:, :]], dim=2)
    selected_v = torch.cat([top_Vcache, origin_heads_value_states[head_idx][:,:, -self.window_size:, :]], dim=2)
    heads_key_states.append(selected_k.view(-1, head_dim))
    heads_value_states.append(selected_v.view(-1, head_dim))

# 最终合并所有 head 的 key/value
heads_key_states = torch.cat(heads_key_states, dim=0)
heads_value_states = torch.cat(heads_value_states, dim=0)

该流程在论文实现(Listing)中给出,实际实现还包含保存与加载 importance distribution、layer-by-layer 的处理与张量对齐细节。

补充说明

  • 合并 local window 的目的是保证最近的指令信息(instruction tokens)一定被保留,这对生成质量和 selection 指导很重要;
  • selection 使用 attention pooling 而不是简单的 token TF/IDF 等启发式,是因为 attention 能反映模型当前上下文相关性的内部信号(即“模型知道自己在看什么”),这也是 SnapKV 的核心思想。

4. Experiments and Analysis

4.1 Experiment settings

  • Backbone Models:Llama-3-8B-Instruct、Mistral-7B-Instruct。

  • Benchmarks / Datasets:LongBench(Single-Doc QA、Multi-Doc QA 类别)与 LooGLE(Long Dependency QA 类别)。论文附录给出各数据集的具体样本与平均上下文长度等(Appendix Table 5)。

  • Baselines

    1. SnapKV:以最后 (\alpha) 个 tokens 指导 selection(attention pooling)。
    2. PyramidKV:按层金字塔分配,更低层给更多缓存。
    3. Ada-KV / Ada-SnapKV:在层内基于 concentration 动态分配(论文采用 Ada-SnapKV 作为最强 baseline)。
  • 统一设置:local window size (\alpha = 8);评估 KV size(保留条目)集合:({64,128,256,512,1024});(\beta) 在 ({1.005,1.01,1.1,1.2,1.5,2,5,10}) 取最优报告;选择 SnapKV 作为 per-head selection 方法(即 HeadKV 在分配之后用 SnapKV 做具体条目挑选)。


4.2 Main results

  • 总体结论:Head-level 分配(HeadKV)在各种 KV size 下普遍优于 layer-level baselines,尤其在低资源(KV size=64 或 128)下收益最显著;而基于 Retrieval-Reasoning(HeadKV-R2)的分布进一步优于只基于检索的分布(HeadKV-R)。在某些设置下(Llama-3-8B, KV=1024),HeadKV-R2 的平均分甚至略超 FullKV(32.95 vs 32.90),说明合理压缩在噪声/冗余信息过多时还能抑制负面影响。表 1 给出完整数值。

补充解释 / 直观解读

  • layer-level 方法(SnapKV、PyramidKV)对所有头一视同仁或按层粗粒度分配,无法在“哪些 head 真正负责找答案或推理”上做微调,因此在 head 数目多、功能异质性大的模型上会浪费宝贵的缓存资源。HeadKV 将预算集中到少数关键头,从而在极小的缓存预算下保存关键信息(论文提出在 1.5% KV 保留下仍能达到 FullKV 的 97% 性能)。

4.3 Retrieval-Reasoning Heads

  • Ablation:作者比较了三种 head 分布:原始 Retrieval(Wu et al. 的 exact-match;HeadKV-R)、Enhanced-Retrieval(保持 retrieval 示例但采用新的得分估计方法;HeadKV-ER)、Retrieval-Reasoning(作者提出的 R2;HeadKV-R2)。结果显示:HeadKV-ER 较 HeadKV-R 有小幅提升(因为更关注整个 needle 而非 argmax),但仍不如 HeadKV-R2(因为 ER 仍缺少推理步骤的激活信号)。表 2 给出具体数值。

深入补充

  • R2 分布的密度更大、有更好的区分度(less zero mass),适合在分配共享池时进行稳定的比例分配;实验显示 R2 在多个任务上都带来一致提升,尤其是在 multi-doc / long dependency 的 QA 场景(这些场景同时需要检索与多步推理)。

4.4 Long-context retrieval and reasoning

作者使用两类专门设计的 stress-tests:

  1. Needle-in-a-Haystack(检索测试):在海量无关文本中插入一个 needle(答案片段),检测是否能检索并直接 paste 出答案(偏向纯检索能力)。结果中 HeadKV(尤其 HeadKV-R2)在 KV=128 时比其他压缩方法表现更好(图 5)。

  2. Reasoning-in-a-Haystack(推理测试):基于 bAbI 风格的 reasoning 问题,将 reasoning-needles 插入到大 haystack 中,模型需先检索到相关 needles,再基于它们进行多步推理(偏向检索+推理)。在这个测试上,HeadKV-R2 的优势更加明显(Table 3),说明 R2 的评分确实捕捉到了推理相关的 head。

定量要点(摘自论文表格与图):

  • 在 Llama-3-8B, KV=128 的 Reasoning-in-a-Haystack 平均分:FullKV 约 57.04,HeadKV-R2 约 56.84(接近 FullKV),而 SnapKV、PyramidKV 等明显落后(见 Table 3)。这表明在受限 KV 下 HeadKV-R2 能显著保留 reasoning 能力。


4.5 Memory & Latency

设置与结论

  • 使用 Mistral-7B-Instruct、最大序列长度 32K、FlashAttention 实现;评估 decoding latency(包含 prefill 与 decoding)与 peak memory(在不同 context len 与 generation len 下)。论文图 6 显示 HeadKV 在解码延迟上与其他压缩方法基本持平,同时在 peak memory 上较 FullKV 有显著降低。换言之:HeadKV 在不增加运行时开销的前提下,达到了更好的性能/记忆率折中。

补充说明

  • decoding latency 的曲线在 generation length=1 时接近,说明 prefill 的额外开销(包括计算 importance scores 的离线开销)被设计为初始化阶段或一次性离线计算(论文中 importance distribution 为静态并在模型初始化时载入),因此对在线推理开销影响微小。Pseudo-code 明确指出 obtain_head_budget 在初始化时运行一次。

G. Hyper-parameter analysis

  • 论文通过扫描 (\beta) 值集合展示了 (\beta) 对最终平均分的影响(Figure 10/11)。总体结论:较小 (\beta)(更大的动态池)能让 HeadKV-R2 进一步获益,说明 R2 分布在“把更多预算动态分配给重要头”时更可靠。论文同时保持 (\alpha) 与其他参数与 PyramidKV 的实现一致以保证公平比较。

5. Conclusion & Future Work

论文结论:提出 HeadKV-R2(head-level KV cache 压缩 + retrieval-reasoning 重要性估计),在 LongBench、LooGLE 等多数据集与两种 backbone 模型上证明了在极限缓存预算下能保存大部分性能,同时保持低内存与延迟开销。