读取数据函数
def read_data(file):
with open(file, "r", encoding="utf-8") as f:
all_data = f.read().split("\n")
all_text = []
all_label = []
text = []
label = []
for data in all_data:
if data == "":
all_text.append(text)
all_label.append(label)
text = []
label = []
else:
t,l = data.split(" ")
text.append(t)
label.append(l)
构建标签函数
def bulid_label(train_label):
label_2_index = {"PAD":0, "UNK":1}
for label in train_label:
for l in label:
if l not in label_2_index:
label_2_index[l]=len(label_2_index)
return label_2_index, list(label_2_index)
主函数
from torch.utils.data import Dataset,DataLoader
from transformers import BertTokenizer
from transformers import BertModel
train_text, train_label = read_data(os.path.join("data","train.txt"))
label_2_index, index_2_label = build_label(train_label)
tokenizer = BertTokenizer.from_pretrained(os.path.join("..","bert_base_chinese"))
batch_size = 10
epoch = 100
max_len=30
lr = 0.0001
train_dataset = BertDataset(train_text, train_label,label_2_index, tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
model = BertNerModel((len(label_2_index)))
opt = AdamW(model.parameters(),lr)
for e in range(epoch):
for batch_text_index, batch_label_index in train_dataloader:
loss = model.forward(batch_text_index, batch_label_index)
loss.backward()
opt.step()
opt.zero_grad()
数据集模型
class BertDataSet(Dataset):
def __init__(self, all_text, all_label, label_2_index, max_len, tokenizer):
self.all_text = all_text
self.all_label = all_label
self.label_2_index = label_2_index
self.tokenizer = tokenizer
self.max_len = max_len
def __getitem__(self, index):
text = self.all_text[index]
label = self.all_label[index]
text_index = self.tokenizer.encode(text, add_special_tokens=True, max_length=self.max_len, padding="max_length", truncation=True,return_tensors=True)
label_index = [0] + [self.label_2_index.get(l, 1) for l in label_2_index] + [0] + [0] + (max_len - len(text))
lable_index = torch.tensor(label_index)
return text_index.reshape(-1), label_index
def __len__(self):
return self.all_text.__len__()
NER模型
class BertNerModel(nn.Module):
def __init__(self, class_num):
super().__init__()
self.bert = BertModel.from_pretrained(os.path.join("..", "bert_base_chinese"))
self.classifier = nn.Linear(768, class_num)
self.loss_fun = nn.CrossEntropyLoss()
def forward(self.batch_index, batch_label=None):
bert_out = self.bert(batch_index)
bert_out0, bert_out1 = bert_out[0], bert_out[1] 字符级别, 篇章级别
pre = self.classifier(bert_out0)
if batch_label is not None:
loss = self.loss_fun(pre.reshape(-1,pre.reshape[-1]), batch_label.reshape(-1))
return loss
config.json 配置文件
pytorch.model.bin
vocab.txt 索引表
相关资料:
机器学习和人工智能数据预处理和建模的工具: Slik-wrangler
transformers模型: https://github.com/huggingface/transformers