Task4-基于深度学习的文本分类3-基于Bert预训练和微调进行文本分类
因为天池这个比赛的数据集是脱敏的,无法利用其它已经预训练好的模型,所以需要针对这个数据集自己从头预训练一个模型。
我们利用Huggingface的transformer包,按照自己的需求从头开始预训练一个模型,然后将该模型应用于下游任务。
完整代码见:NLP-hands-on/天池-零基础入门NLP at main · ifwind/NLP-hands-on (github.com)
注意:利用Huggingface做预训练需要安装wandb包,如果报错可参考:[wandb.errors.UsageError: api_key not configured (no-tty). call wandb.login(key=[your_api_key\])_](https://blog.youkuaiyun.com/hhhhhhhhhhwwwwwwwwww/article/details/116124285)
预训练模型
利用Huggingface的transformer包进行预训练主要包括以下几个步骤:
- 用数据集训练Tokenizer;
- 加载数据及数据预处理;
- 设定预训练模型参数,初始化预训练模型;
- 设定训练参数,加载训练器;
- 训练并保存模型。
用数据集训练Tokenizer
Tokenizer是分词器,分词方式有很多种,可以按照空格直接切分、也可以在按词组划分等,可以查看HuggingFace关于tokenizers的官方文档。
Huggingface中,Tokenizer的训练方式为:
- 根据
tokenizers.models
实例化一个Tokenizer
对象tokenizer
, - 从
tokenizers.trainers
中选模型相应的训练器实例化,得到trainer
; - 从
tokenizers.pre_tokenizers
选定一个预训练分词器,对tokenizer
的预训练分词器实例化; - 利用
tokenizer.train()
结合trainer
对语料(注意,语料为一行一句)进行训练; - 利用
tokenizer.save()
保存tokenizer
。
因为天池这个比赛的数据集是脱敏的,词都是用数字进行表示,没有办法训练wordpiece等复杂形式的分词器,只能用空格分隔,在wordlevel进行分词。
因此,我们利用tokenizers.models
中的WordLevel
模型,对应tokenizers.trainers
中的WordLevelTrainer
,选择预训练分词器为Whitespace
训练分词器。
另外,在训练Tokenizer时,可以利用上全部的语料(包括训练集和最终的测试集)。
完整代码如下:
import joblib
from config import *
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
import os
def data_preprocess():
rawdata = pd.read_csv(data_file, sep='\t', encoding='UTF-8')
#用正则表达式按标点替换文本
import re
rawdata['words']=rawdata['text'].apply(lambda x: re.sub('3750|900|648',"",x))
del rawdata['text']
#预测
final_test_data = pd.read_csv(final_test_data_file, sep='\t', encoding='UTF-8')
final_test_data['words'] = final_test_data['text'].apply(lambda x: re.sub('3750|900|648',"",x))
del final_test_data['text']
all_value= rawdata['words'].append(final_test_data['words'])
all_value.columns=['text']
all_value.to_csv('../alldata.csv',index=False)
data_preprocess()
from tokenizers import Tokenizer
from tokenizers.models import BPE,WordLevel
tokenizer= Tokenizer(WordLevel(unk_token=