实现教学版 PagedAttention:从 block table 读取历史 KV cache

前面两篇文章,我们已经把问题推进到了 vLLM 最核心的部分。

第九篇里,我们解释了为什么 KV cache 需要分页式管理。在线 LLM serving 中,请求会不断进入、增长和结束。如果每个请求都持有一整段连续 cache,显存浪费会非常严重。如果按需申请连续空间,又会遇到碎片化。PagedAttention 的核心思想是:让 Sequence 在逻辑上拥有连续上下文,但物理上使用不连续的 KV cache blocks。

第十篇里,我们实现了 Block Manager。它负责维护全局 free blocks,为 prefill 分配 block,为 decode 追加 slot,在请求结束后释放 block。到这一步,我们已经有了 block table,也有了 physical KV cache pool 的概念。

但是,还有一个关键问题没有解决:

模型做 attention 时,怎么根据 block table 找到正确的 K 和 V?

这就是这一篇要讲的内容。

我们会实现一个教学版 PagedAttention。它不会追求高性能,也不会一上来写 CUDA kernel。它的目标只有一个:把 PagedAttention 的寻址过程讲清楚,让读者真正理解从 logical token index 到 physical KV cache address 的映射。

高性能可以以后再做,但正确性必须先看懂。

PagedAttention 改变的是 KV cache 的访问方式

先强调一个很重要的点。

PagedAttention 没有改变 attention 的数学含义。

对于一个正在 decode 的 token,它仍然会拿自己的 Q 去和历史所有 K 做匹配,然后用 softmax 得到 attention weights,再用这些 weights 加权求和历史 V。

概念上还是:

scores = Q @ K.T
weights = softmax(scores)
output = weights @ V

变化发生在 K 和 V 从哪里来。

普通 attention 假设历史 K 和 V 是连续存放的。

K cache:
token 0, token 1, token 2, token 3, ...

V cache:
token 0, token 1, token 2, token 3, ...

所以访问第 i 个历史 token 的 K 和 V,直接读第 i 个位置就行。

PagedAttention 下,历史 K 和 V 被切成多个 physical block。Sequence 的逻辑上下文仍然是连续的,但它对应的 physical block 可能是离散的。

例如 block size 是 4,某个 Sequence 的 block table 是:

[7, 2, 9]

这表示:

logical block 0 -> physical block 7
logical block 1 -> physical block 2
logical block 2 -> physical block 9

那么这个 Sequence 的逻辑 token 到物理位置的映射就是:

token 0 -> physical block 7, offset 0
token 1 -> physical block 7, offset 1
token 2 -> physical block 7, offset 2
token 3 -> physical block 7, offset 3

token 4 -> physical block 2, offset 0
token 5 -> physical block 2, offset 1
token 6 -> physical block 2, offset 2
token 7 -> physical block 2, offset 3

token 8 -> physical block 9, offset 0
...

所以 PagedAttention 的第一步不是计算 attention,而是做地址映射。

它要把逻辑上的历史 token 序列,还原成 attention 可以使用的 K 和 V。

从逻辑 token 到物理 block

PagedAttention 的核心寻址关系非常简单。

给定一个历史 token 的逻辑位置 token_index,以及 block size,可以先算出它属于哪个 logical block:

logical_block_id = token_index // block_size

再算出它在 block 内部的偏移:

block_offset = token_index % block_size

然后通过 block table 找到对应的 physical block:

physical_block_id = block_table[logical_block_id]

最终,K 和 V 的物理位置就是:

physical_block_id, block_offset

这条映射链路就是 PagedAttention 的核心:

logical token index
  -> logical block id
  -> physical block id
  -> block offset
  -> K/V cache address

只要这条链路是正确的,PagedAttention 在语义上就能得到和连续 KV cache 一样的结果。

区别只是底层存储不再连续。

这也是我们实现教学版 PagedAttention 的目标。

先不要考虑 kernel fusion,也不要考虑 shared memory、warp、memory coalescing 这些底层优化。先保证逻辑 token 序列能正确读回历史 K 和 V。

先定义一个简单的 KV cache 布局

为了方便理解,我们先定义一个简单的 physical KV cache pool。

假设我们只关注单层 attention,先忽略多层结构。

可以把 key cache 和 value cache 看成两个张量:

key_cache:
[num_blocks, block_size, num_heads, head_dim]

value_cache:
[num_blocks, block_size, num_heads, head_dim]

其中:

num_blocks 是全局 physical block 数量
block_size 是每个 block 能存多少 token
num_heads 是 attention head 数
head_dim 是每个 head 的维度

在真实模型里,每一层都需要自己的 K 和 V,所以通常还会多一个 layer 维度。

例如可以组织成:

[num_layers, num_blocks, block_size, num_heads, head_dim]

或者把每层的 cache 分开存。

教学版本里,为了减少干扰,我们可以先只讲单层。多层只是对每一层重复同样的过程。

现在,假设某个 Sequence 的 block table 是:

[7, 2, 9]

如果我们要读取逻辑 token 5 的 key,那么:

logical_block_id = 5 // 4 = 1
block_offset = 5 % 4 = 1
physical_block_id = block_table[1] = 2

所以它对应:

key_cache[2, 1]
value_cache[2, 1]

这个位置里保存的就是逻辑 token 5 的 K 和 V。

如果 block table 没错,Block Manager 写入 slot 没错,那么这里读出来的值就应该等价于连续 cache 里第 5 个位置的值。

教学版实现的思路:先 gather,再 attention

高性能 PagedAttention 不会真的把所有历史 K 和 V gather 成一个连续张量再算 attention。那样会增加额外拷贝,效率不高。

但教学版可以这么做。

因为我们的目标不是最高性能,而是让逻辑更清楚。

对于一个 Sequence,我们可以先根据 block table 把它的历史 K 和 V 收集出来,恢复成逻辑上连续的形式:

logical K:
[token 0, token 1, token 2, ..., token context_len - 1]

logical V:
[token 0, token 1, token 2, ..., token context_len - 1]

然后再用普通 attention 公式计算。

这个过程可以拆成两步。

第一步,paged gather。

根据 block table 从 physical KV cache 中取出历史 K 和 V

第二步,standard attention。

用当前 query 对 gather 后的 K 和 V 做普通 attention

这种实现虽然慢,但非常适合验证正确性。

我们甚至可以写两个版本做对比。

一个版本使用普通连续 KV cache。

一个版本使用 paged KV cache,先 gather 再 attention。

只要输入相同、cache 内容相同、mask 相同,两者输出应该一致。

这就是教学版 PagedAttention 最重要的验证方式。

单个 Sequence 的 paged gather

先考虑最简单情况:batch size 等于 1,只处理一个 Sequence。

输入包括:

query: 当前 token 的 Q
key_cache: 全局 physical key cache
value_cache: 全局 physical value cache
block_table: 当前 Sequence 的 block table
context_len: 当前可见历史长度
block_size: 每个 block 的 token 数

我们要做的是,把逻辑位置从 0 到 context_len 减 1 的 K 和 V 读出来。

伪代码大概是:

keys = []
values = []

for token_index in range(context_len):
    logical_block_id = token_index // block_size
    block_offset = token_index % block_size

    physical_block_id = block_table[logical_block_id]

    k = key_cache[physical_block_id, block_offset]
    v = value_cache[physical_block_id, block_offset]

    keys.append(k)
    values.append(v)

最后把 keys 和 values 拼成连续张量:

keys:   [context_len, num_heads, head_dim]
values: [context_len, num_heads, head_dim]

这个过程就是最直观的 paged gather。

它把分散在 physical blocks 中的 K 和 V,按照逻辑 token 顺序重新收集起来。

接下来,就可以做普通 attention。

从 gather 后的 K/V 做 attention

假设当前 decode token 的 query 形状是:

query: [num_heads, head_dim]

gather 后的 keys 和 values 是:

keys:   [context_len, num_heads, head_dim]
values: [context_len, num_heads, head_dim]

为了方便计算,可以把维度调整成按 head 组织:

query:  [num_heads, head_dim]
keys:   [num_heads, context_len, head_dim]
values: [num_heads, context_len, head_dim]

然后对每个 head 做 attention。

scores = query @ keys.T
weights = softmax(scores)
output = weights @ values

输出形状是:

[num_heads, head_dim]

最后再把多个 head 合并回 hidden size。

这个过程和普通 attention 没有本质区别。

区别只在于 keys 和 values 是通过 block table gather 出来的。

所以可以这样理解:

PagedAttention = paged KV lookup + standard attention

高性能实现会把这两个阶段融合起来,不显式生成中间连续 K/V。

但教学版本可以先把它拆开,这样最容易看懂。

扩展到 batch:每个 Sequence 都有自己的 block table

真实推理里,decode batch 通常会同时处理多个 running Sequence。

每个 Sequence 都有自己的 block table,也有自己的 context length。

例如:

Sequence A:
  block_table = [7, 2, 9]
  context_len = 10

Sequence B:
  block_table = [4, 1]
  context_len = 6

Sequence C:
  block_table = [8, 5, 3, 6]
  context_len = 15

decode batch 中每个 Sequence 都输入一个 token,每个 token 都有自己的 query。

queries:
[batch_size, num_heads, head_dim]

PagedAttention 需要分别为每个 Sequence 读取自己的历史 K 和 V。

伪代码会变成:

outputs = []

for seq_idx in range(batch_size):
    query = queries[seq_idx]
    block_table = block_tables[seq_idx]
    context_len = context_lens[seq_idx]

    keys, values = gather_kv(
        key_cache,
        value_cache,
        block_table,
        context_len,
    )

    output = attention(query, keys, values)
    outputs.append(output)

最后 outputs 拼成:

[batch_size, num_heads, head_dim]

这个实现非常直观,但性能很差。

原因是 Python 循环很多,gather 很慢,而且每个 Sequence 的 context length 不同,难以形成高效的大矩阵计算。

但作为教学版,它足够清楚。

它能帮助我们验证 block table 和 cache 写入逻辑是否正确。

只有当这个版本正确,我们才有资格去写更复杂的 Triton 或 CUDA kernel。

attention mask 在哪里

在 decode 阶段,如果 context_len 表示当前 token 可以看到的历史长度,那么 gather 出来的 K 和 V 本身就已经只包含可见 token。

此时不一定需要额外的 causal mask。

因为当前 token只对过去的 token 做 attention,不存在未来 token。

但在 prefill 阶段,情况不同。

prefill 一次性处理完整 prompt。prompt 内部每个 token 都只能看自己和之前的 token,不能看未来 token。这时需要 causal mask。

教学版 PagedAttention 可以先主要关注 decode。

原因是 PagedAttention 在 vLLM 中最核心的场景就是 decode 阶段高效访问历史 KV cache。

prefill 阶段也可以使用 block 化 cache 写入,但计算完整 prompt attention 时,很多实现会采用不同路径,例如普通 attention 或其他优化 backend。

所以这一篇我们主要实现 decode 侧的 PagedAttention。

这不是说 prefill 不重要,而是为了让主线更清晰。

prefill 负责把 prompt tokens 的 K 和 V 写入 blocks。

decode 负责通过 block table 读取这些 blocks,并生成后续 token。

写入 slot 和读取 block table 必须一致

这里有一个极其重要的正确性条件:

Block Manager 写入 cache 的位置,必须和 PagedAttention 读取 cache 的位置完全一致。

decode 时,Block Manager 会为当前输入 token 分配一个 slot。

physical_block_id, block_offset

ModelRunner 执行当前 token 的 forward 后,会把这个 token 的 K 和 V 写入这个 slot。

下一轮 decode 时,PagedAttention 会把这个 token 当作历史 token 读取出来。

如果写入时使用的 token index 和读取时使用的 token index 不一致,模型就会读错历史。

这类 bug 很难排查。

因为程序可能不会报 shape 错误,也不会报显存错误。它只是生成结果变差,或者在长输出时逐渐跑偏。

所以实现时必须明确两件事。

第一,Sequence 的 cached_token_count 表示已经写入 KV cache 的 token 数。

第二,Sequence 的 generated_token_ids 表示已经采样出来、可以返回给用户的 token 数。

这两个数并不总是同步。

为什么?

因为 prefill 处理 prompt 后,会采样出第一个 generated token。但这个 generated token 的 K 和 V 还没有写入 cache。它会作为下一轮 decode 的输入。只有下一轮 decode 执行完,它的 K 和 V 才被写入 cache。

所以一种清晰的语义是:

prefill:
  输入 prompt tokens
  写入 prompt tokens 的 KV
  采样 token_1
  generated_token_ids = [token_1]
  cached_token_count = prompt_len

decode 第 1 轮:
  输入 token_1
  写入 token_1 的 KV
  采样 token_2
  generated_token_ids = [token_1, token_2]
  cached_token_count = prompt_len + 1

decode 第 2 轮:
  输入 token_2
  写入 token_2 的 KV
  采样 token_3
  generated_token_ids = [token_1, token_2, token_3]
  cached_token_count = prompt_len + 2

这里的 cached_token_count 是 PagedAttention 读取历史 KV 的重要依据。

decode 第 1 轮输入 token_1 时,它应该能看到 prompt tokens 的 KV。

decode 第 2 轮输入 token_2 时,它应该能看到 prompt tokens 加 token_1 的 KV。

如果 cached_token_count 提前加了,attention 可能读到还没写入的 token_2 的位置。

如果 cached_token_count 晚加了,attention 可能看不到 token_1。

这就是为什么状态语义比代码本身更重要。

和连续 KV cache 对齐验证

实现教学版 PagedAttention 时,我建议一定要做一个对齐测试。

思路是这样的:

先构造一份连续 KV cache。

contiguous_key_cache:
[context_len, num_heads, head_dim]

contiguous_value_cache:
[context_len, num_heads, head_dim]

再构造一份 paged KV cache。

把同样的 K 和 V 按照 block table 写入 physical blocks。

然后用同一个 query 分别计算两种 attention。

连续版本:

output_contiguous = attention(query, contiguous_k, contiguous_v)

Paged 版本:

paged_k, paged_v = gather_by_block_table(...)
output_paged = attention(query, paged_k, paged_v)

最后比较:

output_contiguous ≈ output_paged

如果这一步不一致,说明问题一定出在 block table、slot 写入、gather 顺序或 context_len 上。

这个测试非常重要。

因为 PagedAttention 的数学结果应该和普通 attention 一致。

它只是改变 KV cache 存储和访问方式,不应该改变模型输出。

在教学实现中,先用随机张量验证,再接入真实模型验证。

随机张量测试更容易定位问题,因为没有 tokenizer、采样、模型层数这些干扰。

为什么教学版会慢

教学版 PagedAttention 先 gather 再 attention,清晰但慢。

慢在哪里?

第一,它会显式构造连续 K/V。

这相当于多做了一次内存拷贝。

第二,它对每个 Sequence 单独循环。

batch 中每个请求 context_len 不同,Python 层循环会很慢。

第三,它没有利用 GPU 的高效内存访问模式。

真实 kernel 会尽量让线程以合适的方式读取 block,减少随机访问带来的开销。

第四,它没有融合 softmax、value 聚合和 block 读取。

高性能 attention kernel 通常会把多个步骤融合,避免中间结果反复写回显存。

所以教学版不能代表最终性能。

但它有一个巨大优势:可读。

它能让我们清楚地看到 PagedAttention 到底在做什么。

工程学习里,这一步很重要。

如果直接从 CUDA kernel 开始看,很容易被线程块、warp、shared memory、向量化加载这些细节淹没,反而看不清系统设计。

先写一个慢但正确的版本,再逐步优化,是理解复杂系统的更好路径。

从教学版到高性能版的方向

教学版正确之后,优化方向就清楚了。

第一个方向是去掉显式 gather。

不要先把 K/V 收集成连续张量,而是在 attention 计算时直接通过 block table 读取 physical cache。

也就是说,kernel 在遍历历史 token 时,直接完成:

token_index -> block_id -> offset -> load K/V

第二个方向是减少 Python 循环。

batch 内多个 Sequence 的 attention 应该交给 GPU kernel 并行处理,而不是在 Python 里逐个处理。

第三个方向是优化内存访问。

block size、cache layout、head 维度排列都会影响读取效率。高性能实现会尽量让相邻线程读取相邻内存,提高带宽利用率。

第四个方向是融合计算。

softmax、score 计算、value 聚合可以在 kernel 内部融合,减少中间张量的读写。

第五个方向是针对 decode 场景优化。

decode 每个 Sequence 通常只有一个 query,但 context_len 可能很长。这和 prefill 阶段的大矩阵 attention 不一样。kernel 设计需要围绕这个形态优化。

这些优化都很重要,但它们不改变 PagedAttention 的核心抽象。

核心仍然是:

每个 Sequence 通过 block table 访问自己的逻辑上下文

PagedAttention 如何接入 mini vLLM

接入 mini vLLM 后,几个模块的关系会变得更完整。

Block Manager 负责分配 blocks 和 slot。

Sequence 保存 block table、cached token 数和生成状态。

Scheduler 在调度前检查 token budget 和 block budget。

ModelRunner 在执行 decode 时,拿到当前 token、block table、context length、slot mapping 和 physical KV cache pool。

PagedAttention 根据 block table 读取历史 K/V,根据 slot mapping 写入当前 token 的 K/V。

Engine 在采样后更新 Sequence 状态,并在请求结束时释放 block。

这条链路可以概括成:

Scheduler 选择谁运行
Block Manager 决定 cache 写到哪里
ModelRunner 执行模型
PagedAttention 读取历史 KV
Sampler 生成新 token
Sequence 更新状态
Block Manager 回收 finished 请求的 blocks

到这里,mini vLLM 的核心形态就已经非常清楚了。

它不再只是一个调用 past_key_values 的外壳,而是开始拥有自己的 cache 管理和 attention 访问方式。

这正是从“使用模型推理”走向“实现推理系统”的分界线。

小结

这一篇我们实现了教学版 PagedAttention 的核心思想。

PagedAttention 不改变 attention 的数学语义。它改变的是 KV cache 的存储和读取方式。

在普通 attention 中,历史 K 和 V 通常被看作连续数组。

在 PagedAttention 中,Sequence 的上下文逻辑上连续,但物理上分散在多个 KV cache blocks 中。Sequence 通过 block table 记录逻辑 block 到 physical block 的映射。

读取某个历史 token 的 KV 时,需要经过这样的映射链路:

logical token index
  -> logical block id
  -> physical block id
  -> block offset
  -> K/V cache address

教学版实现可以先根据 block table 把历史 K/V gather 成连续张量,再调用普通 attention。这个版本性能不好,但非常适合理解和验证正确性。

实现时最需要注意的是写入 slot 和读取 block table 必须一致,尤其要区分 generated tokens 和已经写入 KV cache 的 cached tokens。很多 PagedAttention 的隐蔽 bug 都来自这个状态语义不清。

到这里,我们已经完成了一个正确性优先的 PagedAttention 版本。

下一篇文章,我们会讨论如何从正确走向更快。

我们会分析教学版为什么慢,高性能 attention kernel 要解决哪些问题,以及 Triton、FlashAttention、FlashInfer 这类 backend 在推理系统中分别扮演什么角色。