1. 从JSON格式的数据中加载并预处理样本供Ner任务训练和推理使用
class JsonNerDataset(Dataset):
"""
定义一个加载json格式原始命名实体识别格式数据的Dataset
一行一条样本(json字符串),包含: originalText、entities
"""
def __init__(self,
data_path, tokenizer, categories, target_padding_idx=-100,
add_special_token=False, first_special_token='[CLS]', last_special_token='[SEP]'
):
super(JsonNerDataset, self).__init__()
self.tokenizer: BertTokenizer = tokenizer
# self.sentence_max_len = self.tokenizer.max_len_single_sentence # 510
self.sentence_max_len = 126
self.categories = categories
self.token_padding_idx = self.tokenizer.convert_tokens_to_ids('[PAD]')
self.target_padding_idx = target_padding_idx
self.add_special_token = add_special_token
self.first_token_id = self.tokenizer.convert_tokens_to_ids(first_special_token)
self.last_token_id = self.tokenizer.convert_tokens_to_ids(last_special_token)
self.records = self.load_records(data_path)
def load_records(self, data_path) -> list:
records = []
with open(data_path, "r", encoding="utf-8") as reader:
for line in reader:
# 1. 获取原始数据
record = json.loads(line.strip())
# 2. 原始的文本数据转换
entities = record['entities']
text = special_token_processor(record['originalText'])
chars = list(text) # 分字,就是每个字对应一个token
# 3. 标签数据转换
labels = ['O'] * len(chars)
for entity in entities:
label_type = entity['label_type']
start_pos = entity['start_pos'] # 包含开始
end_pos = entity['end_pos'] # 不包含结尾
if end_pos - start_pos == 1:
# 只有一个字形成实体
labels[start_pos] = f'S-{label_type}'
elif end_pos - start_pos > 1:
# 多个字形成实体
labels[start_pos] = f'B-{label_type}'
labels[end_pos - 1] = f'E-{label_type}'
for i in range(start_pos + 1, end_pos - 1):
labels[i] = f'M-{label_type}'
else:
raise ValueError(f"数据异常:{record}")
if self.add_special_token:
# 需要对chars、labels进行split分割, 单个样本的长度不能超过510
for sub_chars, sub_labels in split_chars(chars, labels, self.sentence_max_len):
x = self.tokenizer.convert_tokens_to_ids(sub_chars) # 针对每个字获取对应的id
assert len(sub_chars) == len(x), "bert进行token id转换后,更改了列表长度."
y = [self.categories[c] for c in sub_labels]
x.insert(0, self.first_token_id)
x.append(self.last_token_id)
y.insert(0, self.categories['O'])
y.append(self.categories['O'])
assert len(x) == len(y), f"输入token的长度必须和标签的长度一致,当前长度为:{len(x)} - {len(y)} - {record}"
records.append((x, y, len(x)))
else:
x = self.tokenizer.convert_tokens_to_ids(chars) # 针对每个字获取对应的id
assert len(chars) == len(x), "bert进行token id转换后,更改了列表长度."
y = [self.categories[c] for c in labels]
assert len(x) == len(y), f"输入token的长度必须和标签的长度一致,当前长度为:{len(x)} - {len(y)} - {record}"
records.append((x, y, len(x)))
return records
def __getitem__(self, index):
"""
获取index对应的样本信息,包括x和y
:param index: 样本的索引id
:return: x,y
"""
x, y, num_tokens = self.records[index]
return copy.deepcopy(x), copy.deepcopy(y), num_tokens
def __len__(self):
return len(self.records)
def collate(self, batch):
max_length = max([t[2] for t in batch])
x, y, mask = [], [], []
for i in range(len(batch))