HF beam search 代码精读、中文断句、PrefixConstrainedLogitsProcessor、BeamHypotheses

关于beam search、topk、topp的解释可以参考beam search、top-k sampling、nucleus sampling、temperature sampling和联合采样

beam_search大致框架,不能跑,仅做展示

简单来说就是初始的时候把(bs, cur_seqlen)的input_ids序列repeat_interleave成(bs×beam_size, cur_seqlen)的形状,然后过模型得到(bs×beam_size, vocab_size)的next_token预测,再把next_token预测转换成(bs, beam_size×vocab_size)后取top beam_size,得到(bs, beam_size)的next_token预测。最后 同时 取余 vocab_size得到(bs, beam_size)的next_token预测beam_id和idx,所谓beam_id就是输入序列input_ids要被下一轮选用的下标,最后把(bs×beam_size, cur_seqlen)的input_ids序列拼上(bs×beam_size,1)的新token。举个例子bs=4, beam_size=3,每次实际上是每个batch从3×vocab_size个next token中取最大的3×beam_size个

import torch
import torch.nn as nn

class MyGPTModel(nn.Module):
    def __init__(self):
        super(MyGPTModel, self).__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.full_str = '哈尔滨是一个美丽的城市'
        ...

    def prefix_allowed_tokens_fn(self, input_ids_tensor):
        # input_ids_tensor.shape = (cur_seq_length)
        prefix_str = "".join(self.tokenizer.batch_decode(input_ids_tensor))
        suffix_str = self.full_str.replace(prefix_str, "")
        allow_token = []
        if self.full_str in prefix_str:
            return [self.tokenizer.eos_token_id]
        for i in range(1, len(suffix_str)+1):
            token_arr = self.tokenizer.encode(suffix_str[:i])
            if len(token_arr) == 1: 
                # 如果这个中文短语能用一个token来表示
                allow_token.extend(token_arr)
        return allow_token
    
    def logits_processor(self, input_ids, next_token_scores):
        # input_ids: (bs*beam_size, cur_seq_len)
        # next_token_score: (bs*beam_size, vocab_size)
        mask = torch.full_like(next_token_scores, float('-inf'))
        for batch_beam_id, sent in enumerate(input_ids):
            prefix_allowed_tokens = self.prefix_allowed_tokens_fn(sent)
            mask[batch_beam_id, prefix_allowed_tokens] = 0
        scores_processed = next_token_scores + mask
        return scores_processed
    
    def generate(self, prompt, beam_size):
        tokenizer_output = self.tokenizer(prompt)
        input_ids = tokenizer_output["input_ids"].repeat_interleave(beam_size, dim=0)
        attention_mask = tokenizer_output["attention_mask"].repeat_interleave(beam_size, dim=0)

        bs, cur_seq_len = input_ids.shape # 这里的bs*beam_size
        vocab_size = self.tokenizer.vocab_size
        bs = bs // beam_size
        
        # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
        # of the first beam are considered to avoid sampling the exact same tokens across all beams.
        beam_scores = torch.zeros(bs, beam_size)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.reshape(bs * beam_size, 1)
        past_key_values = DynamicCache() 
        # key的shape是[bs, num_heads, cur_seq_len, head_dim]*num_layers,n层的transformer每层都要cache, 可以用DynamicCache来

        while self._has_unfinished_sequences():
            model_outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values)
            next_token_logits = model_outputs['logits']
            past_key_values = model_outputs['past_key_values']
            # next_token_logits.shape: [bs * beam_size, vocab_size]
            next_token_scores = nn.functional.log_softmax(next_token_logits, dim=-1)  # (bs * beam_size, vocab_size)
            next_token_scores_processed = self.logits_processor(input_ids, next_token_scores) # next_token_scores_processed仍然是(bs * beam_size, vocab_size)
            next_token_scores = next_token_scores_processed + beam_scores
            
            next_token_scores = next_token_scores.view(bs, beam_size * vocab_size)
            next_token_scores, next_tokens_indices = next_token_scores.topk(beam_size, dim=1, largest=True, sorted=True)
            # topk后next_token_scores和next_tokens_indices的shape变成[bs, beam_size]
            # 但由于next_tokens_indices是从beam_size*vocab_size里选的topk,所以要得到beam_id和取余后的indices
            beam_ids = next_tokens_indices // vocab_size
            next_tokens_indices = next_tokens_indices % vocab_size

            input_ids = torch.cat([input_ids[beam_ids.reshape(bs*beam_size), :], next_tokens_indices.reshape(bs*beam_size,1)], dim=-1)
            beam_scores = next_token_scores.reshape(bs*beam_size,1) # 别漏了最后这句

        return input_ids

beam_search简易框架,能跑

这个LM_prob不好的地方在于把LM生成的seqlen都放到了LM_prob这一个tensor里了,seqlen的每一个长度代表一次结果,而vocab_size注意不是embedding,代表是vocab里哪个词

 
import torch
import torch.nn.functional as F
 
def beam_search(LM_prob,beam_size = 3):
    batch,seqlen,vocab_size = LM_prob.shape
    # 对LM_prob取对数,LM_prob: [batch,seqlen,vocab_size]
    # 这个LM_prob不好的地方在于把LM生成的seqlen都放到了LM_prob这一个tensor里了,seqlen的每一个长度代表一次结果,而vocab_size注意不是embedding,代表是vocab里哪个词
    log_LM_prob = LM_prob.log()
    #先选择第0个位置的最大beam_size个token,log_beam_prob与indices的shape为[batch,beam_size]
    log_beam_prob, indices = log_LM_prob[:,0,:].topk(beam_size,sorted = True)
    # topk这个函数刚好是batch内每个元素中所有beam_size之中选topk,log_beam_prob的shape是[batch,beam_size]
    indices = indices.unsqueeze(-1) # indices的shape是[batch,beam_size,1]
    #对每个长度进行beam search
    for i in range(1,seqlen):
        #每个beam的可能产生的概率,由于是log的概率可以直接加
        # log_beam_prob.unsqueeze(-1)的shape是[batch,beam_size,1]
        # log_LM_prob[:,i,:].unsqueeze(1).repeat(1,beam_size,1)的shape是[batch,beam_size,vocab_size]
        log_beam_prob = log_beam_prob.unsqueeze(-1) + log_LM_prob[:,i,:].unsqueeze(1).repeat(1,beam_size,1)
        #选择当前步概率最高的token,此时的log_beam_prob的shape是[batch,beam_size,vocab_size]
        log_beam_prob, index = log_beam_prob.view(batch,-1).topk(beam_size,sorted = True)
        #下面的计算:beam_id选出新beam来源于之前的哪个beam;index代表真实的token id
        #beam_id,index的shape是[batch,beam_size]
        beam_id = index//vocab_size
        index = index%vocab_size
        mid = torch.Tensor([])
        #对batch内每个样本循环,选出beam的同时拼接上新生成的token id
        for j,bid,idx in zip(range(batch),beam_id,index):
            x = torch.cat([indices[j][bid],idx.unsqueeze(-1)], -1) # x.shape是[beam_size, cur_seq_len]
            mid = torch.cat([mid,x.unsqueeze(0)], 0) # mid这时候被拼成[cur_batch, beam_size, cur_seq_len]
        indices = mid #这时候的indices.shape是[batch, beam_size, cur_seq_len]
        print('indices.shape={}'.format(indices.shape))
    return indices, log_beam_prob
 
if __name__=='__main__':
    # 建立一个语言模型 LM_prob (batch,seqlen,vocab_size),这个纯属偷懒,相当于省去每次decode的工作
    # (batch_i,seqlen_i,vocab_size_i)表示第batch_i个batch中第
    LM_prob = F.softmax(torch.randn([32,20,1000]),dim = -1)
    #最终返回每个候选,以及每个候选的log_prob,shape为(batch,beam_size,seqlen)
    indices,log_prob = beam_search(LM_prob,beam_size = 3)
    # print(indices)

一个有意思的题目,中文断句,能跑

使用qwen2对“成都是一个美丽的城市”做断句,上面demo code其实就是干这个,但是不能跑,下面借助qwen2来一个能在笔记本CPU上都可以正常run的版本。可以看到直接tokenizer得到的断句是[‘成’, ‘都是’, ‘一个’, ‘美丽的’, ‘城市’],使用LLM可以把这个给矫正过来

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
tokenizer.add_special_tokens({'bos_token': '<BOS>'}) #必须有,prompt必须有一个BOS,不能是空,最后不行就把BOS干掉就行了
# from_pretrained得到的tokenizer是 type(tokenizer) = <class 'transformers.models.qwen2.tokenization_qwen2_fast.Qwen2TokenizerFast'>

full_str = "<BOS>成都是一个美丽的城市"

# inputs = tokenizer("<BOS>成都是一个美丽的城市", return_tensors="pt")
# 返回: {'input_ids': tensor([[151646,  12857, 100132,  46944, 105664,  994907]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}
# 也就是被分词成了['<BOS>', '成', '都是', '一个', '美丽的', '城市']
# tokenizer的__call__函数得到的是{"input_ids":input_ids_tensor, 'atention_mask':atention_mask_tensor}的这样一个dict,return_tensors的pt表示pytorch,当然还可以是tf,jax
def prefix_allowed_tokens_fn(batch_id, input_ids_tensor):
    prefix_str = "".join(tokenizer.batch_decode(input_ids_tensor))
    suffix_str = full_str.replace(prefix_str, "")
    allow_token = []
    if full_str in prefix_str:
        return [tokenizer.eos_token_id]
    # print('#####input_ids_tensor[-1]={} tokenizer.eos_token_id={}'.format(input_ids_tensor[-1], tokenizer.eos_token_id))
    for i in range(1, len(suffix_str)+1):
        token_arr = tokenizer.encode(suffix_str[:i])
        if len(token_arr) == 1: 
            # print('{}={}'.format(token_arr[0], suffix_str[:i]))
            # 如果这个中文短语能用一个token来表示
            # print('sent={} prefix_str={} suffix_str[:{}]={} token_arr={}'.format(input_ids_tensor, prefix_str, i, suffix_str[:i], token_arr))
            allow_token.extend(token_arr)
    return allow_token

inputs1 = tokenizer("<BOS>", return_tensors="pt")
outputs = model.generate(
              **inputs1, 
              num_beams=3, 
              prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, 
              pad_token_id=tokenizer.eos_token_id, # 没有这句会有这个警告,Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
              num_return_sequences = 3, # 输出num_beams中的top num_return_sequences个序列
              eos_token_id = tokenizer.eos_token_id #遇到EOS就停止输出
          )
bs, seq_len = outputs.shape
for batch_id, seq_tensor in enumerate(outputs):
    res = []
    for word_id, word in enumerate(seq_tensor):
        res.append(tokenizer.decode(word))
    print(res)

'''
输出是:
['<BOS>', '成都', '是一个', '美丽的', '城市', '<|endoftext|>', '<|endoftext|>']
['<BOS>', '成都', '是一个', '美丽', '的城市', '<|endoftext|>', '<|endoftext|>']
['<BOS>', '成都', '是一个', '美丽的', '城', '市', '<|endoftext|>']
'''

代码精读

beam search的主要代码在site-packages/transformers/generation/utils.py路径下,搜索self._beam_search这个函数。

    def _beam_search(
        self,
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: LogitsProcessorList,
        stopping_criteria: StoppingCriteriaList,
        generation_config: GenerationConfig,
        synced_gpus: bool,
        logits_warper: Optional[LogitsProcessorList],
        **model_kwargs,
    ) -> Union[GenerateBeamOutput, torch.LongTensor]:

logits_processor和logits_warper的本质上都是修改next token prediction的logits,但是logits_warper专门给topk/topp采样用的,普通的beam search可以忽略logits_warper。下面结合两个例子写写logits_processor:

MinNewTokensLengthLogitsProcessor

generate时min_new_tokens就是用了MinNewTokensLengthLogitsProcessor,例如下面的code会把新生成的长度变成2.

    >>> from transformers import AutoModelForCausalLM, AutoTokenizer

    >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
    >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")

    >>> inputs = tokenizer(["A number:"], return_tensors="pt")
    >>> gen_out = model.generate(**inputs)
    >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
    A number: one

    >>> # setting `min_new_tokens` will force the model to generate beyond its natural ending point, which is not
    >>> # necessarily incorrect
    >>> gen_out = model.generate(**inputs, min_new_tokens=2)
    >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
    A number: one thousand

上面code的实现原理就是MinNewTokensLengthLogitsProcessor这个processor,input_ids是到目前为止已经生成的序列ids,shape是[bs×beam_size, cur_seq_length],scores是下一个token的分数,shape是[bs×beam_size, vocab_size],下面的例子把EOS的token分数改成了float(‘-inf’),也就是在softmax排topk时不会选中EOS token

class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
    def __init__(
        self,
        prompt_length_to_skip: int,
        min_new_tokens: int,
        eos_token_id: Union[int, List[int], torch.Tensor],
        device: str = "cpu",
    ):
        for arg_name, arg_value in [
            ("prompt_length_to_skip", prompt_length_to_skip),
            ("min_new_tokens", min_new_tokens),
        ]:
            if not isinstance(arg_value, int) or arg_value < 0:
                raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}")

        if not isinstance(eos_token_id, torch.Tensor):
            if isinstance(eos_token_id, int):
                eos_token_id = [eos_token_id]
            eos_token_id = torch.tensor(eos_token_id, device=device)

        self.prompt_length_to_skip = prompt_length_to_skip
        self.min_new_tokens = min_new_tokens
        self.eos_token_id = eos_token_id

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
        scores_processed = scores.clone()
        vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
        eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
        if new_tokens_length < self.min_new_tokens:
            scores_processed = torch.where(eos_token_mask, -math.inf, scores)

        return scores_processed

PrefixConstrainedLogitsProcessor

这个processor是根据已经输入shape为[bs×beam_size, cur_seq_length]的input_ids和输入shape为[bs×beam_size, vocab_size]的scores,把scores变成scores_processed,也就是说根据已输入序列,输出的下一个token根据函数_prefix_allowed_tokens_fn来预处理一遍概率。_prefix_allowed_tokens_fn这个函数只针对shape为[cur_seq_length]的sent来输出允许的next tokens的范围

class PrefixConstrainedLogitsProcessor(LogitsProcessor):

    def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int):
        self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
        self._num_beams = num_beams

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        mask = torch.full_like(scores, -math.inf)
        # import pdb; pdb.set_trace();
        for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
             for beam_id, sent in enumerate(beam_sent):
                prefix_allowed_tokens = self._prefix_allowed_tokens_fn(batch_id, sent)
                print('beam_id={} sent={} prefix_allowed_tokens={}'.format(beam_id, sent, prefix_allowed_tokens))
                if len(prefix_allowed_tokens) == 0:
                    raise ValueError(
                        f"`prefix_allowed_tokens_fn` returned an empty list for batch ID {batch_id}."
                        f"This means that the constraint is unsatisfiable. Please check your implementation"
                        f"of `prefix_allowed_tokens_fn` "
                    )
                mask[batch_id * self._num_beams + beam_id, prefix_allowed_tokens] = 0

        scores_processed = scores + mask
        return scores_processed

stopping_criteria

以下代码都在 lib/python3.8/site-packages/transformers/generation/utils.py
当所有的stopping_criteria都满足时,修改this_peer_finished:

            if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
                this_peer_finished = True

当有多张卡时,通过this_peer_finished和synced_gpus来控制是否停止

    def _has_unfinished_sequences(
        self,
        this_peer_finished: bool,
        synced_gpus: bool,
        device: torch.device,
        cur_len: Optional[int] = None,
        max_length: Optional[int] = None,
    ) -> bool:
        """
        Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is
        fed through `this_peer_finished`. ZeRO stage 3-friendly.
        """
        # torch.compile does not support data-dependent control flow. This is a workaround to allow torch.compile,
        # although we lose the ability to stop when all sequences return an EOS token (and other stopping criteria)
        # TODO (joao): remove this when torch's support for control flow is not experimental (https://pytorch.org/docs/stable/generated/torch.cond.html)
        if is_torchdynamo_compiling():
            return cur_len < max_length
        else:
            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    return False
            elif this_peer_finished:
                return False
            return True

stopping_criteria 有以下几种:

from .stopping_criteria import (
    EosTokenCriteria,
    MaxLengthCriteria,
    MaxTimeCriteria,
    StoppingCriteria,
    StoppingCriteriaList,
    StopStringCriteria,
)

BeamHypotheses的作用:如果某一个batch_beam_idx遇到了EOS,但其他的batch_beam_idx还没有结束。那么结束了的batch_beam_idx还会继续forward嘛?

答案是还会继续,用上面的例子打一点logs就能看到。那么问题又来了,如果EOS后还在继续forward,那怎么在整个batch都结束后补充EOS呢?解决的方案是BeamSearchScorer里process下面的code。每一个batch_group_idx(没分组的话就是batch_beam_idx)都配了一个self._beam_hyps类型为BeamHypotheses的数组。这个BeamHypotheses维护了n-best list of hypotheses,当遇到EOS时开始add,最后通过is_done来判断终止条件。

            # next tokens for this sentence
            beam_idx = 0
            for beam_token_rank, (next_token, next_score, next_index) in enumerate(
                zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
            ):
                batch_beam_idx = batch_idx * self.group_size + next_index
                # add to generated hypotheses if end of sentence
                if (eos_token_id is not None) and (next_token.item() in eos_token_id):
                    # if beam_token does not belong to top num_beams tokens, it should not be added
                    is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
                    if is_beam_token_worse_than_top_num_beams:
                        continue
                    if beam_indices is not None:
                        beam_index = beam_indices[batch_beam_idx]
                        beam_index = beam_index + (batch_beam_idx,)
                    else:
                        beam_index = None

                    self._beam_hyps[batch_group_idx].add(
                        input_ids[batch_beam_idx].clone(),
                        next_score.item(),
                        beam_indices=beam_index,
                        generated_len=cur_len - decoder_prompt_len,
                    )
                else:
                    # add next predicted token since it is not eos_token
                    next_beam_scores[batch_idx, beam_idx] = next_score
                    next_beam_tokens[batch_idx, beam_idx] = next_token
                    next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
                    beam_idx += 1

                # once the beam for next step is full, don't add more tokens to it.
                if beam_idx == self.group_size:
                    break
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值