Beam Search(集束搜索/束搜索)

作者:Fyuocuk
链接:https://www.zhihu.com/question/54356960/answer/293804923
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
 

首先需要确定一个`Beam Size`,这里设置为2,意思是每个`word`后面的分支考虑概率最大的那两个`words`。比如下面的例子,从下往上首先分成A、B两个words,然后继续往上传播,句子变成是AA/AB/BA/BB这四种情况(绿色虚线)。考虑到`Beam Size=2`,选择概率最大的两个,假设是AB/BA(橙色大箭头)。然后以选择的AB/BA继续向上传播,又出现了四种情况ABA/ABB/BBA/BBB,依然是选择综合概率最大的两个ABB/BBB。以此类推,直至句子结束。只要可以调整好`Beam Size`,就能够使用最小的计算量,得到最优的结果。

 

### 集束搜索 (Beam Search) 算法实现 集束搜索是一种用于优化解码过程的算法,在自然语言处理(NLP)、机器翻译等领域广泛应用。它通过维护一组可能的最佳路径来扩展贪婪搜索的思想,从而提高生成序列的质量。 以下是基于 Python 的简单集束搜索代码示例: ```python import numpy as np def beam_search(scores_fn, start_token, end_token, max_len=10, beam_width=3): """ 使用集束搜索生成序列 参数: scores_fn: 计算给定前缀后的下一个 token 得分函数 输入为当前序列,返回为形状为(vocab_size,)的得分数组 start_token: 起始标记 end_token: 结束标记 max_len: 序列最大长度 beam_width: 束宽 返回: best_sequence: 最优序列及其对应的分数 """ beams = [(start_token, 0)] # 初始束列表,包含起始标记和初始得分为零 sequences = [] # 存储完成的序列 for _ in range(max_len): new_beams = [] for seq, score in beams: if seq[-1] == end_token: # 如果遇到结束标记,则保存该序列并跳过后续计算 sequences.append((seq, score)) continue # 获取下一时刻所有可能token的得分 next_scores = scores_fn(seq) # 将每个候选加入到新束中 for i in range(len(next_scores)): candidate_seq = list(seq) + [i] candidate_score = score + next_scores[i] new_beams.append((candidate_seq, candidate_score)) # 按照得分排序,并截取前beam_width个最佳候选项 sorted_new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width] # 更新束列表 beams = sorted_new_beams # 添加剩余未完成的序列至sequences sequences += [(b[0], b[1]) for b in beams if b[0][-1] != end_token] # 找到最优序列 best_sequence = max(sequences, key=lambda x: x[1]) return best_sequence # 假设的一个打分函数scores_fn模拟器 vocab_size = 5 np.random.seed(42) def mock_scores_fn(sequence): """模拟一个随机打分函数""" return np.random.rand(vocab_size) best_seq, best_score = beam_search(mock_scores_fn, start_token=[0], end_token=4, max_len=8, beam_width=2) print(f"Best Sequence: {best_seq}, Score: {best_score}") ``` 上述代码定义了一个通用的 `beam_search` 函数,其中输入是一个评分函数 `scores_fn` 和其他必要的参数如起始标记、终止标记以及束宽等。此函数会逐步构建序列直到达到指定的最大长度或者找到足够的有效序列为止[^4]。 #### 关键点解释 - **束宽 (`beam_width`) 控制着每一步保留多少条最有可能的路径**,较大的束宽可以提升准确性但也增加了计算成本。 - 当某个序列到达结束标志时会被立即存储起来不再参与进一步扩展。 - 在每一时间步上都会重新评估所有的可能性并将它们按总累积概率降序排列后选取顶部若干项作为新的候选集合[^5]。 ### 注意事项 实际应用中的 `scores_fn` 可能来自复杂的神经网络模型输出层经过 softmax 处理之后的结果向量;因此需要依据具体场景调整相应部分逻辑以适配不同类型的底层架构需求。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值