(LLM decoding) beam search

一句话总结 beam search

从同一个 prompt 开始,把输入序列复制成多条 beam,第一步只允许一个 beam 扩展出全部候选 token 并选出 top-k 个作为初始分支;之后每一轮都对所有存活的 beam 进行前向计算,得到每个 beam 的下一个 token 分布,叠加历史分数后在所有候选中选出 top-k 保留,遇到 EOS 的序列交给 beam_scorer 收集,直到所有 beam 完成或达到最大长度,最后输出分数最高的若干完整序列。

记录LLM decoding的beam search 策略在transformers库中的实现

beam search 伪代码

为了搞清楚beam search在transformers库中的过程,把其中设计工程的部分删除了,留下了重要的部分写成伪代码的形式。
transformers\generation\utils.py

function BEAM_SEARCH_CORE(model, input_ids, num_beams, max_len,
                          logits_processor, stopping_criteria,
                          pad_id, eos_id, do_sample=False):

    # 形状约定:
    # input_ids: [B * nb, L],初始常由 [B, L] 扩展/重复到 [B*nb, L]
    B = batch_size_from(input_ids, num_beams)
    V = model.vocab_size
    nb = num_beams
	
    # 1) 初始化 beam 分数:首列 0,其他列 -INF
    beam_scores = zeros([B, nb]); beam_scores[:, 1:] = -INF
    beam_scores = beam_scores.view(B * nb)   # [B*nb]

	-------------------------
	假设 
	batch_size (B) = 1
	beam_size (nb) = 3
	prompt 长度 (L) = 2
	词表大小 (V) = 5
	pad = 0, eos = 4
	input_ids 是 [1, 10]
	
	input_ids =
	[
		 [1, 10],   # beam 0
		 [1, 10],   # beam 1
		 [1, 10],   # beam 2
	]   # shape = [3, 2]
	--------------------------

    past = None
    while not finished:
        # 2) 前向,取最后一步 logits(仅自回归一步输出)
        logits, past = model.forward(input_ids, past=past)    # logits: [B*nb, L, V]
        step_logits = logits[:, -1, :]                        # [B*nb, V]
		--------------------------
		模型给出每个 beam 的下一步 logits(这里举个假设概率):
		next_token_probs (log_softmax 后):
		beam 0: [0.05, 0.10, 0.70, 0.10, 0.05]
		beam 1: ...
		beam 2: ...
		--------------------------
        # 3) 转 log 概率 + 规则化(温度/禁词/重复惩罚等)
        step_scores = log_softmax(step_logits)                # [B*nb, V]
        step_scores = logits_processor(input_ids, step_scores)# [B*nb, V]

        # 4) 加上累计的 beam 分数(序列级)
        step_scores = step_scores + beam_scores[:, None]      # [B*nb, V]

        # 5) reshape 成每个 batch 独立选择 (nb*V) 的全局前 K
        step_scores = step_scores.view(B, nb * V)             # [B, nb*V]

        # 6) 选择候选(采样或确定性 top-k)
        k = max(2, 1 + count_eos(eos_id)) * nb
        if do_sample:
            probs = softmax(step_scores)                      # [B, nb*V]
            next_tokens = multinomial(probs, num_samples=k)   # [B, k](索引范围 0..nb*V-1)
            # 将抽到的候选按分数重排(降序)
            cand_scores = gather(step_scores, next_tokens)
            next_tokens = sort_by_desc(cand_scores, next_tokens)
            next_scores = sort_desc(cand_scores)
        else:
            next_scores, next_tokens = topk(step_scores, k)   # 都是 [B, k]

        # 7) 还原:来源 beam 索引 + 词表索引
        next_indices = floor_div(next_tokens, V)              # [B, k] in [0..nb-1]
        next_tokens  = next_tokens % V                        # [B, k] in [0..V-1]

        # 8) 交给 beam_scorer 进行“存活/完成/排序”裁决(核心)
        #    返回:保留的 nb 个 beam 的新分数/新 token/来自哪个旧 beam
        out = beam_scorer.process(
                input_ids, next_scores, next_tokens, next_indices,
                pad_token_id=pad_id, eos_token_id=eos_id)

        beam_scores      = out.next_beam_scores               # [B*nb]
        beam_next_tokens = out.next_beam_tokens               # [B*nb]
        beam_src_indices = out.next_beam_indices              # [B*nb](全局索引 0..B*nb-1)

        # 9) 重排并拼接 token;缓存也按相同索引重排
        input_ids = concat( input_ids[beam_src_indices], beam_next_tokens[:, None] )
        past      = reorder_cache(past, beam_src_indices)

        # 10) 停止判定(达到 max_len / beam_scorer 判定完成 / 其他条件)
        if beam_scorer.is_done() or stopping_criteria(input_ids) or length(input_ids) >= max_len:
            break

    # 11) 收尾:整合已完成与在生的 beam,输出最终序列与分数
    result = beam_scorer.finalize(
                input_ids, beam_scores,
                pad_token_id=pad_id, eos_token_id=eos_id, max_length=max_len)
    return result.sequences, result.sequence_scores

注意

初始化 beam 分数

在 Hugging Face 的 _beam_search 里,第一次迭代时确实只有第一个 beam 会被真正展开,其他 beam 的分数一开始设为 −1e9(负无穷),保证它们不会被选中。这样做的目的就是避免一开始多个 beam 完全重复 prompt,导致前几个 token 都一样,浪费算力。
流程可以这么理解:

  1. 初始化
    • beam_scores = [0, -1e9, -1e9, ...]
    • 只有 beam 0 是合法候选,其余 beam 被“屏蔽”。
  2. 第一次扩展
    • 从 beam 0 展开出 vocab_size 个 token。
    • 取 top-k(k ≥ num_beams),交给 beam_scorer
    • 这时候 beam_scorer 会选出 num_beams 条路径,分配给 beam 0、beam 1、beam 2。
    • 这样每个 beam 就都拿到了不同的前缀。
  3. 后续步骤
    • 从第二次迭代开始,所有 beam 的分数都是真实值了(不再是 −1e9),所以后续都会被同时展开。
      在这里插入图片描述

规则化(温度/禁词/重复惩罚等)

transformers\generation\utils.py
在这里插入图片描述
transformers\generation\logits_process.py

伪代码中3) 转 log 概率 + 规则化(温度/禁词/重复惩罚等)

在 transformers 库的 generate 方法中,LogitsProcessorList 被用来实现各种复杂的文本生成策略。例如:

  • 控制重复: RepetitionPenaltyLogitsProcessor 会被加入列表,以降低已生成词元的分数。
  • 采样策略: TopKLogitsWarperTopPLogitsWarperTemperatureLogitsWarper 会被加入列表,以实现 Top-K、Top-P (Nucleus) 和温度采样。
  • 强制约束: MinLengthLogitsProcessor 会被加入,以禁止在生成文本达到最小长度前产生 EOS (end-of-sequence) 标记。
    通过将这些不同的处理器组合在一个 LogitsProcessorList 中,transformers 库能够以一种模块化、可扩展的方式,灵活地控制文本生成过程。

采样 vs 确定性:

transformers\generation\utils.py
在这里插入图片描述

伪代码中 6) 选择候选(采样或确定性 top-k)
do_sample=True 时【采样解码】,它对 (batch, num_beams*vocab) 的分数做 softmax 后抽样 n_tokens_to_keep 个,再排序取前 B 个;【贪心解码】否则直接 topk。

Beam Search (束搜索) 算法的核心决策与状态更新阶段

可视化可参考 b站up”五道口纳什“:
【github】https://github.com/chunhuizhang/llm_rl/blob/main/tutorials/search-learn/llm_beam_search.ipynb
【[LLM+RL] model.generate 之 beam search decoding strategy】 https://www.bilibili.com/video/BV1SbrKYkETg/?share_source=copy_web&vd_source=ac07d48defe559c6322dff2c05c6eb10

伪代码中 8) 交给 beam_scorer 进行“存活/完成/排序”裁决(核心) 部分
在这里插入图片描述

BeamScorer
transformers\generation\beam_search.py

  • BeamSearchScorer
  • ConstrainedBeamSearchScorer
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值