从正确到更快:PagedAttention 的优化路线

上一篇文章里,我们实现了一个教学版 PagedAttention。

那个版本的目标不是快,而是正确和清晰。我们让 Sequence 通过 block table 找到自己的 physical KV cache blocks,然后把历史 K 和 V 按照逻辑 token 顺序 gather 出来,再用普通 attention 公式计算输出。

这个实现非常适合学习。

因为它把 PagedAttention 最核心的映射关系暴露了出来:

logical token index
  -> logical block id
  -> physical block id
  -> block offset
  -> K/V cache address

只要这条链路正确,PagedAttention 在语义上就和连续 KV cache 的普通 attention 等价。

但这个教学版也有一个显而易见的问题:它很慢。

如果我们真的在推理服务里为每个 Sequence 先 gather 出完整历史 K/V,再调用普通 attention,那么就会多做很多内存拷贝,也会产生大量中间张量。batch 里每个请求的上下文长度不同,还会引入很多 Python 层循环和不规则张量处理。

所以这一篇要讨论的问题是:

当 PagedAttention 已经正确之后,如何一步步把它变快?

这篇不会直接实现完整 CUDA kernel。我们先从系统视角看清楚:教学版慢在哪里,真正的高性能 attention kernel 要优化什么,Triton、FlashAttention、FlashInfer 这类方案分别适合解决什么问题,以及 mini vLLM 应该采用什么样的演进路线。

先明确优化目标:我们到底在优化什么

很多人一提到优化 attention,第一反应是“用更快的 kernel”。

这个方向当然没错,但如果没有先明确优化目标,就容易把问题想得太窄。

在 LLM 推理系统里,attention 优化至少包含三个层次。

第一层是计算更快。

也就是让 Q、K、V 的矩阵计算、softmax、value 聚合更高效。

第二层是访存更省。

decode 阶段需要频繁读取历史 KV cache。上下文越长,读取越多。此时瓶颈经常不只是算力,而是显存带宽和 cache layout。

第三层是系统更稳。

在线 serving 里,请求长度不同,batch 每轮都变,KV cache 分布在不同 blocks 中。一个 attention 实现如果只能处理整齐的固定形状,在真实调度场景里就不够用。

PagedAttention 的优化重点主要在第二层和第三层。

它不只是让 attention 算得更快,而是让大量动态请求的 KV cache 能够被高效访问。

这也是为什么我们要从 block manager 开始,而不是一上来写 kernel。

因为如果没有 block table,没有 slot mapping,没有 context length,没有物理 KV cache pool,那么所谓的 PagedAttention kernel 就没有输入基础。

现在我们有了这些结构,才可以讨论如何优化。

教学版为什么慢

上一篇文章的教学版大概是这样的思路:

遍历每个 Sequence
根据 block table 读取历史 K/V
把历史 K/V gather 成连续张量
再调用普通 attention

这个实现的问题主要有四个。

第一个问题是显式 gather 带来的额外内存拷贝。

历史 K/V 本来已经在 physical KV cache pool 里。如果我们先把它们复制到一个新的连续张量里,再做 attention,就会多一次读写。对于长上下文来说,这个额外开销非常明显。

更糟糕的是,decode 每一轮都要做一次。请求越多、上下文越长,这个额外拷贝越不可接受。

第二个问题是 Python 循环。

教学版通常会在 Python 层遍历 batch 中的每个 Sequence,再遍历每个历史 token。这样的实现可读性很好,但性能很差。

GPU 擅长大规模并行。Python 循环会让大量工作停留在解释器层面,无法充分利用 GPU。

第三个问题是中间张量太多。

如果每个 Sequence 都构造自己的连续 K/V,再做 attention,就会产生很多临时张量。这些张量不仅占显存,还会增加内存分配和释放的开销。

在线推理服务非常讨厌频繁的小张量分配。它会让性能抖动,也会增加显存碎片。

第四个问题是不规则 shape 难以高效处理。

batch 里不同请求的 context length 不同。

有的请求历史长度是 128。

有的是 4096。

有的是 16000。

如果每个请求都单独构造 attention,就很难形成高效的大 batch 计算。

所以教学版的价值在于正确性,而不是性能。

下一步优化的方向很明确:不要显式 gather,不要在 Python 里逐个 token 循环,不要创建大量中间张量,而是让 GPU kernel 直接基于 block table 读取 KV cache 并完成 attention。

高性能 PagedAttention 的核心思想

高性能 PagedAttention 的理想形态是:在 attention kernel 内部直接完成 block table 寻址。

也就是说,kernel 不再先把 K/V gather 出来,而是在计算 attention score 时直接读取对应的 physical block。

对于每个历史 token,kernel 内部做这样的映射:

token_index
  -> logical_block_id
  -> physical_block_id
  -> block_offset
  -> load K/V

然后直接参与 attention 计算。

这样做的好处是显而易见的。

第一,少了一次显式 gather。

K/V 不需要先拷贝到连续中间张量。

第二,减少中间显存占用。

历史 K/V 直接从 cache pool 读取,不再额外保存一份逻辑连续副本。

第三,寻址和计算可以融合。

读取 K、计算 score、softmax 累积、读取 V、聚合 output,这些步骤可以放进一个或少数几个 kernel 中完成。

第四,更适合动态 batch。

每个 Sequence 可以有自己的 block table 和 context length。kernel 根据这些元信息处理每个请求,而不是强迫所有请求拥有相同长度的连续 cache。

这就是 PagedAttention 从“正确版本”走向“高性能版本”的关键。

但这里也带来了新的问题:直接从不连续 blocks 读取 K/V,会让内存访问更复杂。

普通连续 attention 的访问模式比较规则。

PagedAttention 的访问路径多了 block table lookup。

如果 kernel 写得不好,不连续访问会降低显存带宽利用率。

所以高性能 PagedAttention 的难点就在于:既要保留分页式 KV cache 的灵活性,又要尽量让 GPU 访问模式高效。

decode attention 和 prefill attention 不一样

讨论优化时,一定要区分 prefill 和 decode。

prefill 阶段一次处理完整 prompt。它的形态更像传统的 causal attention。输入长度可能很长,Q、K、V 都来自当前 prompt。这个阶段更适合使用 FlashAttention 这类针对大矩阵 attention 优化的方案。

decode 阶段每个 Sequence 通常只有一个 query token,但它要访问很长的历史 KV cache。batch 中可能有很多 Sequence,每个 Sequence 的 context length 不同,block table 也不同。

所以 decode attention 的形态是:

many queries
每个请求一个 query
每个 query 对应一段不同长度的历史 KV
历史 KV 分布在不同 physical blocks 中

这和 prefill 的大矩阵 attention 很不一样。

prefill 关注的是如何高效处理一大段 prompt。

decode 关注的是如何为大量请求高效读取历史 cache,并生成下一个 token。

因此,mini vLLM 后续可以采用两条不同路径:

prefill path:
  优先使用普通 attention 或 FlashAttention 类方案

decode path:
  使用 PagedAttention 读取 block 化 KV cache

这也是为什么前几篇一直强调 run_prefill 和 run_decode 要分开。

它们不只是调度阶段不同,底层 attention backend 也可能不同。

如果系统结构一开始没有拆开,后面想单独优化 decode 会非常痛苦。

attention kernel 真正在做什么

为了理解优化,我们先把 decode attention 的计算过程拆开。

对于一个 Sequence 当前的 query,它需要看到 context_len 个历史 token。

每个历史 token 都有 K 和 V。

计算过程可以概括为两遍。

第一遍,计算 attention score,并求 softmax 的归一化信息。

score_i = dot(query, key_i) / sqrt(head_dim)

然后对所有 score 做 softmax。

第二遍,用 softmax 权重聚合 value。

output = sum(weight_i * value_i)

朴素实现会先把所有 scores 存下来,再 softmax,再聚合 V。

高性能实现通常不会这么简单。

因为如果 context_len 很长,把所有 score 都写到显存再读回来,会产生大量内存流量。

更好的方式是做分块计算和在线 softmax。

也就是说,kernel 每次处理一段历史 KV,维护当前最大值和归一化因子,逐块更新 softmax 结果。这样可以避免保存完整 score 矩阵。

这个思想和 FlashAttention 的核心方向相似:减少中间结果的显存读写,把更多工作留在片上存储和寄存器里完成。

对于 PagedAttention 来说,还要额外处理一件事:每一段历史 KV 可能来自不同 physical block。

所以 kernel 在处理某段 token 时,还要根据 block table 找到对应 block。

这就是 PagedAttention kernel 的复杂性来源。

它要同时做三件事:

根据 block table 找 K/V
执行 attention score 和 softmax
聚合 V 得到 output

如果这三件事分开做,代码好写但慢。

如果融合起来做,性能好但实现复杂。

Triton 适合做什么

对于 mini vLLM 这样的教学项目,Triton 是一个很适合的中间阶段。

它比纯 PyTorch 更接近 GPU kernel,可以让我们控制 block、program、mask 和内存访问。

同时它又比手写 CUDA 更容易上手。

使用 Triton 实现 PagedAttention,可以帮助我们逐步理解高性能 kernel 的几个关键点。

比如,如何把一个 head 或一组 head 映射到一个 Triton program。

如何让一个 program 处理一段历史 tokens。

如何根据 block table 计算 physical KV cache 的地址。

如何处理 context length 不是 block size 整数倍的情况。

如何在 kernel 内部做 mask,避免读取超出有效上下文的 token。

如何做分块 softmax,避免保存完整 score。

这些都是从教学版走向高性能版必须面对的问题。

不过,Triton 版本也不应该一上来追求极致。

第一版 Triton kernel 可以只支持比较简单的形态。

例如:

decode only
batch 内每个请求一个 query
固定 head_dim
固定 block_size
不支持复杂量化
不支持多 GPU

先把这个版本跑通,再逐步加能力。

这样读者能看到优化是如何演进的,而不是直接面对一个复杂到无法理解的最终 kernel。

FlashAttention、FlashInfer 和工程复用

如果目标是生产性能,完全从零写所有 kernel 通常不是最优选择。

更现实的做法是理解原理,同时复用成熟 backend。

FlashAttention 主要解决的是高效 attention 计算,尤其是减少 attention 中间矩阵带来的显存读写。在 prefill 阶段,它非常有价值,因为 prefill 是大段序列的 causal attention。

FlashInfer 更偏向推理场景,提供了面向 LLM serving 的高性能 attention 和采样相关能力。它关注的问题更接近在线推理中的 decode、paged KV cache、动态 batch 等场景。

对于 mini vLLM 系列来说,我们不需要一开始绑定某个 backend。更好的写法是保留 backend 抽象。

例如 ModelRunner 不直接写死某个 attention 实现,而是通过 AttentionBackend 调用:

attention_backend.forward_decode(
    query,
    kv_cache,
    block_tables,
    context_lens,
    slot_mapping,
)

这样系统可以有多个实现:

PyTorchPagedAttentionBackend:
  用于教学和正确性验证

TritonPagedAttentionBackend:
  用于学习 kernel 优化

FlashAttentionBackend:
  用于 prefill 或连续 attention 场景

FlashInferBackend:
  用于更接近生产的 decode 推理

这个抽象的意义很大。

它让 mini vLLM 的上层结构保持稳定。

Scheduler 不关心底层 attention 怎么算。

Block Manager 不关心 kernel 怎么加载 K/V。

Sequence 不关心用的是 PyTorch、Triton 还是别的 backend。

只有 ModelRunner 和 AttentionBackend 之间需要适配。

这就是前面一直强调模块边界的原因。

好的系统设计不是每个模块都很复杂,而是复杂度出现在应该出现的位置。

从 PyTorch 到 Triton 的演进路径

对于这个系列,我建议采用四阶段优化路线。

第一阶段是 PyTorch 教学版。

它先 gather,再 attention。这个版本慢,但最好读,也最适合验证 block table 和 slot mapping 的正确性。

第二阶段是 PyTorch 向量化版本。

尽量减少 Python 循环,把部分 gather 和 attention 用张量操作组织起来。它仍然不会特别快,但可以帮助读者理解 batch 化和形状处理。

第三阶段是 Triton decode kernel。

把 block table lookup、K/V 读取、score 计算、softmax、V 聚合放到 GPU kernel 中。这个阶段重点学习 kernel 如何处理不连续 KV cache。

第四阶段是接入成熟 backend。

当我们已经理解原理后,可以接入 FlashAttention 或 FlashInfer 这类方案,用于更接近实际工程的性能表现。

这条路线的好处是,每一步都有意义。

PyTorch 版本负责正确性。

向量化版本负责理解 batch。

Triton 版本负责理解 kernel。

成熟 backend 负责工程性能。

如果直接从成熟 backend 开始,读者很难知道里面到底解决了什么问题。

如果一直停留在 PyTorch 版本,又无法体现推理系统的性能价值。

这四步结合起来,比较适合技术博客系列。

性能优化不能只看 tokens per second

做 attention 优化时,很容易只看吞吐,比如 tokens per second。

但这不是唯一指标。

在线推理系统至少要同时看几类指标。

第一类是吞吐。

单位时间生成多少 output tokens。这能反映系统整体产能。

第二类是首 token 延迟。

用户从发出请求到看到第一个 token,需要多久。这主要受排队和 prefill 影响。

第三类是每 token 延迟。

流式输出过程中,相邻 token 间隔多久。这主要受 decode 调度和 decode attention 性能影响。

第四类是显存利用率。

同样一张 GPU,能同时承载多少请求,KV cache 浪费多少,free blocks 是否足够。

第五类是尾部延迟。

P90、P99 延迟往往比平均值更能反映服务体验。一个调度策略平均表现不错,但如果长请求经常拖垮部分用户,它仍然不适合线上服务。

attention kernel 优化主要会改善 decode 阶段的每 token 延迟和整体吞吐。

Paged KV cache 管理则会改善显存利用率和可承载并发数。

Scheduler 策略会影响首 token 延迟、尾部延迟和输出流畅度。

所以 benchmark 时要把这些指标分开看。

否则很容易出现一个问题:某个优化让总吞吐变高了,但首 token 延迟变差了,用户体验反而下降。

这也是系统优化和单点算子优化的区别。

算子优化追求局部更快。

系统优化要看整体体验。

实现高性能版本前要先守住正确性

优化 PagedAttention 时,最容易出现的问题不是程序报错,而是生成结果悄悄变错。

因为 shape 可能都对,kernel 也能跑,但读取的 KV cache 位置错了。

一旦 block table、slot mapping、context length 有任何不一致,模型就会看到错误的历史上下文。

这种错误会表现为生成质量下降、长文本变乱、随机出现重复或不合逻辑内容。

所以优化之前一定要有正确性测试。

最基本的测试是和连续 KV cache 对齐。

同样的 query,同样的 K/V 内容,一份放在连续 cache,一份放在 paged cache。两个 attention 输出必须接近。

接着测试不同 context length。

比如 1、block_size 减 1、block_size、block_size 加 1、多个 block、很长上下文。

这些边界最容易出错。

还要测试 batch 中不同请求拥有不同 block table 和不同 context length。

因为 PagedAttention 的价值就在于动态 batch。如果只测试整齐输入,很多 bug 暴露不出来。

最后再接真实模型测试。

真实模型测试可以观察生成结果是否和连续 KV cache 版本基本一致。采样模式下有随机性,可以先用 greedy decoding 方便对齐。

这类测试不是附属品,而是优化的前提。

没有这些测试,后面写 Triton 或 CUDA kernel 时会非常痛苦。

mini vLLM 在这一阶段应该做到什么

这一篇之后,mini vLLM 不一定马上拥有生产级 attention 性能。

但它应该形成清晰的优化方向。

当前阶段可以做到:

有一个 PyTorch 教学版 PagedAttention
有连续 KV cache 对齐测试
有 block table 和 slot mapping 的边界测试
有 AttentionBackend 抽象
ModelRunner 可以切换不同 backend
Scheduler 和 Block Manager 不依赖具体 backend

这已经足够重要。

因为从这一刻开始,系统的结构是可演进的。

以后无论是手写 Triton,接入 FlashAttention,还是接入 FlashInfer,都不需要重写整个 Engine。

只需要替换 attention backend。

这就是工程抽象的价值。

一个好的 mini vLLM 不应该只是把一堆功能硬编码在一起,而应该让核心路径可以逐步替换。

先正确,再快。

先清晰,再复杂。

先抽象边界,再接高性能实现。

这是整个系列应该坚持的路线。

小结

这一篇我们从教学版 PagedAttention 走向了性能优化视角。

教学版实现先 gather 再 attention,适合理解和验证,但不适合高性能推理。它慢的主要原因是显式拷贝历史 K/V、Python 循环多、中间张量多,并且很难高效处理动态 batch。

高性能 PagedAttention 的方向是:在 kernel 内部直接根据 block table 读取 physical KV cache,把寻址、score 计算、softmax 和 value 聚合尽量融合起来,避免显式 gather 和大量中间结果。

同时,我们也看到 prefill 和 decode 的优化路径不同。prefill 更适合使用大矩阵 attention 优化,例如 FlashAttention 类方案。decode 更关注动态 batch、paged KV cache 和不规则上下文访问,因此需要专门的 PagedAttention decode kernel 或推理后端。

对 mini vLLM 来说,最合理的路线不是一上来写复杂 CUDA,而是分阶段演进:

PyTorch 教学版保证正确
PyTorch 向量化版本理解 batch
Triton 版本学习 kernel 优化
成熟 backend 提供工程性能

更重要的是,我们应该抽象出 AttentionBackend,让 Scheduler、Block Manager、Sequence 和 Engine 不依赖具体 attention 实现。这样,系统才能随着后端能力逐步演进。

下一篇文章,我们会暂时从 attention kernel 回到生成结果本身,开始实现 Sampling 系统。

到目前为止,我们一直把 sampler 当成一个简单模块。但真正的推理服务需要支持 temperature、top p、top k、repetition penalty、stop words、多个返回序列,以及更复杂的输出控制。

Sampling 看起来只是从 logits 里选 token,但它会直接影响模型行为、服务接口和请求状态管理。