说明
使用的fairseq版本为 0.6.2
fairseq 进行beam search的逻辑位于 fairseq.sequence_generator.SequenceGanerator:generate
SequenceGenerator负责处理整个搜索的过程, 大致逻辑为
- 对每个时刻, 调用self.search.step获取可能的候选, 将搜索到EOS的结果加入到结果候选中, 更新参数并进行下一步的搜索.
self.search.step负责具体的搜索token的策略, 包括beam search, 带惩罚的beam search, 采样等, 模块位于fairseq.search
代码详细注释
在阅读的过程中, 注释不够, 所以直接将代码的逻辑添加到注释中方便看.
class SequenceGenerator(object):
def __init__(
self,
tgt_dict,
beam_size=1,
max_len_a=0,
max_len_b=200,
min_len=1,
stop_early=True,
normalize_scores=True,
len_penalty=1.,
unk_penalty=0.,
retain_dropout=False,
sampling=False,
sampling_topk=-1,
sampling_temperature=1.,
diverse_beam_groups=-1,
diverse_beam_strength=0.5,
match_source_len=False,
no_repeat_ngram_size=0,
):
"""Generates translations of a given source sentence.
Args:
tgt_dict (~fairseq.data.Dictionary): target dictionary
beam_size (int, optional): beam width (default: 1)
max_len_a/b (int, optional): generate sequences of maximum length
ax + b, where x is the source length
min_len (int, optional): the minimum length of the generated output
(not including end-of-sentence)
stop_early (bool, optional): stop generation immediately after we
finalize beam_size hypotheses, even though longer hypotheses
might have better normalized scores (default: True)
normalize_scores (bool, optional): normalize scores by the length
of the output (default: True)
len_penalty (float, optional): length penalty, where <1.0 favors
shorter, >1.0 favors longer sentences (default: 1.0)
unk_penalty (float, optional): unknown word penalty, where <0
produces more unks, >0 produces fewer (default: 0.0)
retain_dropout (bool, optional): use dropout when generating
(default: False)
sampling (bool, optional): sample outputs instead of beam search
(default: False)
sampling_topk (int, optional): only sample among the top-k choices
at each step (default: -1)
sampling_temperature (float, optional): temperature for sampling,
where values >1.0 produces more uniform sampling and values
<1.0 produces sharper sampling (default: 1.0)
diverse_beam_groups/strength (float, optional): parameters for
Diverse Beam Search sampling
match_source_len (bool, optional): outputs should match the source
length (default: False)
"""
self.pad = tgt_dict.pad()
self.unk = tgt_dict.unk()
self.eos = tgt_dict.eos()
self.vocab_size = len(tgt_dict)
self.beam_size = beam_size
# the max beam size is the dictionary size - 1, since we never select pad
self.beam_size = min(beam_size, self.vocab_size - 1)
self.max_len_a = max_len_a
self.max_len_b = max_len_b
self.min_len = min_len
self.stop_early = stop_early
self.normalize_scores = normalize_scores
self.len_penalty = len_penalty
self.unk_penalty = unk_penalty
self.retain_dropout = retain_dropout
self.match_source_len = match_source_len
self.no_repeat_ngram_size = no_repeat_ngram_size
assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling'
if sampling:
self.search = search.Sampling(tgt_dict, sampling_topk, sampling_temperature)
elif diverse_beam_groups > 0:
self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength)
elif match_source_len:
self.search = search.LengthConstrainedBeamSearch(
tgt_dict, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0,
)
else:
self.search = search.BeamSearch(tgt_dict)
@torch.no_grad()
def generate(
self,
models,
sample,
prefix_tokens=None,
bos_token=None,
**kwargs
):
"""Generate a batch of translations.
Args:
models (List[~fairseq.models.FairseqModel]): ensemble of models
sample (dict): batch
prefix_tokens (torch.LongTensor, optional): force decoder to begin
with these tokens
"""
model = EnsembleModel(models)
if not self.retain_dropout:
model.eval()
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample['net_input'].items()
if k != 'prev_output_tokens'
}
src_tokens = encoder_input['src_tokens']
src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
input_size = src_tokens.size()
# batch dimension goes first followed by source lengths
bsz = input_size[0]
src_len = input_size[1]
beam_size = self.beam_size
if self.match_source_len:
max_len = src_lengths.max().item()
else:
max_len = min(
int(self.max_len_a * src_len + self.max_len_b),
# exclude the EOS marker
model.max_decoder_positions() - 1,
)
# compute the encoder output for each beam
encoder_outs = model.forward_encoder(encoder_input)
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
new_order = new_order.to(src_tokens.device).long() # (bsz*beam,) 每个位置表示对应batch样本的idx
# reorder_encoder_output根据new_order, 将encoder_outs每个位置的结果
# 和new_order指定的batch中的sent对应, 现在的encoder_outs包含了 bsz*beam大小的结果
encoder_outs = model.reorder_encoder_out(encoder_outs, new_order)
# initialize buffers
# scores记录的是累加的log prob, bxz*beam x max_len+1
scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0)
scores_buf = scores.clone()
# max_len不包括EOS, 开头和结尾各有一个, 所以+2
# bsz*beam x max_len+2
tokens = src_tokens.data.new(

最低0.47元/天 解锁文章
1145

被折叠的 条评论
为什么被折叠?



