# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Mask, padding and batching."""from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves importxrangedefmask(batch_tokens,
seg_labels,
mask_word_tags,
total_token_num,
vocab_size,
CLS=1,
SEP=2,
MASK=3):"""
Add mask for batch_tokens, return out, mask_label, mask_pos;
Note: mask_pos responding the batch_tokens after padded;
"""
max_len =max([len(sent)for sent in batch_tokens])
mask_label =[]
mask_pos =[]
prob_mask = np.random.rand(total_token_num)# Note: the first token is [CLS], so [low=1]
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
pre_sent_len =0
prob_index =0for sent_index, sent inenumerate(batch_tokens):
mask_flag =False
mask_word = mask_word_tags[sent_index]
prob_index += pre_sent_len
if mask_word:
beg =0for token_index, token inenumerate(sent):
seg_label = seg_labels[sent_index][token_index]if seg_label ==1:continueif beg ==0:if seg_label !=-1:
beg = token_index
continue
prob = prob_mask[prob_index + beg]if prob >0.15:passelse:for index inxrange(beg, token_index):
prob = prob_mask[prob_index + index]
base_prob =1.0if index == beg:
base_prob =0.15if base_prob *0.2< prob <= base_prob:
mask_label.append(sent[index])
sent[index]= MASK
mask_flag =True
mask_pos.append(sent_index * max_len + index)elif base_prob *0.1< prob <= base_prob *0.2:
mask_label.append(sent[index])
sent[index]= replace_ids[prob_index + index]
mask_flag =True
mask_pos.append(sent_index * max_len + index)else:
mask_label.append(sent[index])
mask_pos.append(sent_index * max_len + index)if seg_label ==-1:
beg =0else:
beg = token_index
else:for token_index, token inenumerate(sent):
prob = prob_mask[prob_index + token_index]if prob >0.15:continueelif0.03< prob <=0.15:# maskif token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index]= MASK
mask_flag =True
mask_pos.append(sent_index * max_len + token_index)elif0.015< prob <=0.03:# random replaceif token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index]= replace_ids[prob_index +
token_index]
mask_flag =True
mask_pos.append(sent_index * max_len + token_index)else:# keep the original tokenif token != SEP and token != CLS:
mask_label.append(sent[token_index])
mask_pos.append(sent_index * max_len + token_index)
pre_sent_len =len(sent)
mask_label = np.array(mask_label).astype("int64").reshape([-1,1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1,1])return batch_tokens, mask_label, mask_pos
defprepare_batch_data(insts,
total_token_num,
voc_size=0,
pad_id=None,
cls_id=None,
sep_id=None,
mask_id=None,
task_id=0,
return_input_mask=True,
return_max_len=True,
return_num_token=False):"""
change the ids of batch data into the numpy type data
"""
batch_src_ids =[inst[0]for inst in insts]
batch_sent_ids =[inst[1]for inst in insts]
batch_pos_ids =[inst[2]for inst in insts]
labels =[inst[3]for inst in insts]
labels = np.array(labels).astype("int64").reshape([-1,1])
seg_labels =[inst[4]for inst in insts]
mask_word_tags =[inst[5]for inst in insts]# First step: do mask without paddingassert mask_id >=0,"[FATAL] mask_id must >= 0"
out, mask_label, mask_pos = mask(
batch_src_ids,
seg_labels,
mask_word_tags,
total_token_num,
vocab_size=voc_size,
CLS=cls_id,
SEP=sep_id,
MASK=mask_id)# Second step: padding
src_id, self_input_mask = pad_batch_data(
out, pad_idx=pad_id, return_input_mask=True)
pos_id = pad_batch_data(batch_pos_ids, pad_idx=pad_id)
sent_id = pad_batch_data(batch_sent_ids, pad_idx=pad_id)
padded_task_ids = np.ones_like(src_id, dtype="int64")* task_id
return_list =[
src_id, sent_id, pos_id, padded_task_ids, self_input_mask, mask_label, mask_pos
]return return_list
defpad_batch_data(insts,
pad_idx=0,
return_pos=False,
return_input_mask=False,
return_max_len=False,
return_num_token=False,
return_seq_lens=False):"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
return_list =[]
max_len =max(len(inst)for inst in insts)# Any token included in dict can be used to pad, since the paddings' loss# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array([inst +list([pad_idx]*(max_len -len(inst)))for inst in insts])
return_list +=[inst_data.astype("int64").reshape([-1, max_len,1])]# position dataif return_pos:
inst_pos = np.array([list(range(0,len(inst)))+[pad_idx]*(max_len -len(inst))for inst in insts
])
return_list +=[inst_pos.astype("int64").reshape([-1, max_len,1])]if return_input_mask:# This is used to avoid attention on paddings.
input_mask_data = np.array([[1]*len(inst)+[0]*(max_len -len(inst))for inst in insts])
input_mask_data = np.expand_dims(input_mask_data, axis=-1)
return_list +=[input_mask_data.astype("float32")]if return_max_len:
return_list +=[max_len]if return_num_token:
num_token =0for inst in insts:
num_token +=len(inst)
return_list +=[num_token]if return_seq_lens:
seq_lens = np.array([len(inst)for inst in insts])
return_list +=[seq_lens.astype("int64").reshape([-1,1])]return return_list iflen(return_list)>1else return_list[0]if __name__ =="__main__":pass