Sampling 系统设计:从 logits 到可控生成

前面几篇文章里,我们已经把 mini vLLM 的核心推理链路搭了起来。

请求会被抽象成 Sequence。Scheduler 决定每一轮哪些请求做 prefill,哪些请求做 decode。Block Manager 负责管理 KV cache blocks。PagedAttention 根据 block table 读取历史 KV。ModelRunner 执行模型 forward,最后得到 logits。

到这里,模型已经完成了它的工作。

但推理还没有结束。

因为 logits 不是最终输出。

logits 只是模型对词表中每个 token 给出的原始分数。我们还需要根据这些分数,选择下一个 token。

这个过程就是 sampling。

在最早的版本里,我们已经写过一个简单 Sampler。它支持 greedy、temperature、top k 和 top p。那时它只是 generate 函数里的一个小工具。

但到了推理服务阶段,Sampling 不再只是几行代码。

它会影响 API 参数设计,会影响 Sequence 状态,会影响 batch 内不同请求的处理方式,会影响流式输出,也会影响结果是否可复现。

这一篇我们要把 Sampler 从一个小函数,扩展成一个真正的 Sampling 系统。

目标不是把所有采样技巧一次性实现完,而是建立一个清晰的结构:从 logits 进入 Sampler,到 token 被选出,再到 Sequence 更新状态,这条链路应该怎么设计。

为什么不能把 sampling 写死在 generate 里

在单请求 demo 里,sampling 很容易写在 generate 循环里。

logits = outputs.logits[:, -1, :]
next_token = sample(logits, temperature, top_p)

这样写没有问题。

但在推理引擎里,这种写法会很快失控。

因为一个 batch 中的不同请求可能使用不同的采样参数。

请求 A 可能是 greedy decoding。

请求 B 可能使用 temperature 0.8。

请求 C 可能使用 top p 0.95。

请求 D 可能设置了 stop words。

请求 E 可能要求返回多个候选结果。

如果 sampling 逻辑散落在 Engine 或 ModelRunner 里,后续很难维护。

更合理的设计是把 sampling 做成独立模块。

ModelRunner 只负责输出 logits。

Sampler 负责根据每个 Sequence 的 SamplingParams 选出 token。

Sequence 负责接收 token,并判断是否结束。

也就是说,职责应该是:

ModelRunner: 产生 logits
Sampler: 从 logits 中选择 token
Sequence: 追加 token 并更新状态
Engine: 组织整个流程

这个边界非常重要。

因为 attention backend、KV cache backend、Scheduler 策略都不应该关心采样细节。

采样策略也不应该反过来影响模型执行结构。

一个好的推理系统,应该能在不改 ModelRunner 的情况下增加新的 sampling 方法。

SamplingParams 是用户意图的结构化表达

Sampler 的输入不只是 logits,还包括每个请求的 SamplingParams。

SamplingParams 可以理解成用户对生成行为的控制。

最常见的字段包括:

class SamplingParams:
    max_tokens: int
    temperature: float
    top_p: float
    top_k: int
    repetition_penalty: float
    stop_token_ids: list[int]
    stop_words: list[str]
    seed: int | None

这些字段背后代表不同的控制维度。

max_tokens 控制最多生成多少 token。它不是采样策略本身,但会影响停止条件。

temperature 控制分布的尖锐程度。值越低,输出越确定;值越高,随机性越强。

top_k 限制只从概率最高的 k 个 token 中采样。

top_p 限制只从累计概率达到 p 的候选集合中采样。

repetition_penalty 用来降低重复 token 的概率,减少模型反复输出相同内容。

stop_token_idsstop_words 用来控制生成何时停止。

seed 用于随机采样的可复现性。

把这些字段集中到 SamplingParams 里,有两个好处。

第一,Sequence 可以持有自己的采样配置。

这样 batch 内不同请求可以有不同策略。

第二,API 层和引擎层之间有清晰边界。

用户请求中的 JSON 参数会被解析成 SamplingParams,后续系统只处理这个结构,而不是到处传散乱字段。

从 logits 到概率分布

模型输出的 logits 形状通常是:

[batch_size, vocab_size]

每一行对应一个 Sequence。

每一列对应词表中的一个 token。

logits 不是概率,它只是未归一化的分数。

如果要做随机采样,需要先经过 softmax:

probs = softmax(logits)

然后从这个概率分布中抽取 token。

但真实采样通常不会直接对原始 logits softmax。

它会先经过一系列 logits processors。

例如:

temperature scaling
repetition penalty
top k filtering
top p filtering
bad words filtering
min tokens filtering

这些处理会修改 logits,然后再 softmax。

所以 Sampler 的内部可以拆成两步:

LogitsProcessor: 修改 logits
TokenSelector: 从处理后的 logits 中选择 token

这种拆分很有用。

因为很多采样规则本质上不是“选择 token”,而是在选择之前改变候选 token 的分数或屏蔽某些 token。

例如 top k 和 top p 是过滤候选集合。

repetition penalty 是调整已经出现过 token 的分数。

stop token 的某些控制也可能通过屏蔽 token 实现。

把这些逻辑都写进一个大 sample 函数会很乱。

更清晰的做法是让多个 processor 依次作用在 logits 上。

for processor in processors:
    logits = processor(sequence, logits)

最后再交给 selector。

Greedy 和 Sampling 是两种不同模式

最简单的生成方式是 greedy。

它直接选择 logits 最大的 token。

next_token = argmax(logits)

greedy 的优点是确定、稳定、容易调试。

在做系统正确性测试时,greedy 非常有用。

因为它没有随机性。只要模型、输入、cache 和 logits 一样,输出就应该完全一样。

这对验证 PagedAttention 很重要。

如果使用随机 sampling,输出不同不一定说明系统错了,可能只是随机性导致的。

所以在 mini vLLM 的测试里,建议优先用 greedy 对齐连续 KV cache 和 paged KV cache。

但 greedy 也有缺点。

它容易输出保守结果,也更容易陷入重复。对于开放式写作、聊天、创意生成,适当的随机性通常更好。

随机 sampling 的基本流程是:

处理 logits
softmax 得到概率
按概率抽样

这里 temperature、top k、top p 都是为了控制这个随机过程。

所以 Sampler 可以先根据 SamplingParams 判断模式。

temperature == 0:
  greedy

temperature > 0:
  sampling

这个判断看起来很小,但很重要。

因为 greedy 不需要 softmax,也不需要随机数。它可以更快,也更容易复现。

temperature 不是简单的“创造力旋钮”

temperature 经常被解释成“创造力”,但这个说法有点粗糙。

更准确地说,temperature 改变的是概率分布的尖锐程度。

处理方式是:

logits = logits / temperature

当 temperature 小于 1 时,高分 token 和低分 token 的差距会被放大,分布更尖锐。

当 temperature 大于 1 时,差距会被压缩,分布更平滑。

如果 temperature 非常接近 0,最高分 token 几乎一定会被选中,效果接近 greedy。

这带来的系统设计问题是:temperature 为 0 时,不应该真的做除法。

因为除以 0 没意义。

所以一般把 temperature 0 视为 greedy 模式。

temperature = 0:
  argmax

temperature > 0:
  logits / temperature
  sample

这也提醒我们,SamplingParams 需要做参数校验。

用户传入负数 temperature,top_p 大于 1,top_k 小于 0,这些都应该在 API 层或 SamplingParams 初始化时处理掉,而不是留到模型执行中报错。

top k 和 top p 的不同思想

top k 和 top p 都是在采样前过滤候选 token,但它们的思想不同。

top k 是固定数量。

如果 top k 等于 50,每一步只从 logits 最高的 50 个 token 中采样。

它简单、稳定,但不够自适应。

有些时候模型非常确定,前几个 token 的概率已经占了绝大部分。保留 50 个可能太多。

有些时候模型不确定,合理候选很多。只保留 50 个可能太少。

top p 是按累计概率过滤,也叫 nucleus sampling。

它会先按概率从高到低排序,然后保留累计概率达到 p 的那部分 token。

如果 top p 是 0.9,那么每一步保留的是能够覆盖 90% 概率质量的候选集合。

这让候选集合大小可以随上下文变化。

模型很确定时,候选集合小。

模型不确定时,候选集合大。

所以 top p 通常更灵活。

在实现上,top k 可以在 logits 空间直接做。

top p 通常需要排序并计算 softmax 后的累计概率。

这意味着 top p 的计算成本更高。

在小 batch、普通服务场景里,这点成本通常可以接受。

但在大规模 serving 中,sampling 也可能成为瓶颈,尤其是 vocab size 很大、batch 很大时。

所以采样系统也需要关注性能,而不是把所有注意力都放在 attention 上。

repetition penalty 为什么属于状态相关处理

repetition penalty 和 temperature、top k、top p 不一样。

temperature 只看当前 logits。

top k 和 top p 也只看当前 logits 分布。

但 repetition penalty 需要知道当前 Sequence 之前生成过哪些 token。

也就是说,它是状态相关的。

它会根据 generated_token_ids 或完整上下文 token ids 来调整 logits。

一种常见思路是:对于已经出现过的 token,降低它们再次被选中的概率。

概念上可以理解为:

如果 token 已经出现过:
  调整它的 logit

这里要注意,repetition penalty 的实现细节有不同版本。不同框架可能处理正负 logits 的方式不同。所以在 mini vLLM 里,第一版不需要追求覆盖所有细节,可以先实现一个简单版本。

真正重要的是把它放在正确的位置。

它应该是一个 LogitsProcessor。

它的输入是当前 Sequence 和当前 logits。

它读取 Sequence 的历史 tokens,然后修改 logits。

class RepetitionPenaltyProcessor:
    def __call__(self, sequence, logits):
        token_ids = sequence.get_token_ids()
        apply_penalty(logits, token_ids)
        return logits

这体现了一个重要原则:

凡是依赖生成历史的采样规则,都应该能访问 Sequence 状态

这也是为什么 Sampler 的输入不应该只有 logits,还应该有 Sequence 或至少有每个请求的历史 token 信息。

stop token 和 stop words 的位置

停止条件是 sampling 系统里很容易混乱的一部分。

有些停止条件是 token 级别的。

例如 EOS token、stop token ids、max tokens。

这些可以在每次采样出 token 后立刻判断。

sample next token
append token
check eos
check max tokens
check stop token ids

有些停止条件是文本级别的。

例如 stop words。

用户可能设置 stop words 为:

["\n\nUser:", "</answer>"]

这些不是单个 token 能稳定表示的。它们可能跨多个 token。

所以 stop words 通常要在 detokenization 之后判断。

也就是说,Sequence 追加 token 后,Detokenizer 生成新的文本片段,然后检查完整输出文本是否包含 stop words。

如果命中,就设置 stop_reason。

这里会出现一个细节:stop words 本身要不要返回给用户?

有些 API 会截断 stop words,不把它们返回。

有些实现会返回已经生成的内容,再由用户处理。

mini vLLM 需要选择一种语义,并保持一致。

我建议第一版采用比较常见的方式:stop words 用于停止,但最终输出不包含 stop words。

这意味着 detokenizer 或输出层要能在命中 stop words 时截断文本。

这个逻辑不应该散落在 Sampler 里。

Sampler 只负责 token 选择。

停止条件可以放在 Sequence 的状态更新中,但文本级 stop words 需要 Detokenizer 或 OutputProcessor 配合。

所以更清晰的链路是:

Sampler 选择 token
Sequence 追加 token
Detokenizer 生成增量文本
StopChecker 检查 token 级和文本级停止条件
Sequence 更新 finished 状态

这样结构会比把所有逻辑都塞进 sample 函数清楚得多。

batch 内不同请求有不同 SamplingParams

推理服务里,batch 是 Scheduler 动态组出来的。

它不会保证同一个 batch 内所有请求拥有相同 sampling 参数。

这会让 Sampler 比单请求版本复杂。

假设 batch size 是 4。

Sequence A: greedy
Sequence B: temperature 0.7, top_p 0.9
Sequence C: temperature 1.0, top_k 50
Sequence D: repetition_penalty 1.2

ModelRunner 输出的是一个 logits batch:

[4, vocab_size]

Sampler 需要按行处理。

每一行使用对应 Sequence 的 SamplingParams。

最简单的实现是循环每个 Sequence。

next_tokens = []

for sequence, logits in zip(sequences, logits_batch):
    token = sample_one(sequence, logits)
    next_tokens.append(token)

这个版本清楚,但可能不够快。

更高性能的实现会把具有相同参数的请求分组,或者用向量化方式批量处理部分 processor。

但第一版不需要过度优化。

这里更重要的是保证抽象正确。

Sampler 的输入应该是:

sequences
logits_batch

而不是一个全局 SamplingParams。

因为每个 Sequence 都可能有自己的参数。

这也是 serving 系统和单请求脚本的区别。

随机数和可复现性

只要涉及随机 sampling,就会涉及随机数。

对于服务系统来说,可复现性是一个值得认真对待的问题。

用户可能希望设置 seed,让相同 prompt 和相同参数得到相同输出。

调试时,我们也希望复现某个问题请求。

如果 Sampler 使用全局随机数生成器,那么多个请求在同一个 batch 中的执行顺序可能影响结果。

比如请求 A 和请求 B 同时采样。它们谁先消耗随机数,会影响另一个请求拿到的随机数。

这会让结果和调度顺序耦合。

更合理的做法是为每个 Sequence 维护自己的随机数状态。

Sequence A has generator A
Sequence B has generator B

这样每个请求的随机过程相对独立。

当然,GPU 上高性能采样的随机数管理会更复杂。第一版 mini vLLM 可以先在 PyTorch 层为每个请求创建 generator,确保语义清晰。

如果不支持 seed,也应该明确说明输出不保证复现。

但从系统设计角度看,SamplingParams 中保留 seed 字段是有价值的。

因为采样是生成行为的一部分,而生成行为最好可以被记录和复现。

多输出序列对 Sampling 的影响

很多 LLM API 支持一个 prompt 返回多个结果。

例如:

n = 4

这意味着同一个 Request 会生成 4 条不同 Sequence。

这些 Sequence 共享同一个 prompt,但采样过程不同。

这会带来几个问题。

第一,Request 和 Sequence 不再是一对一。

这就是我们前面提到 SequenceGroup 的原因。

第二,prefill 可能可以共享。

同一个 prompt 的多个输出序列,prompt KV cache 是一样的。理论上可以只 prefill 一次,然后复制或共享初始 KV cache。

第三,sampling 必须独立。

每个输出序列应该有自己的随机过程,否则多个结果可能完全一样。

第四,停止条件分别判断。

4 条 Sequence 可能在不同时间结束。

所以 Sampling 系统天然会推动我们使用 SequenceGroup。

第一版 mini vLLM 可以先不完整实现 n > 1,但文章里应该说明这个方向。

这也是为什么一开始把 Request 和 Sequence 区分开是有意义的。

一个 Request 是用户请求。

一个 Sequence 是一条具体生成路径。

Sampling 决定每条路径如何继续展开。

Sampler 和 Beam Search 的关系

到目前为止,我们讨论的是常见 sampling,不是 beam search。

beam search 和 sampling 的思路不同。

sampling 是从概率分布中选下一个 token。

beam search 是维护多个候选序列,每一步扩展并保留分数最高的若干条路径。

如果未来要支持 beam search,Sampler 就不再只是为每个 Sequence 选一个 token。

它可能要为一个 SequenceGroup 生成多个候选 token,并更新候选序列的分数。

这会影响 Scheduler、SequenceGroup、KV cache 共享和释放。

所以第一版 mini vLLM 不建议一开始支持 beam search。

但是 Sampler 的设计最好不要把路堵死。

比如不要假设每个请求永远只返回一个 token。

可以让 Sampler 输出结构稍微通用一点。

class SamplerOutput:
    sequence_id: str
    token_id: int
    logprob: float | None

如果未来支持 beam search,可以扩展成多个候选。

当然,第一版只需要每个 Sequence 一个 token。

这里的重点不是过度设计,而是知道哪些假设未来可能被打破。

logprobs 要不要支持

有些 API 会返回每个输出 token 的 logprob。

这对评估、调试、rerank 都很有用。

如果要支持 logprobs,Sampler 就不能只返回 token id。

它还需要返回被选中 token 的 log probability。

甚至可能需要返回 top logprobs。

例如:

选中 token 的 logprob
候选 top k token 及其 logprob

这会增加采样系统的输出复杂度。

对于 mini vLLM 第一版,可以先只保存 selected token 的 logprob,top logprobs 后面再加。

需要注意的是,logprob 应该基于处理后的分布还是原始模型分布?

不同 API 可能有不同语义。

一般来说,如果用户设置了 temperature、top p 等采样参数,返回的 logprob 通常需要明确是基于哪一个分布计算的。

第一版可以不展开这个复杂度,但要在设计上让 SamplerOutput 能承载 logprob。

这会让系统后续更容易扩展。

Sampling 在整个 Engine.step 中的位置

现在我们把 Sampling 放回 Engine.step 里。

一轮 step 大概是:

Scheduler 选出 prefill 和 decode sequences
ModelRunner 执行 forward
得到 logits
Sampler 根据每个 Sequence 的 SamplingParams 采样 token
Sequence 追加 token
StopChecker 判断是否结束
Detokenizer 生成增量文本
Engine 返回 RequestOutput

这里最重要的是顺序。

先 forward,再 sampling。

先 append token,再判断 token 级停止。

再 detokenize,再判断文本级停止。

最后生成 RequestOutput。

如果请求结束,Scheduler 后续不再调度它,Block Manager 释放它的 KV cache blocks。

这条链路把前面所有模块串起来了。

Sampling 不再是一个孤立函数,而是请求生命周期中的一个关键节点。

它决定 Sequence 下一步走向,也决定用户看到什么文本。

小结

这一篇我们把 Sampler 从一个简单函数,扩展成了推理服务中的 Sampling 系统。

Sampling 的输入是模型 logits 和每个 Sequence 的 SamplingParams。它的输出不是最终文本,而是下一步要追加到 Sequence 的 token。

一个清晰的 Sampling 系统应该区分几类职责:

LogitsProcessor 负责修改 logits
TokenSelector 负责选择 token
Sequence 负责维护生成历史
StopChecker 负责判断结束条件
Detokenizer 负责把 token 变成增量文本

temperature、top k、top p 主要控制当前概率分布。

repetition penalty 依赖 Sequence 历史状态。

stop token ids 是 token 级停止条件。

stop words 是文本级停止条件,需要和 detokenization 配合。

batch 内不同请求可以拥有不同 SamplingParams,所以 Sampler 不能依赖一个全局采样配置。随机 sampling 还要考虑 seed 和每个 Sequence 独立随机数状态,避免生成结果和调度顺序强耦合。

到这里,mini vLLM 的输出控制能力已经基本成型。

下一篇文章,我们会讨论 prefix caching。

很多线上请求会共享相同的系统提示词、模板或长文档前缀。如果每次都重新 prefill,就会浪费大量计算。Prefix caching 的目标是复用已经计算过的前缀 KV cache,让相同前缀的请求更快进入生成阶段。