第TR5周:Transformer实战:文本分类

任务:
●理解文中代码逻辑并成功运行
●根据自己的理解对代码进行调优,使准确率达到70%

1.准备工作

1.1.环境安装

这是一个使用PyTorch通过Transformer算法实现简单的文本分类实战案例。

import torch,torchvision
print(torch.__version__)  #注意是双下划线
print(torchvision.__version__)

代码输出

2.0.0+cpu
0.15.1+cpu

1.2.加载数据

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warnings

warnings.filterwarnings("ignore") #忽略警告信息

# win10系统
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

代码输出

device(type='cpu')
import pandas as pd

# 加载自定义中文数据
train_data = pd.read_csv('./TR5/train.csv', sep='\t', header=None)
train_data.head()

代码输出

0 1
0 还有双鸭山到淮阴的汽车票吗13号的 Travel-Query
1 从这里怎么回家 Travel-Query
2 随便播放一首专辑阁楼里的佛里的歌 Music-Play
3 给看一下墓王之王嘛 FilmTele-Play
4 我想看挑战两把s686打突变团竞的游戏视频 Video-Play
# 构造数据集迭代器
def coustom_data_iter(texts, labels):
    for x, y in zip(texts, labels):
        yield x, y
        
train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_data[0].values[:]

代码输出

array(['还有双鸭山到淮阴的汽车票吗13号的', '从这里怎么回家', '随便播放一首专辑阁楼里的佛里的歌', ...,
       '黎耀祥陈豪邓萃雯畲诗曼陈法拉敖嘉年杨怡马浚伟等到场出席', '百事盖世群星星光演唱会有谁', '下周一视频会议的闹钟帮我开开'],
      dtype=object)
train_data[1].values[:]

代码输出

array(['Travel-Query', 'Travel-Query', 'Music-Play', ..., 'Radio-Listen',
       'Video-Play', 'Alarm-Update'], dtype=object)

2.数据预处理

2.1.构建词典

需要另外安装jieba分词库,安装语句如下:
●cmd命令:pip install jieba

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba

# 中文分词方法
tokenizer = jieba.lcut

def yield_tokens(data_iter):
    for text,_ in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"]) # 设置默认索引,如果找不到单词,则会选择默认索引

代码输出

Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\xzy\AppData\Local\Temp\jieba.cache
Loading model cost 0.953 seconds.
Prefix dict has been built successfully.
vocab(['我','想','看','和平','精英','上','战神','必备','技巧','的','游戏','视频'])

代码输出

[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
label_name = list(set(train_data[1].values[:]))
print(label_name)

代码输出

['Audio-Play', 'Music-Play', 'Weather-Query', 'Alarm-Update', 'Radio-Listen', 'TVProgram-Play', 'Travel-Query', 'FilmTele-Play', 'Calendar-Query', 'HomeAppliance-Control', 'Video-Play', 'Other']
text_pipeline  = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: label_name.index(x)

print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))

代码输出

[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
10

2.2.生成数据批次和迭代器

from torch.utils.data import DataLoader

def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    
    for (_text,_label) in batch:
        # 标签列表
        label_list.append(label_pipeline(_label))
        
        # 文本列表
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        
        # 偏移量,即语句的总词汇量
        offsets.append(processed_text.size(0))
        
    label_list = torch.tensor(label_list, dtype=torch.int64)
    text_list  = torch.cat(text_list)
    offsets    = torch.tensor(offsets[:-1]).cumsum(dim=0) #返回维度dim中输入元素的累计和
    
    return text_list.to(device),label_list.to(device), offsets.to(device)

2.3.构建数据集

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset

BATCH_SIZE = 4 

train_iter    = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)

split_train_, split_valid_ = random_split(train_dataset,
                                          [int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)

valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)

to_map_style_dataset()函数

作用是将一个迭代式的数据集(Iterable-style dataset)转换为映射式的数据集(Map-style dataset)。这个转换使得我们可以通过索引(例如:整数)更方便地访问数据集中的元素。在 PyTorch 中,数据集可以分为两种类型:Iterable-style 和 Map-style。
●Iterable-style 数据集实现了 __ iter__() 方法,可以迭代访问数据集中的元素,但不支持通过索引访问。
●Map-style 数据集实现了 __ getitem__() 和 __ len__() 方法,可以直接通过索引访问特定元素,并能获取数据集的大小。

3.模型构建

3.1.定义位置编码函数

import math,os,torch

class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=500):
        super(PositionalEncoding, self).__init__()

        # 创建一个大小为 [max_len, embed_dim] 的零张量
        pe = torch.zeros(max_len, embed_dim) 
        # 创建一个形状为 [max_len, 1] 的位置索引张量
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 

        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(100.0) / embed_dim))
        
        pe[:, 0::2] = torch.sin(position * div_term) # 计算 PE(pos, 2i)
        pe[:, 1::2] = torch.cos(position * div_term) # 计算 PE(pos, 2i+1)
        pe = pe.unsqueeze(0).transpose(0, 1)

        # 将位置编码张量注册为模型的缓冲区,参数不参与梯度下降,保存model的时候会将其保存下来
        self.register_buffer('pe', pe)

    def forward(self, x):
        # 将位置编码添加到输入张量中,注意位置编码的形状
        x = x + self.pe[:x.size(0)]
        return x

3.2.定义Transformer模型


from tempfile import TemporaryDirectory
from typing   import Tuple
from torch    import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

class TransformerModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_class, nhead=8, d_hid=256, nlayers=12, dropout=0.1):
        super().__init__()

        self.embedding = nn.EmbeddingBag(vocab_size,   # 词典大小
                                         embed_dim,    # 嵌入的维度
                                         sparse=False) # 
        
        self.pos_encoder = PositionalEncoding(embed_dim)

        # 定义编码器层
        encoder_layers           = TransformerEncoderLayer(embed_dim, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.embed_dim           = embed_dim
        self.linear              = nn.Linear(embed_dim*4, num_class)
        
    def forward(self, src, offsets, src_mask=None):

        src    = self.embedding(src, offsets)
        src    = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)

        output = output.view(4, embed_dim*4)
        output = self.linear(output)
 
        return output

3.3.初始化模型

vocab_size = len(vocab)  # 词汇表的大小
embed_dim  = 64         # 嵌入维度
num_class  = len(label_name)

# 创建 Transformer 模型,并将其移动到设备上
model = TransformerModel(vocab_size, 
                         embed_dim, 
                         num_class).to(device)

3.4.定义训练函数

import time

def train(dataloader):
    model.train()  # 切换为训练模式
    total_acc, train_loss, total_count = 0, 0, 0
    log_interval = 300
    start_time   = time.time()

    for idx, (text,label,offsets) in enumerate(dataloader):
        predicted_label = model(text, offsets)
        optimizer.zero_grad()                    # grad属性归零

        loss = criterion(predicted_label, label) # 计算网络输出和真实值之间的差距,label为真实值
        loss.backward()                          # 反向传播
        optimizer.step()  # 每一步自动更新
        
        # 记录acc与loss
        total_acc   += (predicted_label.argmax(1) == label).sum().item()
        train_loss  += loss.item()
        total_count += label.size(0)
        
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:1d} | {:4d}/{:4d} batches '
                  '| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),
                                              total_acc/total_count, train_loss/total_count))
            total_acc, train_loss, total_count = 0, 0, 0
            start_time = time.time()

3.5.定义评估函数

def evaluate(dataloader):
    model.eval()  # 切换为测试模式
    total_acc, train_loss, total_count = 0, 0, 0

    with torch.no_grad():
        for idx, (text,label,offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            
            loss = criterion(predicted_label, label)  # 计算loss值
            # 记录测试数据
            total_acc   += (predicted_label.argmax(1) == label).sum().item()
            train_loss  += loss.item()
            total_count += label.size(0)
            
    return total_acc/total_count, train_loss/total_count

4.训练模型

4.1.模型训练

# 超参数
EPOCHS     = 10 

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc, val_loss = evaluate(valid_dataloader)
    
    # 获取当前的学习率
    lr = optimizer.state_dict()['param_groups'][0]['lr']
    
    print('-' * 69)
    print('| epoch {:1d} | time: {:4.2f}s | '
          'valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch,
                                           time.time() - epoch_start_time,
                                           val_acc,val_loss,lr))

    print('-' * 69)

代码输出

| epoch 1 |  300/2420 batches | train_acc 0.105 train_loss 0.63515
| epoch 1 |  600/2420 batches | train_acc 0.103 train_loss 0.62862
| epoch 1 |  900/2420 batches | train_acc 0.109 train_loss 0.61628
| epoch 1 | 1200/2420 batches | train_acc 0.134 train_loss 0.59848
| epoch 1 | 1500/2420 batches | train_acc 0.116 train_loss 0.59714
| epoch 1 | 1800/2420 batches | train_acc 0.126 train_loss 0.58824
| epoch 1 | 2100/2420 batches | train_acc 0.147 train_loss 0.59021
| epoch 1 | 2400/2420 batches | train_acc 0.147 train_loss 0.58219
---------------------------------------------------------------------
| epoch 1 | time: 81.20s | valid_acc 0.166 valid_loss 0.572 | lr 0.010000
---------------------------------------------------------------------
| epoch 2 |  300/2420 batches | train_acc 0.166 train_loss 0.57671
| epoch 2 |  600/2420 batches | train_acc 0.149 train_loss 0.57718
| epoch 2 |  900/2420 batches | train_acc 0.153 train_loss 0.57841
| epoch 2 | 1200/2420 batches | train_acc 0.168 train_loss 0.57620
| epoch 2 | 1500/2420 batches | train_acc 0.152 train_loss 0.57920
| epoch 2 | 1800/2420 batches | train_acc 0.159 train_loss 0.57392
| epoch 2 | 2100/2420 batches | train_acc 0.158 train_loss 0.57644
| epoch 2 | 2400/2420 batches | train_acc 0.187 train_loss 0.57352
---------------------------------------------------------------------
| epoch 2 | time: 80.59s | valid_acc 0.210 valid_loss 0.561 | lr 0.010000
---------------------------------------------------------------------
| epoch 3 |  300/2420 batches | train_acc 0.178 train_loss 0.56841
| epoch 3 |  600/2420 batches | train_acc 0.182 train_loss 0.56651
| epoch 3 |  900/2420 batches | train_acc 0.191 train_loss 0.55880
| epoch 3 | 1200/2420 batches | train_acc 0.212 train_loss 0.55917
| epoch 3 | 1500/2420 batches | train_acc 0.209 train_loss 0.55905
| epoch 3 | 1800/2420 batches | train_acc 0.190 train_loss 0.56497
| epoch 3 | 2100/2420 batches | train_acc 0.225 train_loss 0.55538
| epoch 3 | 2400/2420 batches | train_acc 0.195 train_loss 0.56107
---------------------------------------------------------------------
| epoch 3 | time: 80.77s | valid_acc 0.223 valid_loss 0.549 | lr 0.010000
---------------------------------------------------------------------
| epoch 4 |  300/2420 batches | train_acc 0.221 train_loss 0.55027
| epoch 4 |  600/2420 batches | train_acc 0.226 train_loss 0.54617
| epoch 4 |  900/2420 batches | train_acc 0.243 train_loss 0.54574
| epoch 4 | 1200/2420 batches | train_acc 0.223 train_loss 0.55473
| epoch 4 | 1500/2420 batches | train_acc 0.218 train_loss 0.55534
| epoch 4 | 1800/2420 batches | train_acc 0.236 train_loss 0.54059
| epoch 4 | 2100/2420 batches | train_acc 0.228 train_loss 0.54930
| epoch 4 | 2400/2420 batches | train_acc 0.226 train_loss 0.55326
---------------------------------------------------------------------
| epoch 4 | time: 81.09s | valid_acc 0.239 valid_loss 0.539 | lr 0.010000
---------------------------------------------------------------------
| epoch 5 |  300/2420 batches | train_acc 0.228 train_loss 0.54374
| epoch 5 |  600/2420 batches | train_acc 0.211 train_loss 0.54772
| epoch 5 |  900/2420 batches | train_acc 0.230 train_loss 0.54833
| epoch 5 | 1200/2420 batches | train_acc 0.226 train_loss 0.54882
| epoch 5 | 1500/2420 batches | train_acc 0.217 train_loss 0.54486
| epoch 5 | 1800/2420 batches | train_acc 0.231 train_loss 0.54067
| epoch 5 | 2100/2420 batches | train_acc 0.245 train_loss 0.53641
| epoch 5 | 2400/2420 batches | train_acc 0.238 train_loss 0.53832
---------------------------------------------------------------------
| epoch 5 | time: 80.30s | valid_acc 0.250 valid_loss 0.531 | lr 0.010000
---------------------------------------------------------------------
| epoch 6 |  300/2420 batches | train_acc 0.233 train_loss 0.53834
| epoch 6 |  600/2420 batches | train_acc 0.229 train_loss 0.54007
| epoch 6 |  900/2420 batches | train_acc 0.240 train_loss 0.53498
| epoch 6 | 1200/2420 batches | train_acc 0.265 train_loss 0.53296
| epoch 6 | 1500/2420 batches | train_acc 0.237 train_loss 0.53516
| epoch 6 | 1800/2420 batches | train_acc 0.244 train_loss 0.54253
| epoch 6 | 2100/2420 batches | train_acc 0.263 train_loss 0.53246
| epoch 6 | 2400/2420 batches | train_acc 0.272 train_loss 0.52636
---------------------------------------------------------------------
| epoch 6 | time: 80.11s | valid_acc 0.236 valid_loss 0.543 | lr 0.010000
---------------------------------------------------------------------
| epoch 7 |  300/2420 batches | train_acc 0.247 train_loss 0.53724
| epoch 7 |  600/2420 batches | train_acc 0.277 train_loss 0.52268
| epoch 7 |  900/2420 batches | train_acc 0.287 train_loss 0.52461
| epoch 7 | 1200/2420 batches | train_acc 0.245 train_loss 0.52172
| epoch 7 | 1500/2420 batches | train_acc 0.253 train_loss 0.52076
| epoch 7 | 1800/2420 batches | train_acc 0.262 train_loss 0.51814
| epoch 7 | 2100/2420 batches | train_acc 0.277 train_loss 0.51824
| epoch 7 | 2400/2420 batches | train_acc 0.300 train_loss 0.51197
---------------------------------------------------------------------
| epoch 7 | time: 80.58s | valid_acc 0.301 valid_loss 0.502 | lr 0.010000
---------------------------------------------------------------------
| epoch 8 |  300/2420 batches | train_acc 0.290 train_loss 0.51114
| epoch 8 |  600/2420 batches | train_acc 0.299 train_loss 0.50069
| epoch 8 |  900/2420 batches | train_acc 0.299 train_loss 0.49917
| epoch 8 | 1200/2420 batches | train_acc 0.320 train_loss 0.49608
| epoch 8 | 1500/2420 batches | train_acc 0.341 train_loss 0.48615
| epoch 8 | 1800/2420 batches | train_acc 0.315 train_loss 0.50020
| epoch 8 | 2100/2420 batches | train_acc 0.366 train_loss 0.47658
| epoch 8 | 2400/2420 batches | train_acc 0.343 train_loss 0.48388
---------------------------------------------------------------------
| epoch 8 | time: 80.48s | valid_acc 0.367 valid_loss 0.471 | lr 0.010000
---------------------------------------------------------------------
| epoch 9 |  300/2420 batches | train_acc 0.355 train_loss 0.47828
| epoch 9 |  600/2420 batches | train_acc 0.358 train_loss 0.47669
| epoch 9 |  900/2420 batches | train_acc 0.369 train_loss 0.46768
| epoch 9 | 1200/2420 batches | train_acc 0.368 train_loss 0.47074
| epoch 9 | 1500/2420 batches | train_acc 0.385 train_loss 0.46331
| epoch 9 | 1800/2420 batches | train_acc 0.355 train_loss 0.47316
| epoch 9 | 2100/2420 batches | train_acc 0.380 train_loss 0.46985
| epoch 9 | 2400/2420 batches | train_acc 0.384 train_loss 0.46793
---------------------------------------------------------------------
| epoch 9 | time: 80.06s | valid_acc 0.381 valid_loss 0.456 | lr 0.010000
---------------------------------------------------------------------
| epoch 10 |  300/2420 batches | train_acc 0.380 train_loss 0.45352
| epoch 10 |  600/2420 batches | train_acc 0.395 train_loss 0.44962
| epoch 10 |  900/2420 batches | train_acc 0.407 train_loss 0.44455
| epoch 10 | 1200/2420 batches | train_acc 0.397 train_loss 0.44962
| epoch 10 | 1500/2420 batches | train_acc 0.401 train_loss 0.44773
| epoch 10 | 1800/2420 batches | train_acc 0.422 train_loss 0.43027
| epoch 10 | 2100/2420 batches | train_acc 0.454 train_loss 0.42070
| epoch 10 | 2400/2420 batches | train_acc 0.446 train_loss 0.43559
---------------------------------------------------------------------
| epoch 10 | time: 80.64s | valid_acc 0.454 valid_loss 0.413 | lr 0.010000
---------------------------------------------------------------------

4.2.模型评估

test_acc, test_loss = evaluate(valid_dataloader)
print('模型准确率为:{:5.4f}'.format(test_acc))

代码输出

模型准确率为:0.4471

5.根据自己的理解对代码进行调优,使准确率达到70%

要将模型的准确率提高,可以考虑以下几种方法调整:

5.1. 调整模型超参数

● 学习率 (Learning Rate):可能需要调整学习率,过高或过低的学习率都会影响模型的收敛速度和效果。可以尝试在代码中降低或增加学习率。

●隐藏层大小 (Hidden Size): 增加隐藏层的大小或增加网络层数可能会提升模型的表达能力。

● 批次大小 (Batch Size): 调整批次大小,可以尝试增加或减少 batch_size 看看效果。

5.2. 增加训练轮数
增加训练轮数 (epochs) 可以让模型有更多时间去学习和调整参数。训练更多轮次,但要注意过拟合风险。

5.3. 正则化

●添加 Dropout 层可以防止模型过拟合,从而提升模型的泛化能力。
●L2 正则化:在优化器中添加权重衰减。

5.4. 数据增强
增加数据量或通过数据增强技术生成更多的训练数据。数据多样性通常能提高模型的泛化能力。

5.5. 使用更复杂的模型架构
可以尝试使用更复杂的模型架构,如双向 LSTM、GRU 或 Transformer 模型,这些模型在处理序列数据上可能表现更好。

5.6. 调优损失函数
如果当前损失函数表现不好,可以考虑更改损失函数或在现有损失函数上加入自定义的调整。

5.7. 检查数据集
确保数据集中标签分布合理且没有错误数据。如果数据不平衡,可能需要进行重采样或使用加权损失函数。

5.8. 实验与调参
通过实验观察不同参数和方法的组合对模型性能的影响,找到最合适的设置。

这些方法可以单独或组合使用以提高模型的准确率。首先可以从增加训练轮次和调整学习率开始,然后再逐步尝试增加模型的复杂度和使用正则化方法。调优是一个反复实验的过程,需要根据实验结果不断调整。

对于该次提高准确率达到70%,可以增加训练轮数,对上面的代码,令EPOCHS从10调整到100,即:

# 超参数
EPOCHS     = 100

其他的不变,得到训练结果如下所示:

# 超参数
| epoch 1 |  300/2420 batches | train_acc 0.124 train_loss 0.63172
| epoch 1 |  600/2420 batches | train_acc 0.100 train_loss 0.62975
| epoch 1 |  900/2420 batches | train_acc 0.108 train_loss 0.61340
| epoch 1 | 1200/2420 batches | train_acc 0.128 train_loss 0.60664
| epoch 1 | 1500/2420 batches | train_acc 0.102 train_loss 0.59968
| epoch 1 | 1800/2420 batches | train_acc 0.114 train_loss 0.59483
| epoch 1 | 2100/2420 batches | train_acc 0.141 train_loss 0.58566
| epoch 1 | 2400/2420 batches | train_acc 0.128 train_loss 0.58129
---------------------------------------------------------------------
| epoch 1 | time: 85.62s | valid_acc 0.141 valid_loss 0.585 | lr 0.010000
---------------------------------------------------------------------
| epoch 2 |  300/2420 batches | train_acc 0.157 train_loss 0.58087
| epoch 2 |  600/2420 batches | train_acc 0.157 train_loss 0.57831
| epoch 2 |  900/2420 batches | train_acc 0.153 train_loss 0.58050
| epoch 2 | 1200/2420 batches | train_acc 0.172 train_loss 0.57228
| epoch 2 | 1500/2420 batches | train_acc 0.187 train_loss 0.56080
| epoch 2 | 1800/2420 batches | train_acc 0.163 train_loss 0.57341
| epoch 2 | 2100/2420 batches | train_acc 0.169 train_loss 0.56785
| epoch 2 | 2400/2420 batches | train_acc 0.174 train_loss 0.56988
---------------------------------------------------------------------
| epoch 2 | time: 86.00s | valid_acc 0.194 valid_loss 0.564 | lr 0.010000
---------------------------------------------------------------------
| epoch 3 |  300/2420 batches | train_acc 0.194 train_loss 0.56447
| epoch 3 |  600/2420 batches | train_acc 0.187 train_loss 0.55973
| epoch 3 |  900/2420 batches | train_acc 0.198 train_loss 0.55564
| epoch 3 | 1200/2420 batches | train_acc 0.189 train_loss 0.55915
| epoch 3 | 1500/2420 batches | train_acc 0.199 train_loss 0.55970
| epoch 3 | 1800/2420 batches | train_acc 0.200 train_loss 0.56135
| epoch 3 | 2100/2420 batches | train_acc 0.225 train_loss 0.54863
| epoch 3 | 2400/2420 batches | train_acc 0.216 train_loss 0.54849
---------------------------------------------------------------------
| epoch 3 | time: 84.61s | valid_acc 0.243 valid_loss 0.542 | lr 0.010000
---------------------------------------------------------------------
| epoch 4 |  300/2420 batches | train_acc 0.226 train_loss 0.55006
| epoch 4 |  600/2420 batches | train_acc 0.217 train_loss 0.55003
| epoch 4 |  900/2420 batches | train_acc 0.240 train_loss 0.54791
| epoch 4 | 1200/2420 batches | train_acc 0.253 train_loss 0.54397
| epoch 4 | 1500/2420 batches | train_acc 0.233 train_loss 0.54547
| epoch 4 | 1800/2420 batches | train_acc 0.233 train_loss 0.55099
| epoch 4 | 2100/2420 batches | train_acc 0.227 train_loss 0.54291
| epoch 4 | 2400/2420 batches | train_acc 0.223 train_loss 0.54319
---------------------------------------------------------------------
| epoch 4 | time: 84.61s | valid_acc 0.261 valid_loss 0.548 | lr 0.010000
---------------------------------------------------------------------
| epoch 5 |  300/2420 batches | train_acc 0.233 train_loss 0.54866
| epoch 5 |  600/2420 batches | train_acc 0.238 train_loss 0.53594
| epoch 5 |  900/2420 batches | train_acc 0.250 train_loss 0.53580
| epoch 5 | 1200/2420 batches | train_acc 0.236 train_loss 0.53639
| epoch 5 | 1500/2420 batches | train_acc 0.242 train_loss 0.53502
| epoch 5 | 1800/2420 batches | train_acc 0.229 train_loss 0.54436
| epoch 5 | 2100/2420 batches | train_acc 0.266 train_loss 0.52806
| epoch 5 | 2400/2420 batches | train_acc 0.241 train_loss 0.53787
---------------------------------------------------------------------
| epoch 5 | time: 84.84s | valid_acc 0.240 valid_loss 0.537 | lr 0.010000
---------------------------------------------------------------------
| epoch 6 |  300/2420 batches | train_acc 0.245 train_loss 0.53467
| epoch 6 |  600/2420 batches | train_acc 0.241 train_loss 0.53455
| epoch 6 |  900/2420 batches | train_acc 0.253 train_loss 0.52606
| epoch 6 | 1200/2420 batches | train_acc 0.250 train_loss 0.53322
| epoch 6 | 1500/2420 batches | train_acc 0.276 train_loss 0.52478
| epoch 6 | 1800/2420 batches | train_acc 0.251 train_loss 0.52935
| epoch 6 | 2100/2420 batches | train_acc 0.260 train_loss 0.52503
| epoch 6 | 2400/2420 batches | train_acc 0.239 train_loss 0.53837
---------------------------------------------------------------------
| epoch 6 | time: 84.49s | valid_acc 0.260 valid_loss 0.536 | lr 0.010000
---------------------------------------------------------------------
| epoch 7 |  300/2420 batches | train_acc 0.263 train_loss 0.51957
| epoch 7 |  600/2420 batches | train_acc 0.249 train_loss 0.52756
| epoch 7 |  900/2420 batches | train_acc 0.267 train_loss 0.52570
| epoch 7 | 1200/2420 batches | train_acc 0.266 train_loss 0.52688
| epoch 7 | 1500/2420 batches | train_acc 0.268 train_loss 0.52311
| epoch 7 | 1800/2420 batches | train_acc 0.238 train_loss 0.53244
| epoch 7 | 2100/2420 batches | train_acc 0.253 train_loss 0.52637
| epoch 7 | 2400/2420 batches | train_acc 0.273 train_loss 0.52541
---------------------------------------------------------------------
| epoch 7 | time: 84.32s | valid_acc 0.269 valid_loss 0.521 | lr 0.010000
---------------------------------------------------------------------
| epoch 8 |  300/2420 batches | train_acc 0.256 train_loss 0.52177
| epoch 8 |  600/2420 batches | train_acc 0.254 train_loss 0.52789
| epoch 8 |  900/2420 batches | train_acc 0.269 train_loss 0.52848
| epoch 8 | 1200/2420 batches | train_acc 0.284 train_loss 0.52198
| epoch 8 | 1500/2420 batches | train_acc 0.270 train_loss 0.51472
| epoch 8 | 1800/2420 batches | train_acc 0.261 train_loss 0.52358
| epoch 8 | 2100/2420 batches | train_acc 0.253 train_loss 0.52217
| epoch 8 | 2400/2420 batches | train_acc 0.270 train_loss 0.51121
---------------------------------------------------------------------
| epoch 8 | time: 86.19s | valid_acc 0.271 valid_loss 0.517 | lr 0.010000
---------------------------------------------------------------------
| epoch 9 |  300/2420 batches | train_acc 0.248 train_loss 0.51928
| epoch 9 |  600/2420 batches | train_acc 0.281 train_loss 0.51930
| epoch 9 |  900/2420 batches | train_acc 0.272 train_loss 0.51636
| epoch 9 | 1200/2420 batches | train_acc 0.270 train_loss 0.51573
| epoch 9 | 1500/2420 batches | train_acc 0.277 train_loss 0.51272
| epoch 9 | 1800/2420 batches | train_acc 0.268 train_loss 0.51255
| epoch 9 | 2100/2420 batches | train_acc 0.284 train_loss 0.51215
| epoch 9 | 2400/2420 batches | train_acc 0.271 train_loss 0.50947
---------------------------------------------------------------------
| epoch 9 | time: 84.34s | valid_acc 0.296 valid_loss 0.507 | lr 0.010000
---------------------------------------------------------------------
| epoch 10 |  300/2420 batches | train_acc 0.277 train_loss 0.50457
| epoch 10 |  600/2420 batches | train_acc 0.278 train_loss 0.51906
| epoch 10 |  900/2420 batches | train_acc 0.295 train_loss 0.50676
| epoch 10 | 1200/2420 batches | train_acc 0.302 train_loss 0.49880
| epoch 10 | 1500/2420 batches | train_acc 0.304 train_loss 0.50554
| epoch 10 | 1800/2420 batches | train_acc 0.314 train_loss 0.49315
| epoch 10 | 2100/2420 batches | train_acc 0.346 train_loss 0.48392
| epoch 10 | 2400/2420 batches | train_acc 0.338 train_loss 0.47602
---------------------------------------------------------------------
| epoch 10 | time: 84.58s | valid_acc 0.342 valid_loss 0.486 | lr 0.010000
---------------------------------------------------------------------
| epoch 11 |  300/2420 batches | train_acc 0.348 train_loss 0.47524
| epoch 11 |  600/2420 batches | train_acc 0.386 train_loss 0.47113
| epoch 11 |  900/2420 batches | train_acc 0.362 train_loss 0.46883
| epoch 11 | 1200/2420 batches | train_acc 0.383 train_loss 0.45975
| epoch 11 | 1500/2420 batches | train_acc 0.381 train_loss 0.46418
| epoch 11 | 1800/2420 batches | train_acc 0.380 train_loss 0.45996
| epoch 11 | 2100/2420 batches | train_acc 0.386 train_loss 0.46528
| epoch 11 | 2400/2420 batches | train_acc 0.410 train_loss 0.45139
---------------------------------------------------------------------
| epoch 11 | time: 83.80s | valid_acc 0.410 valid_loss 0.445 | lr 0.010000
---------------------------------------------------------------------
| epoch 12 |  300/2420 batches | train_acc 0.415 train_loss 0.44190
| epoch 12 |  600/2420 batches | train_acc 0.401 train_loss 0.44386
| epoch 12 |  900/2420 batches | train_acc 0.428 train_loss 0.43816
| epoch 12 | 1200/2420 batches | train_acc 0.427 train_loss 0.43997
| epoch 12 | 1500/2420 batches | train_acc 0.426 train_loss 0.43172
| epoch 12 | 1800/2420 batches | train_acc 0.409 train_loss 0.44360
| epoch 12 | 2100/2420 batches | train_acc 0.439 train_loss 0.43052
| epoch 12 | 2400/2420 batches | train_acc 0.446 train_loss 0.43177
---------------------------------------------------------------------
| epoch 12 | time: 83.92s | valid_acc 0.458 valid_loss 0.420 | lr 0.010000
---------------------------------------------------------------------
| epoch 13 |  300/2420 batches | train_acc 0.463 train_loss 0.41759
| epoch 13 |  600/2420 batches | train_acc 0.474 train_loss 0.40258
| epoch 13 |  900/2420 batches | train_acc 0.463 train_loss 0.41299
| epoch 13 | 1200/2420 batches | train_acc 0.476 train_loss 0.41546
| epoch 13 | 1500/2420 batches | train_acc 0.464 train_loss 0.40948
| epoch 13 | 1800/2420 batches | train_acc 0.468 train_loss 0.41841
| epoch 13 | 2100/2420 batches | train_acc 0.490 train_loss 0.39884
| epoch 13 | 2400/2420 batches | train_acc 0.487 train_loss 0.40473
---------------------------------------------------------------------
| epoch 13 | time: 83.64s | valid_acc 0.500 valid_loss 0.402 | lr 0.010000
---------------------------------------------------------------------
| epoch 14 |  300/2420 batches | train_acc 0.515 train_loss 0.38339
| epoch 14 |  600/2420 batches | train_acc 0.519 train_loss 0.37501
| epoch 14 |  900/2420 batches | train_acc 0.517 train_loss 0.38280
| epoch 14 | 1200/2420 batches | train_acc 0.517 train_loss 0.38195
| epoch 14 | 1500/2420 batches | train_acc 0.525 train_loss 0.39008
| epoch 14 | 1800/2420 batches | train_acc 0.525 train_loss 0.38158
| epoch 14 | 2100/2420 batches | train_acc 0.525 train_loss 0.38559
| epoch 14 | 2400/2420 batches | train_acc 0.529 train_loss 0.38228
---------------------------------------------------------------------
| epoch 14 | time: 86.00s | valid_acc 0.566 valid_loss 0.366 | lr 0.010000
---------------------------------------------------------------------
| epoch 15 |  300/2420 batches | train_acc 0.553 train_loss 0.37311
| epoch 15 |  600/2420 batches | train_acc 0.564 train_loss 0.35000
| epoch 15 |  900/2420 batches | train_acc 0.558 train_loss 0.36826
| epoch 15 | 1200/2420 batches | train_acc 0.547 train_loss 0.36430
| epoch 15 | 1500/2420 batches | train_acc 0.564 train_loss 0.36633
| epoch 15 | 1800/2420 batches | train_acc 0.561 train_loss 0.35983
| epoch 15 | 2100/2420 batches | train_acc 0.593 train_loss 0.33570
| epoch 15 | 2400/2420 batches | train_acc 0.594 train_loss 0.33010
---------------------------------------------------------------------
| epoch 15 | time: 84.05s | valid_acc 0.589 valid_loss 0.342 | lr 0.010000
----
### 文本分类 #### 数据预处理 要求训练集和测试集分开存储,对于中文的数据必须先分词,对分词后的词用空格符分开,并且将标签连接到每条数据的尾部,标签和句子用分隔符\分开。具体的如下: * 今天 的 天气 真好\积极 #### 文件结构介绍 * config文件:配置各种模型的配置参数 * data:存放训练集和测试集 * ckpt_model:存放checkpoint模型文件 * data_helpers:提供数据处理的方法 * pb_model:存放pb模型文件 * outputs:存放vocab,word_to_index, label_to_index, 处理后的数据 * models:存放模型代码 * trainers:存放训练代码 * predictors:存放预测代码 #### 训练模型 * python train.py --config_path="config/textcnn_config.json" #### 预测模型 * 预测代码都在predictors/predict.py中,初始化Predictor对象,调用predict方法即可。 #### 模型的配置参数详述 ##### textcnn:基于textcnn的文本分类 * model_name:模型名称 * epochs:全样本迭代次数 * checkpoint_every:迭代多少步保存一次模型文件 * eval_every:迭代多少步验证一次模型 * learning_rate:学习速率 * optimization:优化算法 * embedding_size:embedding层大小 * num_filters:卷积核的数量 * filter_sizes:卷积核的尺寸 * batch_size:批样本大小 * sequence_length:序列长度 * vocab_size:词汇表大小 * num_classes:样本的类别数,二分类时置为1,多分类时置为实际类别数 * keep_prob:保留神经元的比例 * l2_reg_lambda:L2正则化的系数,主要对全连接层的参数正则化 * max_grad_norm:梯度阶段临界值 * train_data:训练数据的存储路径 * eval_data:验证数据的存储路径 * stop_word:停用词表的存储路径 * output_path:输出路径,用来存储vocab,处理后的训练数据,验证数据 * word_vectors_path:词向量的路径 * ckpt_model_path:checkpoint 模型的存储路径 * pb_model_path:pb 模型的存储路径 ##### bilstm:基于bilstm的文本分类 * model_name:模型名称 * epochs:全样本迭代次数 * checkpoint_every:迭代多少步保存一次模型文件 * eval_every:迭代多少步验证一次模型 * learning_rate:学习速率 * optimization:优化算法 * embedding_size:embedding层大小 * hidden_sizes:lstm的隐层大小,列表对象,支持多层lstm,只要在列表中添加相应的层对应的隐层大小 * batch_size:批样本大小 * sequence_length:序列长度 * vocab_size:词汇表大小 * num_classes:样本的类别数,二分类时置为1,多分类时置为实际类别数 * keep_prob:保留神经元的比例 * l2_reg_lambda:L2正则化的系数,主要对全连接层的参数正则化 * max_grad_norm:梯度阶段临界值 * train_data:训练数据的存储路径 * eval_data:验证数据的存储路径 * stop_word:停用词表的存储路径 * output_path:输出路径,用来存储vocab,处理后的训练数据,验证数据 * word_vectors_path:词向量的路径 * ckpt_model_path:checkpoint 模型的存储路径 * pb_model_path:pb 模型的存储路径 ##### bilstm atten:基于bilstm + attention 的文本分类 * model_name:模型名称 * epochs:全样本迭代次数 * checkpoint_every:迭代多少步保存一次模型文件 * eval_every:迭代多少步验证一次模型 * learning_rate:学习速率 * optimization:优化算法 * embedding_size:embedding层大小 * hidd
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

lihuhelihu

谢谢您的支持和鼓励,我依然努力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值