【AI实战】基于 bert-base-chinese 预训练模型的多标签文本分类模型,BCEWithLogLoss解决样本不均衡问题

该博客介绍了如何使用Bert-base-Chinese预训练模型解决多标签文本分类中的样本不均衡问题。通过过采样、使用BCEWithLogLoss损失函数等方法来改善模型性能。并详细阐述了环境配置、数据预处理、模型训练和测试的步骤。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

多标签文本多分类

  • 文本分类
    文本分类用电脑对文本集(或其他实体或物件)按照一定的分类体系或标准进行自动分类标记。 它根据一个已经被标注的训练文档集合, 找到文档特征和文档类别之间的关系模型, 然后利用这种学习得到的关系模型对 新的文档进行类别判断 。

  • 其应用场景众多,包括:
    情感分析(Sentiment Analyse)
    主题分类(Topic Labeling)
    问答任务(Question Answering)
    意图识别(Dialog Act Classification)
    自然语言推理(Natural Language Inference)

  • 多标签分类
    多标签分类是样本可以同时属于多个类别(或者标签)、类别间可以相互重叠的模式识别问题,其特殊性主要体现在样本到标签的一对多映射关系和标签间的相关性。

    如文本可能同时涉及任何宗教,政治,金融或教育,也可能不属于任何一种。

bert-base-chinese 预训练模型

bert-base-chinese

This model has been pre-trained for Chinese, training and random input masking has been applied independently to word pieces (as in the original BERT paper).

  • Developed by: HuggingFace team
  • Model Type: Fill-Mask
  • Language(s): Chinese
  • License: [More Information needed]
  • Parent Model: See the BERT base uncased model for more information about the BERT base model.

模型地址

https://huggingface.co/bert-base-chinese

模型拉取

git clone https://huggingface.co/bert-base-chinese

在这里插入图片描述

样本不均衡问题

解决样本不均衡问题方法

  • 1.数据过采样、欠采样
    对样本数量特别少的类别进行过采样;
    对样本数量特别多的类别进行欠采样;
    对样本数量较少的类别借助其他文本生成模型进行数据生成;
  • 2.修改loss函数
    使用 focus loss 代替 交叉熵损失函数 Cross entropy loss
    使用 GHM 代替 交叉熵损失函数 Cross entropy loss
    多标签分类时,使用 BCE Loss 代替 交叉熵损失函数 Cross entropy loss;
  • 3.数据过采样 + loss修改
    结合方法1 和方法2

BCEWithLogLoss解决样本不均衡问题

1.对样本数量特别少的类别进行过采样

比如我对样本数量少的复制了 5 次

2.设置 BCEWithLogLoss 的参数

在 pytorch 中,BCEWithLogitsLoss 可以用来计算多标签多分类问题的损失函数。定义如下:

loss_fn = torch.nn.BCEWithLogitsLoss(weight=None, reduction='mean', pos_weight=None)

下面计算参数:weight,pos_weight
多标签:

labels = ['label-1', 'label-2', 'label-3', 'label-4', 'label-5', 'label-6', 'label-7', \
   'label-8', 'label-9', 'label-10', 'label-11', 'label-12', 'label-13', \
   'label-14', 'label-15', 'label-16']

统计样本的每个类别的数量:

label_counts = [6, 19, 5, 13, 5, 2, 3, 4, 2, 2, 2, 2, 2, 2, 2, 2299]

完整代码:
count_labels_info.py:

 import sys

# 功能:统计标签的比例信息,计算 BCEWithLogitsLoss 的 weight和pos_weight 2个参数的值 

labels = ['label-1', 'label-2', 'label-3', 'label-4', 'label-5', 'label-6', 'label-7', \
    'label-8', 'label-9', 'label-10', 'label-11', 'label-12', 'label-13', \
    'label-14', 'label-15', 'label-16']

num_classes = len(labels)
label_counts = [6, 19, 5, 13, 5, 2, 3, 4, 2, 2, 2, 2, 2, 2, 2, 2299]
total_samples = 2370
total_labels = sum(label_counts) 

class_weight = [0 for _ in range(num_classes)]
pos_weight = [0 for _ in range(num_classes)]

for label_idx, count in enumerate(label_counts):
    if count != 0:
        class_weight[label_idx] = 1 - count / total_labels
        pos_weight[label_idx] = total_samples / count - 1

print('class_weight:', class_weight)
print('pos_weight:', pos_weight)

输出:

class_weight = [0.9974683544303797, 0.9919831223628692, 0.9978902953586498, 0.9945147679324895, 0.99
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

szZack

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值