一、任务目标
在之前的实战项目中,我们都是直接使用预训练模型,如Bert进行文本表示,进而再训练下游模型用于预测。尽管预训练模型具备强大的zero-shot能力和领域知识迁移能力,但在一些特定领域中,预训练模型的表现依旧是不足的。那么,为了让预训练模型能够更好地适应我们当前的任务,微调就成为了一个必要途径。这里,我们将展示如何进行预训练模型的微调,并给出详细的python建模框架解读。
二、微调流程
为了更加方便地对预训练模型进行微调,我们通过pytorch来搭建一个数据集类和模型类:
1、数据集类构建
为了提高torch训练模型的效率,我们往往会自定义一个dataset类,用于对原始数据的批量处理。在数据量小的时候可以直接写循环,但是为了规范我们的代码,且为了让程序具备处理大批量数据的能力,使用Dataset和DataLoader是必要的。
from torch.utils.data import Dataset
# 自定义数据集类,并继承Dataset
class taskDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_len):
# 数据和标签
self.texts = texts
self.labels = labels
# 分词器和文本最大长度
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.texts)
def __getitem__(self,idx):
text = self.texts[idx]
label = self.labels[idx]
encoding = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=self.max_len,
return_token_type_ids=False,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt',
)
return {
'text': text,
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long)
}
taskDataset类用于后续的DataLoader数据加载,DataLoader会根据我们在taskDataset中定义的方法批量处理数据并返回相应格式的结果。其中,taskDataset必须包含__init__,__len__和__getitem__方法。除了__len__方法直接返回数据长度之外,其余两个方法的输入和输出由我们自定义,只需要确保输出的结果能够输入到预训练模型中即可(即符合预训练模型输入要求)。
2、模型类构建
定义好数据集类之后,模型类也是必不可少的。模型类必须包括__init__和forward两个类方法,而类方法的参数输入和输出则由我们定义。一般而言,__init__方法中会定义好预训练模型层以及其后的附加层(对于不同的任务和模型结构,更改此处模型层定义即可),而forward方法则获取预训练模型必要的输入以及其他可选参数,并定义好输入数据的处理步骤,例如先经过bert层,再经过fc1,最后fc2等,输出一般是logits,即最后一层的神经元输出值(对应任务输出要求,例如这里是三分类,那么forward输出的结果应当是三个神值,代表属于对应类的概率)。
import torch.nn as nn
class BertMLPClassifier(nn.Module):
def __init__(self):
super(BertMLPClassifier, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.dropout = nn.Dropout(p=0.3)
self.fc1 = nn.Linear(self.bert.config.hidden_size, 256)
self.fc2 = nn.Linear(256, 3)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs[1] # BERT的池化输出
x = self.dropout(pooled_output)
x = torch.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
3、数据准备
完成了数据集类和模型类的定义之后,我们就可以开始准备数据了。首先读取数据集,并划分训练集和测试集,然后taskDataset对数据进行处理,最后调用DataLoader实现批量数据的输出。这里,我们使用的是kaggle的