KV cache 到底缓存了什么:让 decode 不再重复计算历史 token
1. 这一篇要解决什么问题
上一篇文章里,我们手写了一个最小版本的 generate 函数。
它的核心循环大概是这样的:
for step in range(max_new_tokens):
outputs = model(input_ids=input_ids)
logits = outputs.logits[:, -1, :]
next_token = sampler.sample(logits)
input_ids = append(input_ids, next_token)
这个版本可以工作,但它有一个很明显的问题。
每生成一个新 token,我们都会把完整上下文重新送进模型。
假设 prompt 有 1000 个 token,我们要生成 100 个 token。
第 1 轮,模型处理 1000 个 token。
第 2 轮,模型处理 1001 个 token。
第 3 轮,模型处理 1002 个 token。
一直到第 100 轮,模型处理 1099 个 token。
这意味着大量历史 token 被重复计算了一遍又一遍。
而在自回归生成里,历史 token 对应的很多中间结果其实不会再变化。既然不会变化,就应该缓存起来复用。
这就是 KV cache 要解决的问题。
这一篇我们要回答几个问题:
KV cache 中的 K 和 V 分别是什么
为什么只缓存 K 和 V,而不是缓存 Q
prefill 阶段和 decode 阶段如何使用 KV cache
Hugging Face 里的
past_key_values是什么KV cache 为什么会成为推理系统里的显存管理核心问题
这件事和后面的 vLLM、PagedAttention 有什么关系
这一篇之后,我们的生成逻辑会从“每次重算完整上下文”,变成“prefill 算完整 prompt,decode 每次只算新 token”。
这是 mini vLLM 的第一个关键优化。
2. 先回顾一下 attention 做了什么
在 Transformer 里,每一层都会有 self attention。
给定输入 hidden states,模型会通过线性变换得到三组向量:
Q: query
K: key
V: value
可以粗略理解为:
Q 表示当前 token 想找什么信息
K 表示每个历史 token 能提供什么索引
V 表示每个历史 token 真正携带的内容
attention 的计算可以简化写成:
attention_scores = Q @ K.T
attention_weights = softmax(attention_scores)
output = attention_weights @ V
对于 causal language model 来说,第 t 个 token 只能看到自己和它之前的 token,不能看到未来 token。
因此,在生成第 t 个 token 时,模型需要访问前面所有 token 的 K 和 V。
问题就在这里。
当我们生成新 token 时,历史 token 的 K 和 V 已经在之前算过了,而且它们不会改变。
所以没有必要每一轮都重新计算历史 token 的 K 和 V。
这就是 KV cache 的基本动机。
3. 为什么缓存 K 和 V,而不是缓存 Q
很多人第一次接触 KV cache 时会问:
既然 attention 里有 Q、K、V,为什么只缓存 K 和 V?
原因在于,decode 阶段每一轮只关心“当前新 token”要如何关注历史上下文。
当前新 token 会产生新的 Q、K、V。
其中:
当前 token 的 Q 用来查询历史上下文
当前 token 的 K 会被保存起来,供未来 token 查询
当前 token 的 V 也会被保存起来,供未来 token 读取内容
历史 token 的 Q 没有必要缓存,因为未来 token 不会再用历史 token 的 Q。
未来 token 只需要拿自己的 Q 去和所有历史 K 做匹配,然后用 attention 权重去聚合历史 V。
换句话说:
Q 是一次性的查询
K 和 V 是会被未来反复访问的记忆
所以缓存 K 和 V 就够了。
4. prefill 和 decode 的区别
理解 KV cache,必须先区分两个阶段。
第一个阶段是 prefill。
prefill 处理完整 prompt。
假设 prompt 长度是 1024,那么 prefill 阶段会一次性处理这 1024 个 token,并为每一层生成对应的 K 和 V。
这个阶段的输入形状大致是:
input_ids: [batch_size, prompt_length]
模型输出 logits,同时也会输出每一层的 KV cache。
第二个阶段是 decode。
decode 每一轮只处理一个新 token。
此时输入形状大致是:
input_ids: [batch_size, 1]
但模型内部 attention 不能只看这一个 token。它仍然需要看完整历史上下文。
历史上下文从哪里来?
来自 prefill 阶段和前面 decode 阶段保存下来的 KV cache。
所以 decode 阶段可以理解成:
新 token 负责产生新的 Q、K、V
历史 KV cache 提供过去所有 token 的 K、V
attention 用当前 Q 查询完整历史 K、V
最后把当前 token 的 K、V 追加进 cache
这就是 KV cache 的核心数据流。
5. 不使用 KV cache 的生成过程
先看没有 KV cache 的版本。
每一轮生成时,我们把完整上下文都送进模型:
outputs = model(input_ids=all_token_ids)
logits = outputs.logits[:, -1, :]
next_token = sampler.sample(logits)
all_token_ids = append(all_token_ids, next_token)
这个版本的问题是,每一轮都重复处理历史 token。
可以把它想象成这样:
第 1 轮:处理 prompt
第 2 轮:重新处理 prompt + token_1
第 3 轮:重新处理 prompt + token_1 + token_2
第 4 轮:重新处理 prompt + token_1 + token_2 + token_3
历史部分被反复计算。
如果生成越长,浪费越严重。
6. 使用 KV cache 的生成过程
有了 KV cache 之后,流程变成两段。
首先做 prefill:
outputs = model(
input_ids=prompt_ids,
use_cache=True,
)
past_key_values = outputs.past_key_values
prefill 之后,我们拿到 prompt 对应的 KV cache。
接下来进入 decode。
decode 阶段只输入最新 token,并传入历史 KV cache:
outputs = model(
input_ids=last_token_id,
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = outputs.past_key_values
注意,decode 时传入的 input_ids 不再是完整上下文,而只是最新 token。
历史上下文则通过 past_key_values 传进去。
这时生成过程变成:
prefill:处理完整 prompt,得到 KV cache
decode 第 1 轮:只处理 token_1,读取历史 cache,追加新 cache
decode 第 2 轮:只处理 token_2,读取历史 cache,追加新 cache
decode 第 3 轮:只处理 token_3,读取历史 cache,追加新 cache
这就是从重复计算转向缓存复用。
7. past_key_values 的形状
不同模型的实现细节会略有差异,但概念上,past_key_values 通常按层组织。
可以把它理解成:
past_key_values = [
layer_0_kv,
layer_1_kv,
layer_2_kv,
...
]
每一层里有 key cache 和 value cache:
layer_kv = (key_cache, value_cache)
典型形状大致是:
key_cache: [batch_size, num_heads, sequence_length, head_dim]
value_cache: [batch_size, num_heads, sequence_length, head_dim]
这里的 sequence_length 表示已经缓存了多少个 token。
在 prefill 之后,sequence_length 等于 prompt 长度。
在每一轮 decode 之后,sequence_length 会增加 1。
例如 prompt 长度是 1024,已经生成了 10 个 token,那么 cache 里的 sequence_length 就是 1034。
这也是为什么长上下文和长输出会显著增加显存占用。
KV cache 不是一个小东西。它会随着请求数量和上下文长度线性增长。
8. KV cache 的显存开销
我们可以粗略估算一下 KV cache 的大小。
每一层都要缓存 K 和 V。
每个 token 的 KV cache 大小大致和下面这些因素有关:
层数
KV head 数
head_dim
数据类型字节数
K 和 V 两份
可以写成一个简化公式:
KV cache memory
≈ num_layers × num_tokens × num_kv_heads × head_dim × 2 × dtype_size
其中 2 表示 K 和 V 两份。
如果是 FP16 或 BF16,dtype_size 通常是 2 字节。
这个公式非常重要。
它告诉我们,KV cache 显存占用会随着 token 数线性增长,也会随着并发请求数线性增长。
对于单请求 demo 来说,这可能不是大问题。
但对于在线服务来说,这会变成核心问题。
假设同时有很多用户请求,每个请求都有自己的 prompt 和生成过程。系统就需要为每个请求维护一份不断增长的 KV cache。
有的请求很短,很快结束。
有的请求很长,持续占用大量 cache。
有的请求提前停止,释放资源。
有的新请求又进来,申请资源。
这就是为什么 vLLM 这类系统要把 KV cache 管理做成核心模块。
9. 只用 past_key_values 还不够
到这里,我们已经可以用 past_key_values 避免重复计算历史 token。
但是,这还不是 vLLM 真正厉害的地方。
因为 Hugging Face 风格的 KV cache 更适合单请求或简单 batch。
对于一个高并发推理系统,我们还需要解决更多问题。
第一个问题是 cache 如何分配。
如果每个请求都提前分配一大块连续 KV cache,那么生成长度不可预测时会浪费大量显存。
第二个问题是 cache 如何释放。
请求随时结束,释放出来的显存可能变成碎片。下一个请求未必能直接复用这些空间。
第三个问题是 cache 如何在 batch 中组织。
不同请求的上下文长度不同,生成进度不同。如果把它们强行 padding 到一样长,会浪费计算和显存。
第四个问题是 cache 如何支持 continuous batching。
每一轮 decode 时,有的请求继续生成,有的请求结束,有的新请求加入。KV cache 必须能动态管理这些变化。
所以,past_key_values 解决的是“如何避免重复计算”。
而 vLLM 的 PagedAttention 进一步解决的是“如何高效管理大量请求的 KV cache”。
这两者不是同一个层级的问题。
我们可以把它们的关系理解成:
KV cache:让单个请求不重复计算历史 token
PagedAttention:让大量请求的 KV cache 更高效地分配、复用和访问
这也是我们后面几篇文章的主线。
10. 从 generate 函数看 prefill 和 decode
为了让概念更清楚,我们把上一篇的生成函数拆成两个阶段。
第一阶段是 prefill:
outputs = model(
input_ids=prompt_ids,
use_cache=True,
)
past_key_values = outputs.past_key_values
logits = outputs.logits[:, -1, :]
next_token = sampler.sample(logits)
第二阶段是 decode:
outputs = model(
input_ids=next_token[:, None],
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = outputs.past_key_values
logits = outputs.logits[:, -1, :]
next_token = sampler.sample(logits)
注意两者最大的区别:
prefill 输入完整 prompt
decode 输入最新 token
但两者都会输出新的 past_key_values。
这意味着 cache 是随着生成过程不断增长的。
11. 为什么 prefill 和 decode 要分开看
在推理系统里,prefill 和 decode 的性能特征完全不同。
prefill 阶段一次处理很多 token,矩阵计算规模大,更容易把 GPU 跑满。
decode 阶段每次只处理一个 token,计算粒度小,但要反复访问历史 KV cache,访存压力更明显。
这会带来一个重要结论:
prefill 更偏计算密集
decode 更偏访存密集
正因为它们的性能特征不同,推理系统通常会分别调度它们。
例如:
新请求先进入 waiting 队列
调度器选择一批请求做 prefill
prefill 完成后,请求进入 running 队列
running 请求每轮做 decode
decode 过程中如果有新请求进入,调度器可以决定是否插入新的 prefill
这就是后面 continuous batching 的基础。
如果我们不理解 prefill 和 decode 的区别,就很难理解为什么调度器需要 token budget、为什么长 prompt 会阻塞 decode、为什么 chunked prefill 有意义。
12. KV cache 在请求对象里的位置
现在我们可以回到 mini vLLM 的工程设计。
在上一篇文章里,我们的生成函数只维护了一个 input_ids。
但在真正的推理引擎里,每个请求至少要维护这些信息:
prompt tokens
generated tokens
sampling params
past key values
当前状态
停止原因
也就是说,KV cache 应该成为请求状态的一部分。
一个简化的请求对象可以理解成这样:
class Sequence:
prompt_token_ids: list[int]
generated_token_ids: list[int]
sampling_params: SamplingParams
past_key_values: Any
status: SequenceStatus
这不是最终版本。
后面我们会把 past_key_values 替换成更可控的 KV cache block 表。
但在当前阶段,把 cache 放进 Sequence 里是合理的。因为每个请求都有自己的生成历史,也就有自己的 KV cache。
13. 当前版本的 mini vLLM 会怎么变
上一篇文章里的 simple_generate 是一个完整循环。
现在我们会把它改造成更接近推理引擎的形式。
它会有两个关键动作。
第一个动作是 prefill:
sequence = engine.add_request(prompt)
engine.prefill(sequence)
这个动作会完成:
tokenize prompt
执行模型 forward
生成第一份 KV cache
采样第一个 token
更新 sequence 状态
第二个动作是 decode:
while not sequence.is_finished():
engine.decode(sequence)
这个动作会完成:
读取上一次保存的 KV cache
只输入最新 token
执行模型 forward
更新 KV cache
采样下一个 token
判断是否结束
这看起来只是简单拆分,但它会为后面做多请求调度打下基础。
因为后续我们可以把多个 sequence 的 prefill 放在一起,也可以把多个 running sequence 的 decode 放在一起。
14. 一个容易忽略的问题:第一个 token 从哪里来
很多初学者写 KV cache generate 时,容易在“第一个 decode token”这里混乱。
流程应该是这样的:
prompt 进入 prefill
prefill 输出最后一个位置的 logits
sampler 根据这个 logits 采样出第一个生成 token
第一个生成 token 作为下一轮 decode 的输入
也就是说,第一个生成 token 并不是 decode 算出来的,而是 prefill 的最后一个 logits 采样出来的。
decode 从第二个生成 token 的计算开始。
举个例子:
prompt: A B C
prefill 输入: A B C
prefill 输出: 用 C 位置的 logits 采样出 D
decode 输入: D
decode 输出: 用 D 位置的 logits 采样出 E
decode 输入: E
decode 输出: 用 E 位置的 logits 采样出 F
所以在实现时,我们需要保存“上一次生成的 token”,把它作为下一轮 decode 的输入。
这个细节后面会影响 Sequence 的字段设计。
15. 使用 KV cache 后性能为什么会好
使用 KV cache 后,每一轮 decode 不再重新处理完整上下文。
但这并不意味着 decode 的成本和上下文长度完全无关。
因为当前 token 的 Q 仍然要和历史所有 K 做 attention。
也就是说,decode 阶段仍然需要访问历史 KV cache。
上下文越长,需要访问的 K 和 V 越多。
所以 KV cache 的作用不是让长上下文免费,而是避免重复计算历史 token 的投影和中间状态。
可以这样理解:
没有 KV cache:重复计算历史 token 的 K 和 V
有 KV cache:复用历史 K 和 V,只计算新 token 的 K 和 V
它减少的是重复计算,但不会消除对历史上下文的访问。
这也是为什么 KV cache 的显存布局和访问效率会非常重要。
后面的 PagedAttention,本质上就是在解决这个问题:
当有大量请求和大量 KV cache 时,如何既省显存,又高效访问
16. 当前阶段我们先不做什么
这一篇只做单请求 KV cache。
暂时不做这些事情:
不自己实现 attention kernel
不自己管理 KV cache 内存
不支持多请求 batch
不支持 continuous batching
不实现 block manager
不实现 PagedAttention
我们先使用模型框架已经提供的 past_key_values,把“缓存历史 K 和 V”这件事跑通。
这样做的好处是,我们可以先理解语义,再替换底层实现。
后面进入 block manager 时,我们会把 past_key_values 从一个黑盒对象,逐步改造成自己维护的 cache block。
17. 和上一篇文章的对比
上一篇文章的生成方式是:
每一轮输入完整上下文
模型重复计算历史 token
实现简单,但效率低
这一篇加入 KV cache 后,生成方式变成:
prefill 输入完整 prompt
decode 每轮只输入最新 token
历史 K 和 V 从 cache 中读取
每轮生成后追加新的 K 和 V
这就是第一个版本的推理优化。
从工程角度看,我们也迈出了一步:
simple_generate
变成
prefill + decode
这个变化非常关键。
因为后面调度器并不是调度“完整 generate 函数”,而是调度很多个 sequence 的 prefill 和 decode step。
18. 小结
这一篇我们引入了 KV cache。
它解决的是 LLM 自回归生成中的一个核心问题:不要在每一轮 decode 时重复计算历史 token。
我们重点理解了这些内容:
attention 中的 K 和 V 会被未来 token 反复访问
Q 是当前 token 的查询,不需要缓存历史 Q
prefill 阶段处理完整 prompt,并生成初始 KV cache
decode 阶段每轮只输入最新 token,并复用历史 KV cache
KV cache 会随着上下文长度和并发请求数增长
KV cache 管理会成为高性能推理系统的核心问题
到目前为止,我们已经从一个最朴素的 generate 函数,走到了支持 KV cache 的单请求推理。
下一篇文章,我们会开始做工程抽象。
我们会把 prompt tokens、generated tokens、sampling params、KV cache、status 和 stop reason 封装成一个 Sequence 对象。
这一步看起来不像性能优化,但它是从单请求 demo 走向多请求推理引擎的关键。