Transformers的Trainer这个类做Bert Fine-tuning时比较方便,近期用Text-CNN做文本分类实验提升推理速度时,就直接引用了Trainer的训练流程来训练text-CNN,简化训练步骤
import os
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sklearn import metrics
from datasets import Dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import Trainer, TrainingArguments, BertTokenizer
os.environ['CUDA_VISIBLE_DEVICES'] = '7'
# 加载数据
def create_dataset(data_file, tokenizer):
print('create dataset')
with open(data_file, 'r', encoding='utf-8') as f:
data = [_.strip().split('\t') for _ in f.readlines()]
x = [_[0] for _ in data]
y = [int(_[1]) for _ in data]
data_dict = {
'text': x, 'label': y}
dataset = Dataset.from_dict(data_dict)
def