【Bert】文本多标签分类

1. 算法介绍

1.1 参考文献

复旦大学邱锡鹏老师课题组的研究论文《How to Fine-Tune BERT for Text Classification?》。

论文: https://arxiv.org/pdf/1905.05583.pdf

https://mp.weixin.qq.com/s/9MrgIz2bchiCjUGpz6MbGQ

1.2 论文思路

旨在文本分类任务上探索不同的BERT微调方法并提供一种通用的BERT微调解决方法。这篇论文从三种路线进行了探索:

  • (1) BERT自身的微调策略,包括长文本处理、学习率、不同层的选择等方法;

  • (2) 目标任务内、领域内及跨领域的进一步预训练BERT;

  • (3) 多任务学习。微调后的BERT在七个英文数据集及搜狗中文数据集上取得了当前最优的结果。

1.3 代码来源

作者的实现代码: https://github.com/xuyige/BERT4doc-Classification

数据集来源:https://www.kaggle.com/shivanandmn/multilabel-classification-dataset?select=train.csv

项目地址:https://www.kaggle.com/shivanandmn/multilabel-classification-dataset

该数据集包含 6 个不同的标签(计算机科学、物理、数学、统计学、生物学、金融),根据摘要和标题对研究论文进行分类。标签列中的值 1 表示标签属于该标签。每个论文有多个标签为 1。

2. 代码实践

2.1 Import

#2.1 Import

#关于torch的安装可以参考https://blog.youkuaiyun.com/Checkmate9949/article/details/119494673?spm=1001.2014.3001.5501
import torch
from transformers import BertTokenizerFast as BertTokenizer
from utils.plot_results import plot_results
from resources.train_val_model import train_model
from resources.get_data import get_data
from resources.build_model import BertClassifier
from resources.test_model import test_model
from resources.build_dataloader import build_dataloader

2.2 Get data: 分割样本

2.2 Get data

##################################
#            get data
##################################

#该函数见2.2.1
train_df, val_df, test_df = get_data()

# fixed parameters
#Columns: 第三行到倒数第二行
label_columns = train_df.columns.tolist()[3:-1]

num_labels = len(label_columns)
max_token_len = 30


# BERT_MODEL_NAME = "bert-base-uncased"
# bert-base-uncased: for English. bert-base-Chinese
BERT_MODEL_NAME = "model/bert-base-uncased"
#分词
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
def get_data():
    df = pd.read_csv("data/train.csv")
#把标题和摘要合并作为TEXT
    df["TEXT"] = df["TITLE"] + df["ABSTRACT"]

    label_columns = df.columns.tolist()[3:-1]
    print(df[label_columns].sum().sort_values())
#Split data in to train and test: 训练集占比80%
    test_df, train_df = train_test_split(df, test_size=0.8, random_state=42)
#Split data in to valid and test: 分别占比50%
    test_df, val_df = train_test_split(test_df, test_size=0.5, random_state=42)
#输出数据集
    return train_df, val_df, test_df

2.3 Build data loaders 

###########
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值