【GPT入门】第68课 分类任务: 基于BERT训练情感分类

1. 方案设计

基于BERT模型,使用ChnSentiCorp 数据对模型进行训练,最后预测相应评价数据,并检查预测的结果正确性

2. 数据源介绍

2.1 数据下载

ChnSentiCorp 是一款专为中文情感分析设计的数据资源包。它汇集了来自网络平台的多样化评论数据,主要覆盖酒店住宿体验、笔记本电脑使用评价以及书籍阅读感受三大领域。数据集中的每一条评论都经过人工标注,确保了情感标签的准确性和可靠性,对于训练情感分析模型至关重要。

ChnSentiCorp 数据集可以从以下渠道下载:

  • Hugging Face:可以在 Hugging Face 或其镜像上搜索 “ChnSentiCorp”,如直接访问https://hf-mirror.com/datasets/XiangPan/ChnSentiCorp_htl_8k,选择 “Files and versions” 下载数据集。
  • 百度网盘:链接为https://pan.baidu.com/s/1PGCIz-yub3ugXYuNivlZzw,提取码为 nuwl。
  • GitHub:https://raw.githubusercontent.com/SophonPlus/ChineseNlpCorpus/master/datasets/ChnSentiCorp_htl_all/ChnSentiCorp_htl_all.csv。
  • HyperAI 超神经:https://hyper.ai/cn/datasets/29807。

加载数据代码,直接从huggingface下载

from datasets import load_dataset,load_from_disk

#在线加载数据
dataset = load_dataset(path="lansinuote/ChnSentiCorp",cache_dir="data/")
print(dataset)
  • 数据量
DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 9600
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 1200
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1200
    })
})

2.2 数据例子

('小说是早就看过的.买来只是为了收藏,可恨的是当当没有用塑料袋包装此书,感觉不是全新一样,而且封底有一条折痕,影响心情.', 0)
('2007年9月11日256元住普通标间,临街(其它房型已无)。 我是喜欢开着窗睡觉的,总体感觉不太吵。因为下面的玉皇阁北街不算是银川的主要交通要道。 早上有一些车流声。 对面有个农贸市场,购买应季瓜果很方便。 离鼓楼不远,可以品尝一下“老毛抓肉”。', 1)
('酒店太旧了, 大堂感觉象三星级的, 房间也就是的好点的三星级的条件, 在青岛这样的酒店是绝对算不上四星标准, 早餐走了两圈也没有找到可以吃的, 太差了', 0)
('外观简洁漂亮、重量也比较轻,配置的话,就大学生而言绝对够用了。 散热性也比我自己的hp要好。', 1)
('充其量一个度假村而已,客房其实并不大,不知通过什么手段评上4星了。酒店近邻104国道,是连套别墅销售附带的设施。房间装修味道很大,其中一个房子竟然没有窗户。卫生间一股臭味,设施还不如3星级酒店。如果不是去山东路上太晚,怎么也不会到这个酒店。', 0)
('当初无意在电视上看到对于丹的采访,满怀希望地买她的书看。看后觉得与电视中的讲演相距甚远!失望!', 0)
('外观,配置,价格,三个组合起来看是绝对超值的东东 我4699入手,抢到了', 1)
('已经贴完了,又给小区的妈妈买了一套。最值得推荐', 1)
('屏幕大,本本薄。自带数字小键盘,比较少见。声音也还过得去。usb接口多,有四个。独显看高清很好。运行速度也还可以,性价比高!', 1)
('地下车库几乎没有灯,一地泥汤,下车时踩了一脚泥水。所谓海景房看不到海不说,临界巨吵,半夜被街上小贩、拖拉机吵醒再也无法入睡。有骚扰电话,宽带不能上网。没有浴缸,没有浴帽,没有放拖鞋。前台小姐居然拿着电话与男友聊情话而把客人晾一边,真是仅见观瞻。建议携程取消与这家的合作。', 0)   
('作为五星级 酒店的硬件是差了点 装修很久 电视很小 只是位置很好 楼下是DFS 对面是海港城 但性价比不高', 1)

2.3 数据量与标签分布

基本上是1:1

=== train ===
label 0: 4801
label 1: 4799,
total: 9600

=== validation ===
label 0: 607
label 1: 593
total: 1200

=== test ===
label 0: 592
label 1: 608
total: 1200

3.BERT模型训练与评估

3.1 python环境安装

  • 安装conda环境 , 环境放到数据盘:
    mkdir /root/autodl-tmp/xxzhenv
    conda create --prefix /root/autodl-tmp/xxzhenv/bertpython=3.10-y
    conda config --add envs_dirs /root/autodl-tmp/xxzhenv
    conda activate bert

pip install transformers datasets tokenizers

3.2 设置huggingface镜像与学术加速

设置国内镜像
export HF_ENDPOINT=https://hf-mirror.com
autodl的学术加速
source /etc/network_turbo

解决下面这种网络问题:

File “/root/autodl-tmp/xxzhenv/xtuner-env/lib/python3.10/site-packages/datasets/load.py”, line 1551, in dataset_module_factory
raise ConnectionError(f"Couldn’t reach ‘{path}’ on the Hub ({e.class.name})") from e
ConnectionError: Couldn’t reach ‘seamew/THUCNewsText’ on the Hub (LocalEntryNotFoundError)

3.3 模型介绍

基础模型: bert-base-chinese

3.3 模型下载

from transformers import AutoModelForCausalLM, BertTokenizer, BertForSequenceClassification

model_dir = '/root/autodl-tmp/models_xxzh/bert-base-chinese'

# 加载模型和分词器
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', cache_dir=model_dir)

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese',cache_dir=model_dir)

print(model)


4. 模型训练

4.1 代码

transformer是4.46版,采用xtuner的conda环境直接训练

#模型训练
import torch
from MyData import MyDataset
from torch.utils.data import DataLoader
from net import Model
#from transformers import BertTokenizer,AdamW
from transformers import BertTokenizer
from transformers.optimization import Adafactor

#定义设备信息
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#定义训练的轮次(将整个数据集训练完一次为一轮)
EPOCH = 30000

#加载字典和分词器
token = BertTokenizer.from_pretrained(r"/root/autodl-tmp/models_xxzh/bert-base-chinese/models--bert-base-chinese/snapshots/8f23c25b06e129b6c986331a13d8d025a92cf0ea")

#将传入的字符串进行编码
def collate_fn(data):
    sents = [i[0]for i in data]
    label = [i[1] for i in data]
    #编码
    data = token.batch_encode_plus(
        batch_text_or_text_pairs=sents,
        # 当句子长度大于max_length(上限是model_max_length)时,截断
        truncation=True,
        max_length=512,
        # 一律补0到max_length
        padding="max_length",
        # 可取值为tf,pt,np,默认为list
        return_tensors="pt",
        # 返回序列长度
        return_length=True
    )
    input_ids = data["input_ids"]
    attention_mask = data["attention_mask"]
    token_type_ids = data["token_type_ids"]
    label = torch.LongTensor(label)
    return input_ids,attention_mask,token_type_ids,label



#创建数据集
train_dataset = MyDataset("train")
train_loader = DataLoader(
    dataset=train_dataset,
    #训练批次
    batch_size=100,
    #打乱数据集
    shuffle=True,
    #舍弃最后一个批次的数据,防止形状出错
    drop_last=True,
    #对加载的数据进行编码
    collate_fn=collate_fn
)
if __name__ == '__main__':
    #开始训练
    print(DEVICE)
    model = Model().to(DEVICE)
    #定义优化器
    #optimizer = AdamW(model.parameters())
        #定义优化器
    optimizer = Adafactor(model.parameters())
    #定义损失函数
    loss_func = torch.nn.CrossEntropyLoss()

    for epoch in range(EPOCH):
        for i,(input_ids,attention_mask,token_type_ids,label) in enumerate(train_loader):
            #将数据放到DVEVICE上面
            input_ids, attention_mask, token_type_ids, label = input_ids.to(DEVICE),attention_mask.to(DEVICE),token_type_ids.to(DEVICE),label.to(DEVICE)
            #前向计算(将数据输入模型得到输出)
            out = model(input_ids,attention_mask,token_type_ids)
            #根据输出计算损失
            loss = loss_func(out,label)
            #根据误差优化参数
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            #每隔5个批次输出训练信息
            if i%5 ==0:
                out = out.argmax(dim=1)
                #计算训练精度
                acc = (out==label).sum().item()/len(label)
                print(f"epoch:{epoch},i:{i},loss:{loss.item()},acc:{acc}")
        #每训练完一轮,保存一次参数
        torch.save(model.state_dict(),f"params/{epoch}_bert.pth")
        print(epoch,"参数保存成功!")


4.2 训练日志

观察日志,发现,准确率,其实会时高时低。

(/root/autodl-tmp/xxzhenv/xtuner-env) root@autodl-container-015c4bb84b-f47d8cb7:~/autodl-tmp/xxzh/bert# python train.py
cuda
BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(21128, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)
cuda
epoch:0,i:0,loss:0.7195608615875244,acc:0.51
epoch:0,i:5,loss:0.6917376518249512,acc:0.55
epoch:0,i:10,loss:0.649858832359314,acc:0.7
epoch:0,i:15,loss:0.6406018137931824,acc:0.75
epoch:0,i:20,loss:0.5932396054267883,acc:0.77
epoch:0,i:25,loss:0.5811040997505188,acc:0.8
epoch:0,i:30,loss:0.577368974685669,acc:0.76
epoch:0,i:35,loss:0.5450020432472229,acc:0.79
epoch:0,i:40,loss:0.5337967276573181,acc:0.81
epoch:0,i:45,loss:0.4742681086063385,acc:0.86
epoch:0,i:50,loss:0.5025675296783447,acc:0.84
epoch:0,i:55,loss:0.4638785123825073,acc:0.84
epoch:0,i:60,loss:0.4489150941371918,acc:0.85
epoch:0,i:65,loss:0.4568423926830292,acc:0.81
epoch:0,i:70,loss:0.5100483894348145,acc:0.77
epoch:0,i:75,loss:0.4181244373321533,acc:0.86
epoch:0,i:80,loss:0.4122348725795746,acc:0.85
epoch:0,i:85,loss:0.3839280307292938,acc:0.89
epoch:0,i:90,loss:0.429963618516922,acc:0.8
epoch:0,i:95,loss:0.34574413299560547,acc:0.89
0 参数保存成功!
epoch:1,i:0,loss:0.40835708379745483,acc:0.84
epoch:1,i:5,loss:0.40548789501190186,acc:0.9
epoch:1,i:10,loss:0.4293215870857239,acc:0.82
epoch:1,i:15,loss:0.39299094676971436,acc:0.86
epoch:1,i:20,loss:0.39639824628829956,acc:0.87
epoch:1,i:25,loss:0.3679128885269165,acc:0.86
epoch:1,i:30,loss:0.3286331295967102,acc:0.89
epoch:1,i:35,loss:0.40142449736595154,acc:0.81
epoch:1,i:40,loss:0.425790399312973,acc:0.82
epoch:1,i:45,loss:0.29961785674095154,acc:0.91
epoch:1,i:50,loss:0.2763792872428894,acc:0.93
epoch:1,i:55,loss:0.359414279460907,acc:0.86
epoch:1,i:60,loss:0.36357879638671875,acc:0.87
epoch:1,i:65,loss:0.431721031665802,acc:0.8
epoch:1,i:70,loss:0.29267191886901855,acc:0.9
epoch:1,i:75,loss:0.32830092310905457,acc:0.88
epoch:1,i:80,loss:0.2944716811180115,acc:0.92
epoch:1,i:85,loss:0.36355409026145935,acc:0.83
epoch:1,i:90,loss:0.2875462472438812,acc:0.87
epoch:1,i:95,loss:0.4102298319339752,acc:0.79
1 参数保存成功!
epoch:2,i:0,loss:0.33749282360076904,acc:0.87
epoch:2,i:5,loss:0.3382044732570648,acc:0.87
epoch:2,i:10,loss:0.2942817211151123,acc:0.88
epoch:2,i:15,loss:0.3789863586425781,acc:0.82
epoch:2,i:20,loss:0.40764111280441284,acc:0.85
epoch:2,i:25,loss:0.3074020743370056,acc:0.92
epoch:2,i:30,loss:0.3119829297065735,acc:0.87
epoch:2,i:35,loss:0.275304526090
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值