1. 算法介绍
1.1 参考文献
复旦大学邱锡鹏老师课题组的研究论文《How to Fine-Tune BERT for Text Classification?》。
论文: https://arxiv.org/pdf/1905.05583.pdf
https://mp.weixin.qq.com/s/9MrgIz2bchiCjUGpz6MbGQ
1.2 论文思路
旨在文本分类任务上探索不同的BERT微调方法并提供一种通用的BERT微调解决方法。这篇论文从三种路线进行了探索:
-
(1) BERT自身的微调策略,包括长文本处理、学习率、不同层的选择等方法;
-
(2) 目标任务内、领域内及跨领域的进一步预训练BERT;
-
(3) 多任务学习。微调后的BERT在七个英文数据集及搜狗中文数据集上取得了当前最优的结果。
1.3 代码来源
作者的实现代码: https://github.com/xuyige/BERT4doc-Classification
数据集来源:https://www.kaggle.com/shivanandmn/multilabel-classification-dataset?select=train.csv
项目地址:https://www.kaggle.com/shivanandmn/multilabel-classification-dataset
该数据集包含 6 个不同的标签(计算机科学、物理、数学、统计学、生物学、金融),根据摘要和标题对研究论文进行分类。标签列中的值 1 表示标签属于该标签。每个论文有多个标签为 1。
2. 代码实践
2.1 Import
#2.1 Import
#关于torch的安装可以参考https://blog.youkuaiyun.com/Checkmate9949/article/details/119494673?spm=1001.2014.3001.5501
import torch
from transformers import BertTokenizerFast as BertTokenizer
from utils.plot_results import plot_results
from resources.train_val_model import train_model
from resources.get_data import get_data
from resources.build_model import BertClassifier
from resources.test_model import test_model
from resources.build_dataloader import build_dataloader
2.2 Get data: 分割样本
2.2 Get data
##################################
# get data
##################################
#该函数见2.2.1
train_df, val_df, test_df = get_data()
# fixed parameters
#Columns: 第三行到倒数第二行
label_columns = train_df.columns.tolist()[3:-1]
num_labels = len(label_columns)
max_token_len = 30
# BERT_MODEL_NAME = "bert-base-uncased"
# bert-base-uncased: for English. bert-base-Chinese
BERT_MODEL_NAME = "model/bert-base-uncased"
#分词
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
def get_data():
df = pd.read_csv("data/train.csv")
#把标题和摘要合并作为TEXT
df["TEXT"] = df["TITLE"] + df["ABSTRACT"]
label_columns = df.columns.tolist()[3:-1]
print(df[label_columns].sum().sort_values())
#Split data in to train and test: 训练集占比80%
test_df, train_df = train_test_split(df, test_size=0.8, random_state=42)
#Split data in to valid and test: 分别占比50%
test_df, val_df = train_test_split(test_df, test_size=0.5, random_state=42)
#输出数据集
return train_df, val_df, test_df
2.3 Build data loaders
###########

最低0.47元/天 解锁文章
1万+

被折叠的 条评论
为什么被折叠?



