关于beam search、topk、topp的解释可以参考beam search、top-k sampling、nucleus sampling、temperature sampling和联合采样
beam_search大致框架,不能跑,仅做展示
简单来说就是初始的时候把(bs, cur_seqlen)的input_ids序列repeat_interleave成(bs×beam_size, cur_seqlen)的形状,然后过模型得到(bs×beam_size, vocab_size)的next_token预测,再把next_token预测转换成(bs, beam_size×vocab_size)后取top beam_size,得到(bs, beam_size)的next_token预测。最后除 同时 取余 vocab_size得到(bs, beam_size)的next_token预测beam_id和idx,所谓beam_id就是输入序列input_ids要被下一轮选用的下标,最后把(bs×beam_size, cur_seqlen)的input_ids序列拼上(bs×beam_size,1)的新token。举个例子bs=4, beam_size=3,每次实际上是每个batch从3×vocab_size个next token中取最大的3×beam_size个
import torch
import torch.nn as nn
class MyGPTModel(nn.Module):
def __init__(self):
super(MyGPTModel, self).__init__()
self.model = model
self.tokenizer = tokenizer
self.full_str = '哈尔滨是一个美丽的城市'
...
def prefix_allowed_tokens_fn(self, input_ids_tensor):
# input_ids_tensor.shape = (cur_seq_length)
prefix_str = "".join(self.tokenizer.batch_decode(input_ids_tensor))
suffix_str = self.full_str.replace(prefix_str, "")
allow_token = []
if self.full_str in prefix_str:
return [self.tokenizer.eos_token_id]
for i in range(1, len(suffix_str)+1):
token_arr = self.tokenizer.encode(suffix_str[:i])
if len(token_arr) == 1:
# 如果这个中文短语能用一个token来表示
allow_token.extend(token_arr)
return allow_token
def logits_processor(self, input_ids, next_token_scores):
# input_ids: (bs*beam_size, cur_seq_len)
# next_token_score: (bs*beam_size, vocab_size)
mask = torch.full_like(next_token_scores, float('-inf'))
for batch_beam_id, sent in enumerate(input_ids):
prefix_allowed_tokens = self.prefix_allowed_tokens_fn(sent)
mask[batch_beam_id, prefix_allowed_tokens] = 0
scores_processed = next_token_scores + mask
return scores_processed
def generate(self, prompt, beam_size):
tokenizer_output = self.tokenizer(prompt)
input_ids = tokenizer_output["input_ids"].repeat_interleave(beam_size, dim=0)
attention_mask = tokenizer_output["attention_mask"].repeat_interleave(beam_size, dim=0)
bs, cur_seq_len = input_ids.shape # 这里的bs*beam_size
vocab_size = self.tokenizer.vocab_size
bs = bs // beam_size
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
beam_scores = torch.zeros(bs, beam_size)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.reshape(bs * beam_size, 1)
past_key_values = DynamicCache()
# key的shape是[bs, num_heads, cur_seq_len, head_dim]*num_layers,n层的transformer每层都要cache, 可以用DynamicCache来
while self._has_unfinished_sequences():
model_outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values)
next_token_logits = model_outputs['logits']
past_key_values = model_outputs['past_key_values']
# next_token_logits.shape: [bs * beam_size, vocab_size]
next_token_scores = nn.functional.log_softmax(next_token_logits, dim=-1) # (bs * beam_size, vocab_size)
next_token_scores_processed = self.logits_processor(input_ids, next_token_scores) # next_token_scores_processed仍然是(bs * beam_size, vocab_size)
next_token_scores = next_token_scores_processed + beam_scores
next_token_scores = next_token_scores.view(bs, beam_size * vocab_size)
next_token_scores, next_tokens_indices = next_token_scores.topk(beam_size, dim=1, largest=True, sorted=True)
# topk后next_token_scores和next_tokens_indices的shape变成[bs, beam_size]
# 但由于next_tokens_indices是从beam_size*vocab_size里选的topk,所以要得到beam_id和取余后的indices
beam_ids = next_tokens_indices // vocab_size
next_tokens_indices = next_tokens_indices % vocab_size
input_ids = torch.cat([input_ids[beam_ids.reshape(bs*beam_size), :], next_tokens_indices.reshape(bs*beam_size,1)], dim=-1)
beam_scores = next_token_scores.reshape(bs*beam_size,1) # 别漏了最后这句
return input_ids
beam_search简易框架,能跑
这个LM_prob不好的地方在于把LM生成的seqlen都放到了LM_prob这一个tensor里了,seqlen的每一个长度代表一次结果,而vocab_size注意不是embedding,代表是vocab里哪个词
import torch
import torch.nn.functional as F
def beam_search(LM_prob,beam_size = 3):
batch,seqlen,vocab_size = LM_prob.shape
# 对LM_prob取对数,LM_prob: [batch,seqlen,vocab_size]
# 这个LM_prob不好的地方在于把LM生成的seqlen都放到了LM_prob这一个tensor里了,seqlen的每一个长度代表一次结果,而vocab_size注意不是embedding,代表是vocab里哪个词
log_LM_prob = LM_prob.log()
#先选择第0个位置的最大beam_size个token,log_beam_prob与indices的shape为[batch,beam_size]
log_beam_prob, indices = log_LM_prob[:,0,:].topk(beam_size,sorted = True)
# topk这个函数刚好是batch内每个元素中所有beam_size之中选topk,log_beam_prob的shape是[batch,beam_size]
indices = indices.unsqueeze(-1) # indices的shape是[batch,beam_size,1]
#对每个长度进行beam search
for i in range(1,seqlen):
#每个beam的可能产生的概率,由于是log的概率可以直接加
# log_beam_prob.unsqueeze(-1)的shape是[batch,beam_size,1]
# log_LM_prob[:,i,:].unsqueeze(1).repeat(1,beam_size,1)的shape是[batch,beam_size,vocab_size]
log_beam_prob = log_beam_prob.unsqueeze(-1) + log_LM_prob[:,i,:].unsqueeze(1).repeat(1,beam_size,1)
#选择当前步概率最高的token,此时的log_beam_prob的shape是[batch,beam_size,vocab_size]
log_beam_prob, index = log_beam_prob.view(batch,-1).topk(beam_size,sorted = True)
#下面的计算:beam_id选出新beam来源于之前的哪个beam;index代表真实的token id
#beam_id,index的shape是[batch,beam_size]
beam_id = index//vocab_size
index = index%vocab_size
mid = torch.Tensor([])
#对batch内每个样本循环,选出beam的同时拼接上新生成的token id
for j,bid,idx in zip(range(batch),beam_id,index):
x = torch.cat([indices[j][bid],idx.unsqueeze(-1)], -1) # x.shape是[beam_size, cur_seq_len]
mid = torch.cat([mid,x.unsqueeze(0)], 0) # mid这时候被拼成[cur_batch, beam_size, cur_seq_len]
indices = mid #这时候的indices.shape是[batch, beam_size, cur_seq_len]
print('indices.shape={}'.format(indices.shape))
return indices, log_beam_prob
if __name__=='__main__':
# 建立一个语言模型 LM_prob (batch,seqlen,vocab_size),这个纯属偷懒,相当于省去每次decode的工作
# (batch_i,seqlen_i,vocab_size_i)表示第batch_i个batch中第
LM_prob = F.softmax(torch.randn([32,20,1000]),dim = -1)
#最终返回每个候选,以及每个候选的log_prob,shape为(batch,beam_size,seqlen)
indices,log_prob = beam_search(LM_prob,beam_size = 3)
# print(indices)
一个有意思的题目,中文断句,能跑
使用qwen2对“成都是一个美丽的城市”做断句,上面demo code其实就是干这个,但是不能跑,下面借助qwen2来一个能在笔记本CPU上都可以正常run的版本。可以看到直接tokenizer得到的断句是[‘成’, ‘都是’, ‘一个’, ‘美丽的’, ‘城市’],使用LLM可以把这个给矫正过来
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
tokenizer.add_special_tokens({'bos_token': '<BOS>'}) #必须有,prompt必须有一个BOS,不能是空,最后不行就把BOS干掉就行了
# from_pretrained得到的tokenizer是 type(tokenizer) = <class 'transformers.models.qwen2.tokenization_qwen2_fast.Qwen2TokenizerFast'>
full_str = "<BOS>成都是一个美丽的城市"
# inputs = tokenizer("<BOS>成都是一个美丽的城市", return_tensors="pt")
# 返回: {'input_ids': tensor([[151646, 12857, 100132, 46944, 105664, 994907]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}
# 也就是被分词成了['<BOS>', '成', '都是', '一个', '美丽的', '城市']
# tokenizer的__call__函数得到的是{"input_ids":input_ids_tensor, 'atention_mask':atention_mask_tensor}的这样一个dict,return_tensors的pt表示pytorch,当然还可以是tf,jax
def prefix_allowed_tokens_fn(batch_id, input_ids_tensor):
prefix_str = "".join(tokenizer.batch_decode(input_ids_tensor))
suffix_str = full_str.replace(prefix_str, "")
allow_token = []
if full_str in prefix_str:
return [tokenizer.eos_token_id]
# print('#####input_ids_tensor[-1]={} tokenizer.eos_token_id={}'.format(input_ids_tensor[-1], tokenizer.eos_token_id))
for i in range(1, len(suffix_str)+1):
token_arr = tokenizer.encode(suffix_str[:i])
if len(token_arr) == 1:
# print('{}={}'.format(token_arr[0], suffix_str[:i]))
# 如果这个中文短语能用一个token来表示
# print('sent={} prefix_str={} suffix_str[:{}]={} token_arr={}'.format(input_ids_tensor, prefix_str, i, suffix_str[:i], token_arr))
allow_token.extend(token_arr)
return allow_token
inputs1 = tokenizer("<BOS>", return_tensors="pt")
outputs = model.generate(
**inputs1,
num_beams=3,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
pad_token_id=tokenizer.eos_token_id, # 没有这句会有这个警告,Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
num_return_sequences = 3, # 输出num_beams中的top num_return_sequences个序列
eos_token_id = tokenizer.eos_token_id #遇到EOS就停止输出
)
bs, seq_len = outputs.shape
for batch_id, seq_tensor in enumerate(outputs):
res = []
for word_id, word in enumerate(seq_tensor):
res.append(tokenizer.decode(word))
print(res)
'''
输出是:
['<BOS>', '成都', '是一个', '美丽的', '城市', '<|endoftext|>', '<|endoftext|>']
['<BOS>', '成都', '是一个', '美丽', '的城市', '<|endoftext|>', '<|endoftext|>']
['<BOS>', '成都', '是一个', '美丽的', '城', '市', '<|endoftext|>']
'''
代码精读
beam search的主要代码在site-packages/transformers/generation/utils.py路径下,搜索self._beam_search这个函数。
def _beam_search(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
logits_warper: Optional[LogitsProcessorList],
**model_kwargs,
) -> Union[GenerateBeamOutput, torch.LongTensor]:
logits_processor和logits_warper的本质上都是修改next token prediction的logits,但是logits_warper专门给topk/topp采样用的,普通的beam search可以忽略logits_warper。下面结合两个例子写写logits_processor:
MinNewTokensLengthLogitsProcessor
generate时min_new_tokens就是用了MinNewTokensLengthLogitsProcessor,例如下面的code会把新生成的长度变成2.
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
>>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
>>> inputs = tokenizer(["A number:"], return_tensors="pt")
>>> gen_out = model.generate(**inputs)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
A number: one
>>> # setting `min_new_tokens` will force the model to generate beyond its natural ending point, which is not
>>> # necessarily incorrect
>>> gen_out = model.generate(**inputs, min_new_tokens=2)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
A number: one thousand
上面code的实现原理就是MinNewTokensLengthLogitsProcessor这个processor,input_ids是到目前为止已经生成的序列ids,shape是[bs×beam_size, cur_seq_length],scores是下一个token的分数,shape是[bs×beam_size, vocab_size],下面的例子把EOS的token分数改成了float(‘-inf’),也就是在softmax排topk时不会选中EOS token
class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
def __init__(
self,
prompt_length_to_skip: int,
min_new_tokens: int,
eos_token_id: Union[int, List[int], torch.Tensor],
device: str = "cpu",
):
for arg_name, arg_value in [
("prompt_length_to_skip", prompt_length_to_skip),
("min_new_tokens", min_new_tokens),
]:
if not isinstance(arg_value, int) or arg_value < 0:
raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}")
if not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id, device=device)
self.prompt_length_to_skip = prompt_length_to_skip
self.min_new_tokens = min_new_tokens
self.eos_token_id = eos_token_id
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
scores_processed = scores.clone()
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
if new_tokens_length < self.min_new_tokens:
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
return scores_processed
PrefixConstrainedLogitsProcessor
这个processor是根据已经输入shape为[bs×beam_size, cur_seq_length]的input_ids和输入shape为[bs×beam_size, vocab_size]的scores,把scores变成scores_processed,也就是说根据已输入序列,输出的下一个token根据函数_prefix_allowed_tokens_fn来预处理一遍概率。_prefix_allowed_tokens_fn这个函数只针对shape为[cur_seq_length]的sent来输出允许的next tokens的范围
class PrefixConstrainedLogitsProcessor(LogitsProcessor):
def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int):
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
self._num_beams = num_beams
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
mask = torch.full_like(scores, -math.inf)
# import pdb; pdb.set_trace();
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
for beam_id, sent in enumerate(beam_sent):
prefix_allowed_tokens = self._prefix_allowed_tokens_fn(batch_id, sent)
print('beam_id={} sent={} prefix_allowed_tokens={}'.format(beam_id, sent, prefix_allowed_tokens))
if len(prefix_allowed_tokens) == 0:
raise ValueError(
f"`prefix_allowed_tokens_fn` returned an empty list for batch ID {batch_id}."
f"This means that the constraint is unsatisfiable. Please check your implementation"
f"of `prefix_allowed_tokens_fn` "
)
mask[batch_id * self._num_beams + beam_id, prefix_allowed_tokens] = 0
scores_processed = scores + mask
return scores_processed
stopping_criteria
以下代码都在 lib/python3.8/site-packages/transformers/generation/utils.py
当所有的stopping_criteria都满足时,修改this_peer_finished:
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
this_peer_finished = True
当有多张卡时,通过this_peer_finished和synced_gpus来控制是否停止
def _has_unfinished_sequences(
self,
this_peer_finished: bool,
synced_gpus: bool,
device: torch.device,
cur_len: Optional[int] = None,
max_length: Optional[int] = None,
) -> bool:
"""
Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is
fed through `this_peer_finished`. ZeRO stage 3-friendly.
"""
# torch.compile does not support data-dependent control flow. This is a workaround to allow torch.compile,
# although we lose the ability to stop when all sequences return an EOS token (and other stopping criteria)
# TODO (joao): remove this when torch's support for control flow is not experimental (https://pytorch.org/docs/stable/generated/torch.cond.html)
if is_torchdynamo_compiling():
return cur_len < max_length
else:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
return False
elif this_peer_finished:
return False
return True
stopping_criteria 有以下几种:
from .stopping_criteria import (
EosTokenCriteria,
MaxLengthCriteria,
MaxTimeCriteria,
StoppingCriteria,
StoppingCriteriaList,
StopStringCriteria,
)
BeamHypotheses的作用:如果某一个batch_beam_idx遇到了EOS,但其他的batch_beam_idx还没有结束。那么结束了的batch_beam_idx还会继续forward嘛?
答案是还会继续,用上面的例子打一点logs就能看到。那么问题又来了,如果EOS后还在继续forward,那怎么在整个batch都结束后补充EOS呢?解决的方案是BeamSearchScorer里process下面的code。每一个batch_group_idx(没分组的话就是batch_beam_idx)都配了一个self._beam_hyps类型为BeamHypotheses的数组。这个BeamHypotheses维护了n-best list of hypotheses,当遇到EOS时开始add,最后通过is_done来判断终止条件。
# next tokens for this sentence
beam_idx = 0
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
):
batch_beam_idx = batch_idx * self.group_size + next_index
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (next_token.item() in eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
if is_beam_token_worse_than_top_num_beams:
continue
if beam_indices is not None:
beam_index = beam_indices[batch_beam_idx]
beam_index = beam_index + (batch_beam_idx,)
else:
beam_index = None
self._beam_hyps[batch_group_idx].add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
generated_len=cur_len - decoder_prompt_len,
)
else:
# add next predicted token since it is not eos_token
next_beam_scores[batch_idx, beam_idx] = next_score
next_beam_tokens[batch_idx, beam_idx] = next_token
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
beam_idx += 1
# once the beam for next step is full, don't add more tokens to it.
if beam_idx == self.group_size:
break