最近因为需要用BERT-CRF模型做一个英文数据的实体抽取模型训练,因为github上BERT-CRF大多是对中文数据做NER, 这里特此记录一下处理过程中的解决方法与思路,废话不多说直接上代码,这里的代码模版参考的是 CLUENER2020项目下的BERT-CRF模型代码, 主要修改部分在 collate_fn 部分的 batch数据的 padding与aligning处理
首先,说一下我遇到的主要问题:
因为模型的数据padding, aligning的预处理是针对中文数据的,而在用英文数据调试训练过程发现padding和aligning总是出现越界溢出问题,经过网上多方调研,发现是由于中文分词与英文分词的方法不同,中文是单纯的词组切分,而英文是依据词源词根进行切分,导致分词的序列长与原句次数不一致,故原模型的padding与aligning处理方法已不适于英文数据,需要依据英文分词特点进行padding与aligning处理。
import os
import json
import torch
import numpy as np
from transformers import BertTokenizer, BertTokenizerFast
from transformers import RobertaTokenizer, RobertaModel
from torch.utils.data import Dataset
class NERDataset(Dataset):
def __init__(self, words, labels, config, word_pad_idx=0, label_pad_idx=-1):
self.tokenizer = BertTokenizerFast.from_pretrained(config.bert_model, do_lower_case=True, add_special_tokens=True)
# self.tokenizer = BertTokenizerFast.from_pretrained(config.bert_model, do_lower_case=False, add_special_tokens=True)
self.label2id = config.label2id
self.id2label = {_id: _label for _label, _id in list(config.label2id.items())}
self.dataset = self.preprocess(words, labels)
self.word_pad_idx = word_pad_idx
self.label_pad_idx = label_pad_idx
self.device = config.device
def preprocess(self, origin_sentences, origin_labels):
"""
Maps tokens and tags to their indices and stores them in the dict data.
examples:
word:['[CLS]', '浙', '商', '银', '行', '企', '业', '信', '贷', '部']
sentence:([101, 3851, 1555, 7213, 6121, 821, 689, 928, 6587, 6956], array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))
label:[3, 13, 13, 13, 13, 0, 0, 0, 0, 0]
"""
data = []
sentences = []
labels = []
for line in origin_sentences:
# replace each token by its index
# we can not use encode_plus because our sentences are aligned to labels in list type
words = []
word_lens = []
for token in line:
words.append(self.tokenizer.tokenize(token))
word_lens.append(len(token))
# 变成单个字的列表,开头加上[CLS]
words = ['[CLS]'] + [item for token in words for item in token] + ['SEP']
token_start_idxs = 1 + np.cumsum([0] + word_lens[:-1])
sentences.append((self.tokenizer.convert_tokens_to_ids(words), token_start_idxs))
for tag in origin_labels:
label_id = [self.label2id.get(t) for t in tag]
labels.append(label_id)
for sentence, label in zip(sentences, labels):
data.append((sentence, label))
return data
def __getitem__(self, idx):
"""sample data to get batch"""
word = self.dataset[idx][0]
label = self.dataset[idx][1]
return [word, label]
def __len__(self):
"""get dataset size"""
return len(self.dataset)
def collate_fn(self, batch):
""""""
sentences = [x[0] for x in batch]
labels = [x[1] for x in batch]
# batch length
batch_len = len(sentences) # batch size
batch_max_subwords_len = max([len(s[0]) for s in sentences])
max_subword_len = min(batch_max_subwords_len, 512)
max_token_len = 0
# padding data 初始化
batch_data = self.word_pad_idx * np.ones((batch_len, max_subword_len)) # 初始化标注数据默认为0 64 * max_len =
batch_token_starts = []
# padding and aligning
for j in range(batch_len):
cur_subwords_len = len(sentences[j][0]) # word_id list
if cur_subwords_len <= max_subword_len:
batch_data[j][:cur_subwords_len] = sentences[j][0]
else:
batch_data[j] = sentences[j][0][:max_subword_len]
token_start_ids = sentences[j][-1]
token_starts = np.zeros(max_subword_len)
token_starts[[idx for idx in token_start_ids if idx < max_subword_len]] = 1
batch_token_starts.append(token_starts)
max_token_len = max(int(sum(token_starts)), max_token_len)
batch_labels = self.label_pad_idx * np.ones((batch_len, max_token_len))
for j in range(batch_len):
cur_labels_len = len(labels[j])
if cur_labels_len <= max_token_len:
batch_labels[j][:cur_labels_len] = labels[j]
else:
batch_labels[j] = labels[j][:max_token_len]
# convert data to torch LongTensors
batch_data = torch.tensor(batch_data, dtype=torch.long)
batch_token_starts = torch.tensor(batch_token_starts, dtype=torch.long)
batch_labels = torch.tensor(batch_labels, dtype=torch.long)
# print(batch_data.size())
# print(batch_token_starts.size())
# print(batch_labels.size())
# shift tensors to GPU if available
batch_data, batch_label_starts = batch_data.to(self.device ), batch_token_starts.to(self.device )
batch_labels = batch_labels.to(self.device)
return [batch_data, batch_label_starts, batch_labels]

本文介绍使用BERT-CRF模型进行英文数据实体抽取的过程,针对英文分词与中文分词的不同特点,调整了数据填充与对齐方法。
1196

被折叠的 条评论
为什么被折叠?



