【AI实战】基于 bert-base-chinese 预训练模型的多标签文本分类模型,BCEWithLogLoss解决样本不均衡问题
多标签文本多分类
-
文本分类
文本分类用电脑对文本集(或其他实体或物件)按照一定的分类体系或标准进行自动分类标记。 它根据一个已经被标注的训练文档集合, 找到文档特征和文档类别之间的关系模型, 然后利用这种学习得到的关系模型对 新的文档进行类别判断 。 -
其应用场景众多,包括:
情感分析(Sentiment Analyse)
主题分类(Topic Labeling)
问答任务(Question Answering)
意图识别(Dialog Act Classification)
自然语言推理(Natural Language Inference) -
多标签分类
多标签分类是样本可以同时属于多个类别(或者标签)、类别间可以相互重叠的模式识别问题,其特殊性主要体现在样本到标签的一对多映射关系和标签间的相关性。如文本可能同时涉及任何宗教,政治,金融或教育,也可能不属于任何一种。
bert-base-chinese 预训练模型
bert-base-chinese
This model has been pre-trained for Chinese, training and random input masking has been applied independently to word pieces (as in the original BERT paper).
- Developed by: HuggingFace team
- Model Type: Fill-Mask
- Language(s): Chinese
- License: [More Information needed]
- Parent Model: See the BERT base uncased model for more information about the BERT base model.
模型地址
https://huggingface.co/bert-base-chinese
模型拉取
git clone https://huggingface.co/bert-base-chinese
样本不均衡问题
解决样本不均衡问题方法
- 1.数据过采样、欠采样
对样本数量特别少的类别进行过采样;
对样本数量特别多的类别进行欠采样;
对样本数量较少的类别借助其他文本生成模型进行数据生成; - 2.修改loss函数
使用 focus loss 代替 交叉熵损失函数 Cross entropy loss
使用 GHM 代替 交叉熵损失函数 Cross entropy loss
多标签分类时,使用 BCE Loss 代替 交叉熵损失函数 Cross entropy loss; - 3.数据过采样 + loss修改
结合方法1 和方法2
BCEWithLogLoss解决样本不均衡问题
1.对样本数量特别少的类别进行过采样
比如我对样本数量少的复制了 5 次
2.设置 BCEWithLogLoss 的参数
在 pytorch 中,BCEWithLogitsLoss 可以用来计算多标签多分类问题的损失函数。定义如下:
loss_fn = torch.nn.BCEWithLogitsLoss(weight=None, reduction='mean', pos_weight=None)
下面计算参数:weight,pos_weight
多标签:
labels = ['label-1', 'label-2', 'label-3', 'label-4', 'label-5', 'label-6', 'label-7', \
'label-8', 'label-9', 'label-10', 'label-11', 'label-12', 'label-13', \
'label-14', 'label-15', 'label-16']
统计样本的每个类别的数量:
label_counts = [6, 19, 5, 13, 5, 2, 3, 4, 2, 2, 2, 2, 2, 2, 2, 2299]
完整代码:
count_labels_info.py:
import sys
# 功能:统计标签的比例信息,计算 BCEWithLogitsLoss 的 weight和pos_weight 2个参数的值
labels = ['label-1', 'label-2', 'label-3', 'label-4', 'label-5', 'label-6', 'label-7', \
'label-8', 'label-9', 'label-10', 'label-11', 'label-12', 'label-13', \
'label-14', 'label-15', 'label-16']
num_classes = len(labels)
label_counts = [6, 19, 5, 13, 5, 2, 3, 4, 2, 2, 2, 2, 2, 2, 2, 2299]
total_samples = 2370
total_labels = sum(label_counts)
class_weight = [0 for _ in range(num_classes)]
pos_weight = [0 for _ in range(num_classes)]
for label_idx, count in enumerate(label_counts):
if count != 0:
class_weight[label_idx] = 1 - count / total_labels
pos_weight[label_idx] = total_samples / count - 1
print('class_weight:', class_weight)
print('pos_weight:', pos_weight)
输出:
class_weight = [0.9974683544303797, 0.9919831223628692, 0.9978902953586498, 0.9945147679324895, 0.99