【Bert】文本多标签分类

1. 算法介绍

1.1 参考文献

复旦大学邱锡鹏老师课题组的研究论文《How to Fine-Tune BERT for Text Classification?》。

论文: https://arxiv.org/pdf/1905.05583.pdf

https://mp.weixin.qq.com/s/9MrgIz2bchiCjUGpz6MbGQ

1.2 论文思路

旨在文本分类任务上探索不同的BERT微调方法并提供一种通用的BERT微调解决方法。这篇论文从三种路线进行了探索:

  • (1) BERT自身的微调策略,包括长文本处理、学习率、不同层的选择等方法;

  • (2) 目标任务内、领域内及跨领域的进一步预训练BERT;

  • (3) 多任务学习。微调后的BERT在七个英文数据集及搜狗中文数据集上取得了当前最优的结果。

1.3 代码来源

作者的实现代码: https://github.com/xuyige/BERT4doc-Classification

数据集来源:https://www.kaggle.com/shivanandmn/multilabel-classification-dataset?select=train.csv

项目地址:https://www.kaggle.com/shivanandmn/multilabel-classification-dataset

该数据集包含 6 个不同的标签(计算机科学、物理、数学、统计学、生物学、金融),根据摘要和标题对研究论文进行分类。标签列中的值 1 表示标签属于该标签。每个论文有多个标签为 1。

2. 代码实践

2.1 Import

#2.1 Import

#关于torch的安装可以参考https://blog.youkuaiyun.com/Checkmate9949/article/details/119494673?spm=1001.2014.3001.5501
import torch
from transformers import BertTokenizerFast as BertTokenizer
from utils.plot_results import plot_results
from resources.train_val_model import train_model
from resources.get_data import get_data
from resources.build_model import BertClassifier
from resources.test_model import test_model
from resources.build_dataloader import build_dataloader

2.2 Get data: 分割样本

2.2 Get data

##################################
#            get data
##################################

#该函数见2.2.1
train_df, val_df, test_df = get_data()

# fixed parameters
#Columns: 第三行到倒数第二行
label_columns = train_df.columns.tolist()[3:-1]

num_labels = len(label_columns)
max_token_len = 30


# BERT_MODEL_NAME = "bert-base-uncased"
# bert-base-uncased: for English. bert-base-Chinese
BERT_MODEL_NAME = "model/bert-base-uncased"
#分词
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
def get_data():
    df = pd.read_csv("data/train.csv")
#把标题和摘要合并作为TEXT
    df["TEXT"] = df["TITLE"] + df["ABSTRACT"]

    label_columns = df.columns.tolist()[3:-1]
    print(df[label_columns].sum().sort_values())
#Split data in to train and test: 训练集占比80%
    test_df, train_df = train_test_split(df, test_size=0.8, random_state=42)
#Split data in to valid and test: 分别占比50%
    test_df, val_df = train_test_split(test_df, test_size=0.5, random_state=42)
#输出数据集
    return train_df, val_df, test_df

2.3 Build data loaders 

###########
### 使用 BERT 进行文本多标签分类的方法 #### 数据准备 为了使用 BERT 模型进行多标签文本分类,首先需要准备好数据集并将其转换为适合模型输入的形式。定义 `TextDataset` 类可以有效地将 tokenized 数据和标签封装成 PyTorch 的数据集格式[^3]。 ```python from torch.utils.data import Dataset class TextDataset(Dataset): def __init__(self, texts, labels, tokenizer, max_len): self.texts = texts self.labels = labels self.tokenizer = tokenizer self.max_len = max_len def __len__(self): return len(self.texts) def __getitem__(self, idx): text = str(self.texts[idx]) label = self.labels[idx] encoding = self.tokenizer.encode_plus( text, add_special_tokens=True, max_length=self.max_len, return_token_type_ids=False, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt' ) return { 'text': text, 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'labels': torch.FloatTensor(label) } ``` #### 模型构建 接下来是创建基于 BERT多标签分类器。这里采用 Hugging Face 提供的预训练 BERT 模型,并在其基础上添加一层线性变换层用于输出多个类别概率值。 ```python import torch.nn as nn from transformers import BertModel class BertForMultiLabelClassification(nn.Module): def __init__(self, model_name, num_labels): super(BertForMultiLabelClassification, self).__init__() self.bert = BertModel.from_pretrained(model_name) self.dropout = nn.Dropout(0.3) self.out = nn.Linear(self.bert.config.hidden_size, num_labels) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) pooled_output = outputs.pooler_output output = self.dropout(pooled_output) return self.out(output) ``` #### 训练过程 完成上述准备工作之后,便可以通过 Trainer API 或者自定义循环来进行模型训练。下面展示的是一个简单的训练脚本片段: ```python from transformers import Trainer, TrainingArguments training_args = TrainingArguments( output_dir='./results', num_train_epochs=3, per_device_train_batch_size=8, per_device_eval_batch_size=8, warmup_steps=500, weight_decay=0.01, logging_dir='./logs', logging_steps=10, ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset ) trainer.train() ``` 此段代码展示了使用预训练的 BERT 模型在一个多标签文本分类任务上的训练、保存、加载和预测的完整过程[^2]。 #### 预测阶段 当模型训练完成后,在测试集中应用该模型以获得最终的结果。对于每条记录,会得到一组对应于各个类别的置信度分数;通常情况下会选择那些超过设定阈值的概率作为正样本标记。 ```python predictions = trainer.predict(test_dataset).predictions sigmoid = nn.Sigmoid() probs = sigmoid(torch.tensor(predictions)).numpy() threshold = 0.5 predicted_labels = (probs >= threshold).astype(int) ``` 通过这种方式,能够高效地利用 BERT 对复杂语境下的文本内容进行理解和分析,从而实现精准的多标签分类效果[^5]。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值