BeamSearch的原理和实现

本文详细介绍了BeamSearch的概念,作为一种优化seq2seq模型输出的算法,它介于贪心搜索和全概率计算之间。通过设定束宽(Beamwidth)来平衡搜索效率和结果质量。文章讲解了如何使用堆数据结构实现BeamSearch,通过不断保留概率最高的若干条路径,直至找到最优解。此外,还展示了 Beamwidth=3 的例子以加深理解。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


目标

1.知道beam search的概念和原理
2. 能够在代码中使用Beam search完成预测过程


提示:以下是本篇文章正文内容,下面案例可供参考

一、Beam Search的介绍

在进行模型评估的过程中,每次我们选择概率最大的token id作为输出,那么整个输出的句子的概率就是最大的嘛?
在这里插入图片描述
Beam search有被称为束集搜索,是一种在seq2seq中用来优化输出结果的算法(不在训练过程中使用)。
例如:传统的获取解码器输出的过程中,每次只选择概率最大的那个结果,作为当前时间步的输出,等到输出结束,我们会发现,整个句子可能并不通顺,虽然在每一个时间步上的输出确实是概率最大的,但是整体的概率却不一定最大,我们经常把它叫做greedy search[贪心算法]
为了解决上述的问题,可以考虑计算全部的输出概率乘积,选择最大的那一个,但是这样的话,意味着如果句子很长,候选词很多,那么需要保存的数据就会很大,需要计算的数据量就很大。
那么Beam Search就是介于上述两种方法的一种方法,假设Beam width = 2,表示每次保存的最大的概率的个数,这里每次保存两个,在下一个时间步骤同样保留两个,这样就可以达到约束搜索空间大小的目的,从而提高算法的效率。
beam width = 1时,就是贪心算法,beam width=候选词的时候,就是计算全部的概率。beam width是一个超参数。
比如在下图中:
使用一个树状图来表示每个time step的可能输出,其中的数字表示是条件概率
黄色箭头表示的是一种greed search,概率并不是最大的
如果把beam width设置为2,刚开始的时候保存(A,B),在第二次分类的时候从(AA,AB,BA,BB)中选择两个最大的为(AB,BB),在第三层的中从(ABA,ABB,BBA,BBB)中选择两个最大的为(ABB,BBB)。那么后续就可以找到绿色路径的结果,这个结果是最大的。
在这里插入图片描述
下图在给出beam width=3的例子

  1. 首先输出start token <s>,然后得到四个输出(这里假设一个就是四个输出:x,y,z,</s>),选择概率最大的三个,x,y,w
  2. 然后分别把x,y,z放到下一个time step中作为输入,分别得到三个不同的输出,找到三个输出中概率最大的三个,x,y,y
  3. 继续重复上述步骤,知道获得结束符(概率最大)或者是达到句子的最大长度,那么此时选择概率乘积最大的一个。
  4. 拼接整个路径上概率最大的所有结果,比如这里可能是<s>,y,y,x,w,</s>
    在这里插入图片描述

二、Beam search的实现

在上述描述的思路中,我们需要注意以下几个内容:

  1. 数据该如何保存,每一次的输出的最大的beam width个结果,和之后之前的结果该如何保存
  2. 保存了之后的概率应该如何比较大小,保留下概率最大三个
  3. 不能够仅仅只保存当前概率最大的信息,还需要有当前概率最大的三个中前面路径的输出结果

2.1.数据结构-堆-的认识

对于上面所说的,保留有限个数据,同时需要根据大小来保留,可以使用一种带有优先级的数据结构来实现,这里我们可以使用堆这种数据结构。
是一种优先级的队列,但是它其实并不是队列,我们常说的队列都是先进先出或者先进后出,但是堆只根据优先级的高低来去取出数据。和堆在一起的另外一种数据结构叫做,它有入栈和出栈的操作,可以理解为是一种先进后出的数据结构。
在python自带的模块中,有一个焦作heapq的模块,提供了堆所有的方法。
在这里插入图片描述
可以发现,输出的顺序并不是数据插入的顺序,而是根据其优先级,从小往大pop(False<True)。

2.2. 使用堆来实现beam search

为了实现数据的保存,我们可以把beam search中的数据保存到堆中,同时再往这个堆中添加数据的同时,判断数据的个数,仅仅中保存beam width个数据。
代码如下(示例):
在这里插入图片描述

实现方法,完成模型eval过程中的beam search搜索。
思路:

  1. 构造<sos>开始符号等第一次输入的信息,保存到堆中。
  2. 取出堆中的数据,进行forward step的操作,获得当前时间步的output,hidden
  3. 从output中选择topk(k=beam width)个输出,作为下一次的input
  4. 把下一个时间步需要的输入等数据保存在一个新的堆中
  5. 获取新的堆中的优先级最高(概率最大)的数据,判断数据是否是EOS结尾或者是否达到最大长度,如果是,停止迭代。如果不是,则重新遍历新的堆中的数据。
    在这里插入图片描述
    在这里插入图片描述

<think>好的,我需要回答用户关于beam search算法代码实现的问题。首先,我得回忆一下beam search的基本概念用途。Beam search是一种启发式搜索算法,常用于序列生成任务,比如机器翻译、文本生成等,尤其是在解码阶段。它通过维护一个固定大小的集(称为beam宽度)来平衡计算效率结果质量,避免穷举所有可能的序列。 接下来,用户需要的是代码实现,所以我应该提供一个简单易懂的示例。考虑到用户可能使用Python,选择Python作为示例语言比较合适。我需要确保代码结构清晰,注释明确,方便用户理解每一步的操作。 首先,定义beam search函数的结构。参数方面,需要包括当前序列、当前得分、beam宽度等。然后,生成下一步的可能候选,计算每个候选的得分,并选择得分最高的前k个作为新的beam。这个过程需要循环直到达到终止条件,比如生成结束符号或达到最大长度。 在编写代码时,需要注意处理每一步的扩展得分更新。可能需要使用优先队列或堆结构来维护当前最优的候选序列。同时,要考虑如何记录每一步的路径,避免重复计算。 另外,用户可能对代码中的某些部分不太理解,比如得分函数的实现。这时候需要说明得分函数可以根据具体任务设计,比如语言模型中的概率对数。同时,示例中的占位函数需要明确指出,方便用户替换为自己的实现。 最后,生成相关问题,帮助用户进一步理解beam search的应用优化。例如,调整beam宽度的影响,处理重复的方法,与其他解码策略的比较等。 检查代码是否符合格式要求,数学表达式是否正确使用$...$,独立公式用$$...$$。例如,在解释得分时,可以提到对数概率的,如$\log P(y_t | y_{1:t-1}, x)$,并用$$包裹。同时,确保引用标识正确添加,参考用户提供的引用内容,但可能不需要直接引用,除非有相关内容。 总结步骤:解释算法原理,提供Python代码示例,说明关键部分,生成相关问题,确保格式正确。</think>### Beam Search算法代码实现 Beam search是序列生成任务中常用的解码算法,通过维护固定宽度的候选序列集合来平衡搜索效率与结果质量。算法核心步骤包括: 1. 初始化包含空序列的候选集 2. 每步扩展所有候选序列 3. 保留top-k最高分候选 4. 重复直到达到终止条件 以下Python实现演示了beam search的基本逻辑: ```python import heapq def beam_search_decode(initial_score=0.0, beam_width=3, max_len=10): # 初始化候选堆,使用负分便于最小堆实现最大堆 heap = [ (-initial_score, []) ] for _ in range(max_len): candidates = [] # 扩展当前所有候选 while heap: score, seq = heapq.heappop(heap) current_score = -score # 生成下一步候选(此处为示例,需替换实际生成逻辑) next_tokens = generate_candidates(seq) # 需实现候选生成函数 for token, prob in next_tokens: new_seq = seq + [token] new_score = current_score + log(prob) # 使用对数概率求 # 将新候选加入临时列表 heapq.heappush(candidates, (-new_score, new_seq)) # 保留top-k候选 heap = [] for _ in range(min(beam_width, len(candidates))): if candidates: heapq.heappush(heap, heapq.heappop(candidates)) # 返回最佳结果 best = heapq.heappop(heap) return (-best[0], best[1]) def generate_candidates(current_seq): """示例候选生成函数,需根据具体任务实现 返回格式:[(token, probability), ...]""" # 示例:假设有三个候选,概率为1 return [('A', 0.6), ('B', 0.3), ('<EOS>', 0.1)] ``` 关键组件说明: 1. `heapq`模块实现优先队列,维护当前最佳候选 2. 得分计算使用对数概率求:$$ \sum_{t=1}^T \log P(y_t | y_{1:t-1}, x) $$ 3. `beam_width`控制计算复杂度与结果质量的平衡 相关问题
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值