第TR5周:Transformer实战:文本分类

>- **🍨 本文为[🔗365天深度学习训练营]中的学习记录博客**
>- **🍖 原作者:[K同学啊]**

本人往期文章可查阅: 深度学习总结

 本周任务:

  • 理解文中代码逻辑并成功运行
  • 根据自己的理解对代码进行调优,使准确率达到70%

1.准备工作

🏡 我的环境:

  • 语言环境:Python3.11
  • 编译器:PyCharm
  • 深度学习环境:Pytorch
    • torch==2.0.0+cu118
    • torchvision==0.18.1+cu118
  • 显卡:NVIDIA GeForce GTX 1660

1.1.环境安装

       这是一个使用PyTorch通过Transformer算法实现简单的文本分类实战案例。

import torch,torchvision
print(torch.__version__)
print(torchvision.__version__)

输出:

2.0.0+cu118
0.15.0+cu118

1.2.加载数据

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warnings

warnings.filterwarnings("ignore")

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

输出:

device(type='cuda')

导入数据:

import pandas as pd

# 加载自定义中文数据
train_data=pd.read_csv(r'E:\DATABASE\TR-series\TR5\train.csv',sep='\t',header=None)
train_data.head()

输出:

# 构造数据集迭代器
def coustom_data_iter(texts,labels):
    for x,y in zip(texts,labels):
        yield x,y
        
train_iter=coustom_data_iter(train_data[0].values[:],train_data[1].values[:])

2.数据预处理

2.1.构建词典

需要安装  jieba 分词库,安装语句如下:

cmd命令:pip install jieba -i https://pypi.tuna.tsinghua.edu.cn/simple

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba

# 中文分词方法
tokenizer=jieba.lcut

def yield_token(data_iter):
    for text,_ in data_iter:
        yield tokenizer(text)
        
vocab=build_vocab_from_iterator(yield_token(train_iter),specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"]) # 设置默认索引,如果找不到单词,则会选择默认索引

输出:

Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\cyb\AppData\Local\Temp\jieba.cache
Loading model cost 0.639 seconds.
Prefix dict has been built successfully.

vocab(['我','想','看','和平','精英','上','战神','必备','技巧','的','游戏','视频'])

输出:

[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]

label_name=list(set(train_data[1].values[:]))
print(label_name)

输出:

['Video-Play', 'HomeAppliance-Control', 'Travel-Query', 'Audio-Play', 'Other', 'Calendar-Query', 'Music-Play', 'Alarm-Update', 'Radio-Listen', 'Weather-Query', 'TVProgram-Play', 'FilmTele-Play']

text_pipeline=lambda x:vocab(tokenizer(x))
label_pipeline=lambda x:label_name.index(x)

print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))

输出:

[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
0

2.2.生成数据批次和迭代器

from torch.utils.data import DataLoader

def collate_batch(batch):
    label_list,text_list,offsets=[],[],[0]
    
    for (_text,_label) in batch:
        # 标签列表
        label_list.append(label_pipeline(_label))
        
        # 文本列表
        processed_text=torch.tensor(text_pipeline(_text),dtype=torch.int64)
        text_list.append(processed_text)
        
        # 偏移量,即语句的总词汇量
        offsets.append(processed_text.size(0))
        
    label_list=torch.tensor(label_list,dtype=torch.int64)
    text_list=torch.cat(text_list)
    offsets=torch.tensor(offsets[:-1]).cumsum(dim=0) # 返回维度dim中输入元素的累计和
    
    return text_list.to(device),label_list.to(device),offsets.to(device)

2.3.构建数据集

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset

BATCH_SIZE=4

train_iter=coustom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset=to_map_style_dataset(train_iter)

split_train_,split_valid_=random_split(
    train_dataset,
    [int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])

train_dataloader=DataLoader(split_train_,batch_size=BATCH_SIZE,
                            shuffle=True,collate_fn=collate_batch)
valid_dataloader=DataLoader(split_valid_,batch_size=BATCH_SIZE,
                            shuffle=True,collate_fn=collate_batch)

to_map_style_dataset() 函数

       作用是将一个迭代式的数据集(Iterable-style dataset)转换为映射式的数据集(Map-style dataset)。这个转换使得我们可以通过索引(例如:整数)更方便地访问数据集中的元素。在PyTorch中,数据集可以分为两种类型:Iterable-style 和 Map-style。

  • Iterable-style 数据集实现了  __iter__()  方法,可以迭代访问数据集中的元素,但不支持通过索引访问。
  • Map-style 数据集实现了  __getitem__() 和 __len__()  方法,可以直接通过索引访问特定元素,并能获取数据集的大小。

3.模型构建

3.1.定义位置编码函数

import math,os,torch

class PositionalEncoding(nn.Module):
    def __init__(self,embed_dim,max_len=500):
        super(PositionalEncoding,self).__init__()
        
        # 创建一个大小为 [max_len,embed_dim]的零张量
        pe=torch.zeros(max_len,embed_dim)
        # 创建一个形状为 [max_len,1] 的位置索引张量
        position=torch.arange(0,max_len,dtype=torch.float).unsqueeze(1)
        
        div_term=torch.exp(torch.arange(0,embed_dim,2).float()*(-math.log(100.0)/embed_dim))
        
        pe[:,0::2]=torch.sin(position*div_term) # 计算 PE(pos,2i)
        pe[:,1::2]=torch.cos(position*div_term) # 计算 PE(pos,2i+1)
        pe=pe.unsqueeze(0).transpose(0,1)
        
        # 将位置编码张量注册为模型的缓冲区,参数不参与梯度下降,保存model的时候会将其保存下来
        self.register_buffer('pe',pe)
        
    def forward(self,x):
        # 将位置编码添加到输入张量中,注意位置编码的形状
        x=x+self.pe[:x.size(0)]
        return x

3.2.定义Transformer模型

from tempfile import TemporaryDirectory
from typing import Tuple
from torch import nn,Tensor
from torch.nn import TransformerEncoder,TransformerEncoderLayer
from torch.utils.data import dataset

class TransformerModel(nn.Module):
    
    def __init__(self,vocab_size,embed_dim,num_class,nhead=8,
                 d_hid=256,nlayers=12,dropout=0.1):
        super().__init__()
        self.embedding=nn.EmbeddingBag(vocab_size, # 词典大小
                                       embed_dim, # 嵌入的维度
                                       sparse=False)
        self.pos_encoder=PositionalEncoding(embed_dim)
        
        # 定义编码器层
        encoder_layers=TransformerEncoderLayer(embed_dim,nhead,d_hid,dropout)
        self.transformer_encoder=TransformerEncoder(encoder_layers,nlayers)
        self.embed_dim=embed_dim
        self.linear=nn.Linear(embed_dim*4,num_class)
        
    def forward(self,src,offsets,src_mask=None):
        src=self.embedding(src,offsets)
        src=self.pos_encoder(src)
        output=self.transformer_encoder(src,src_mask)
        
        output=output.view(4,embed_dim*4)
        output=self.linear(output)
        
        return output

3.3.初始化模型

vocab_size=len(vocab) # 词汇表的大小
embed_dim=64 # 嵌入维度
num_class=len(label_name)

# 创建 Transformer 模型,并将其移动到设备上
model=TransformerModel(vocab_size,embed_dim,num_class).to(device)

3.4.定义训练函数

import time

def train(dataloader):
    model.train()  # 切换为训练模式
    total_acc,train_loss,total_count=0,0,0
    log_interval=300
    start_time=time.time()
    
    for idx,(text,label,offsets) in enumerate(dataloader):
        predicted_label=model(text,offsets)
        optimizer.zero_grad()  # grad属性归零
        
        loss=criterion(predicted_label,label) # 计算网络输出和真实值之间的差距,label为真实值
        loss.backward() # 反向传播
        optimizer.step() # 每一步自动更新
        
        # 记录acc与loss
        total_acc+=(predicted_label.argmax(1)==label).sum().item()
        train_loss+=loss.item()
        total_count+=label.size(0)
        
        if idx % log_interval==0 and idx>0:
            elapsed=time.time()-start_time
            print('| epoch {:1d} | {:4d}/{:4d} batches '
                  '| train_acc {:5.2f}  train_loss {:4.5f}'.format(epoch,idx,len(dataloader),
                                                                   total_acc/total_count*100,
                                                                   train_loss/total_count))
            total_acc,train_loss,total_count=0,0,0
            start_time=time.time()

3.5.定义评估函数

def evaluate(dataloader):
    model.eval()  # 切换为测试模式
    total_acc,train_loss,total_count=0,0,0
    
    with torch.no_grad():
        for idx,(text,label,offsets) in enumerate(dataloader):
            predicted_label=model(text,offsets)
            
            loss=criterion(predicted_label,label) # 计算 loss值
            # 记录测试数据
            total_acc+=(predicted_label.argmax(1)==label).sum().item()
            train_loss+=loss.item()
            total_count+=label.size(0)
            
    return total_acc/total_count,train_loss/total_count

4.训练模型

4.1.模型训练

# 超参数
EPOCHS=10

criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=1e-2)

for epoch in range(1,EPOCHS+1):
    epoch_start_time=time.time()
    train(train_dataloader)
    val_acc,val_loss=evaluate(valid_dataloader)
    
    # 获取当前的学习率
    lr=optimizer.state_dict()['param_groups'][0]['lr']
    
    print('-'*69)
    print('| epoch {:1d} | time: {:4.2f}s | '
          'valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch,
                                                                     time.time()-epoch_start_time,
                                                                     val_acc,val_loss,lr))
    print('-'*69)

输出:

| epoch 1 |  300/2420 batches | train_acc 10.05  train_loss 0.63597
| epoch 1 |  600/2420 batches | train_acc 10.75  train_loss 0.61639
| epoch 1 |  900/2420 batches | train_acc 11.75  train_loss 0.61587
| epoch 1 | 1200/2420 batches | train_acc 11.75  train_loss 0.60497
| epoch 1 | 1500/2420 batches | train_acc 11.75  train_loss 0.59741
| epoch 1 | 1800/2420 batches | train_acc 13.17  train_loss 0.58865
| epoch 1 | 2100/2420 batches | train_acc 14.33  train_loss 0.58118
| epoch 1 | 2400/2420 batches | train_acc 16.92  train_loss 0.58660
---------------------------------------------------------------------
| epoch 1 | time: 72.37s | valid_acc 0.164 valid_loss 0.579 | lr 0.010000
---------------------------------------------------------------------
| epoch 2 |  300/2420 batches | train_acc 17.36  train_loss 0.57580
| epoch 2 |  600/2420 batches | train_acc 19.58  train_loss 0.56323
| epoch 2 |  900/2420 batches | train_acc 16.17  train_loss 0.57542
| epoch 2 | 1200/2420 batches | train_acc 17.08  train_loss 0.57320
| epoch 2 | 1500/2420 batches | train_acc 19.33  train_loss 0.56107
| epoch 2 | 1800/2420 batches | train_acc 18.58  train_loss 0.56723
| epoch 2 | 2100/2420 batches | train_acc 19.33  train_loss 0.56513
| epoch 2 | 2400/2420 batches | train_acc 19.42  train_loss 0.56916
---------------------------------------------------------------------
| epoch 2 | time: 72.19s | valid_acc 0.206 valid_loss 0.560 | lr 0.010000
---------------------------------------------------------------------
| epoch 3 |  300/2420 batches | train_acc 21.51  train_loss 0.55742
| epoch 3 |  600/2420 batches | train_acc 20.08  train_loss 0.56014
| epoch 3 |  900/2420 batches | train_acc 18.00  train_loss 0.56565
| epoch 3 | 1200/2420 batches | train_acc 20.50  train_loss 0.55969
| epoch 3 | 1500/2420 batches | train_acc 22.33  train_loss 0.55079
| epoch 3 | 1800/2420 batches | train_acc 21.42  train_loss 0.55814
| epoch 3 | 2100/2420 batches | train_acc 23.58  train_loss 0.54727
| epoch 3 | 2400/2420 batches | train_acc 20.33  train_loss 0.55949
---------------------------------------------------------------------
| epoch 3 | time: 72.61s | valid_acc 0.213 valid_loss 0.561 | lr 0.010000
---------------------------------------------------------------------
| epoch 4 |  300/2420 batches | train_acc 22.43  train_loss 0.54979
| epoch 4 |  600/2420 batches | train_acc 22.75  train_loss 0.54472
| epoch 4 |  900/2420 batches | train_acc 22.67  train_loss 0.54405
| epoch 4 | 1200/2420 batches | train_acc 22.67  train_loss 0.55083
| epoch 4 | 1500/2420 batches | train_acc 22.67  train_loss 0.54703
| epoch 4 | 1800/2420 batches | train_acc 24.42  train_loss 0.53693
| epoch 4 | 2100/2420 batches | train_acc 23.58  train_loss 0.55088
| epoch 4 | 2400/2420 batches | train_acc 24.08  train_loss 0.54040
---------------------------------------------------------------------
| epoch 4 | time: 73.21s | valid_acc 0.255 valid_loss 0.533 | lr 0.010000
---------------------------------------------------------------------
| epoch 5 |  300/2420 batches | train_acc 25.33  train_loss 0.53609
| epoch 5 |  600/2420 batches | train_acc 24.00  train_loss 0.53878
| epoch 5 |  900/2420 batches | train_acc 25.58  train_loss 0.54059
| epoch 5 | 1200/2420 batches | train_acc 26.08  train_loss 0.52828
| epoch 5 | 1500/2420 batches | train_acc 24.50  train_loss 0.52981
| epoch 5 | 1800/2420 batches | train_acc 24.08  train_loss 0.54025
| epoch 5 | 2100/2420 batches | train_acc 25.50  train_loss 0.54025
| epoch 5 | 2400/2420 batches | train_acc 27.50  train_loss 0.52163
---------------------------------------------------------------------
| epoch 5 | time: 72.70s | valid_acc 0.261 valid_loss 0.530 | lr 0.010000
---------------------------------------------------------------------
| epoch 6 |  300/2420 batches | train_acc 27.33  train_loss 0.52560
| epoch 6 |  600/2420 batches | train_acc 27.92  train_loss 0.52092
| epoch 6 |  900/2420 batches | train_acc 29.50  train_loss 0.50868
| epoch 6 | 1200/2420 batches | train_acc 30.42  train_loss 0.51428
| epoch 6 | 1500/2420 batches | train_acc 30.08  train_loss 0.51098
| epoch 6 | 1800/2420 batches | train_acc 31.75  train_loss 0.50787
| epoch 6 | 2100/2420 batches | train_acc 30.50  train_loss 0.50646
| epoch 6 | 2400/2420 batches | train_acc 31.08  train_loss 0.50609
---------------------------------------------------------------------
| epoch 6 | time: 71.89s | valid_acc 0.319 valid_loss 0.502 | lr 0.010000
---------------------------------------------------------------------
| epoch 7 |  300/2420 batches | train_acc 32.81  train_loss 0.49740
| epoch 7 |  600/2420 batches | train_acc 36.50  train_loss 0.48570
| epoch 7 |  900/2420 batches | train_acc 32.17  train_loss 0.50369
| epoch 7 | 1200/2420 batches | train_acc 35.83  train_loss 0.48545
| epoch 7 | 1500/2420 batches | train_acc 35.83  train_loss 0.47964
| epoch 7 | 1800/2420 batches | train_acc 35.00  train_loss 0.47124
| epoch 7 | 2100/2420 batches | train_acc 37.08  train_loss 0.47590
| epoch 7 | 2400/2420 batches | train_acc 37.92  train_loss 0.47795
---------------------------------------------------------------------
| epoch 7 | time: 73.88s | valid_acc 0.382 valid_loss 0.465 | lr 0.010000
---------------------------------------------------------------------
| epoch 8 |  300/2420 batches | train_acc 39.12  train_loss 0.47037
| epoch 8 |  600/2420 batches | train_acc 38.00  train_loss 0.46415
| epoch 8 |  900/2420 batches | train_acc 39.00  train_loss 0.45649
| epoch 8 | 1200/2420 batches | train_acc 39.92  train_loss 0.45573
| epoch 8 | 1500/2420 batches | train_acc 42.25  train_loss 0.44419
| epoch 8 | 1800/2420 batches | train_acc 40.08  train_loss 0.45657
| epoch 8 | 2100/2420 batches | train_acc 40.83  train_loss 0.45417
| epoch 8 | 2400/2420 batches | train_acc 41.25  train_loss 0.44707
---------------------------------------------------------------------
| epoch 8 | time: 73.51s | valid_acc 0.421 valid_loss 0.435 | lr 0.010000
---------------------------------------------------------------------
| epoch 9 |  300/2420 batches | train_acc 45.43  train_loss 0.43704
| epoch 9 |  600/2420 batches | train_acc 43.08  train_loss 0.43955
| epoch 9 |  900/2420 batches | train_acc 45.00  train_loss 0.42919
| epoch 9 | 1200/2420 batches | train_acc 44.08  train_loss 0.43416
| epoch 9 | 1500/2420 batches | train_acc 45.33  train_loss 0.41806
| epoch 9 | 1800/2420 batches | train_acc 44.83  train_loss 0.42926
| epoch 9 | 2100/2420 batches | train_acc 48.17  train_loss 0.41192
| epoch 9 | 2400/2420 batches | train_acc 49.42  train_loss 0.40714
---------------------------------------------------------------------
| epoch 9 | time: 73.83s | valid_acc 0.498 valid_loss 0.392 | lr 0.010000
---------------------------------------------------------------------
| epoch 10 |  300/2420 batches | train_acc 50.75  train_loss 0.39862
| epoch 10 |  600/2420 batches | train_acc 50.00  train_loss 0.40449
| epoch 10 |  900/2420 batches | train_acc 50.25  train_loss 0.40058
| epoch 10 | 1200/2420 batches | train_acc 50.58  train_loss 0.40011
| epoch 10 | 1500/2420 batches | train_acc 53.83  train_loss 0.38994
| epoch 10 | 1800/2420 batches | train_acc 53.92  train_loss 0.37948
| epoch 10 | 2100/2420 batches | train_acc 52.75  train_loss 0.37443
| epoch 10 | 2400/2420 batches | train_acc 53.67  train_loss 0.38868
---------------------------------------------------------------------
| epoch 10 | time: 74.10s | valid_acc 0.558 valid_loss 0.367 | lr 0.010000
---------------------------------------------------------------------

4.2.模型评估

test_acc,test_loss=evaluate(valid_dataloader)
print('模型准确率为:{:5.4f}'.format(test_acc))

输出:

模型准确率为:0.5541

4.3.准确率提升

       由于只进行了仅仅10个轮次的训练,准确率只有55.8%,故再次进行30轮次的训练,准确率结果得到提升,达到75%,误差下降至0.228.由于时间关系,没有再进行下一步的训练。

| epoch 1 |  300/2420 batches | train_acc 54.49  train_loss 0.36607
| epoch 1 |  600/2420 batches | train_acc 54.67  train_loss 0.36844
| epoch 1 |  900/2420 batches | train_acc 54.92  train_loss 0.36920
| epoch 1 | 1200/2420 batches | train_acc 55.58  train_loss 0.37345
| epoch 1 | 1500/2420 batches | train_acc 57.75  train_loss 0.35501
| epoch 1 | 1800/2420 batches | train_acc 56.17  train_loss 0.37109
| epoch 1 | 2100/2420 batches | train_acc 55.67  train_loss 0.35130
| epoch 1 | 2400/2420 batches | train_acc 59.08  train_loss 0.34572
---------------------------------------------------------------------
| epoch 1 | time: 73.36s | valid_acc 0.566 valid_loss 0.368 | lr 0.010000
---------------------------------------------------------------------
| epoch 2 |  300/2420 batches | train_acc 56.73  train_loss 0.35899
| epoch 2 |  600/2420 batches | train_acc 57.83  train_loss 0.35252
| epoch 2 |  900/2420 batches | train_acc 60.42  train_loss 0.32902
| epoch 2 | 1200/2420 batches | train_acc 59.92  train_loss 0.33609
| epoch 2 | 1500/2420 batches | train_acc 62.17  train_loss 0.33100
| epoch 2 | 1800/2420 batches | train_acc 62.33  train_loss 0.31901
| epoch 2 | 2100/2420 batches | train_acc 62.08  train_loss 0.33091
| epoch 2 | 2400/2420 batches | train_acc 58.33  train_loss 0.33574
---------------------------------------------------------------------
| epoch 2 | time: 73.68s | valid_acc 0.624 valid_loss 0.314 | lr 0.010000
---------------------------------------------------------------------
| epoch 3 |  300/2420 batches | train_acc 63.37  train_loss 0.31629
| epoch 3 |  600/2420 batches | train_acc 60.75  train_loss 0.32827
| epoch 3 |  900/2420 batches | train_acc 61.92  train_loss 0.31922
| epoch 3 | 1200/2420 batches | train_acc 64.50  train_loss 0.31255
| epoch 3 | 1500/2420 batches | train_acc 64.42  train_loss 0.31176
| epoch 3 | 1800/2420 batches | train_acc 65.42  train_loss 0.29570
| epoch 3 | 2100/2420 batches | train_acc 62.50  train_loss 0.30939
| epoch 3 | 2400/2420 batches | train_acc 64.00  train_loss 0.30784
---------------------------------------------------------------------
| epoch 3 | time: 72.58s | valid_acc 0.663 valid_loss 0.291 | lr 0.010000
---------------------------------------------------------------------
| epoch 4 |  300/2420 batches | train_acc 66.36  train_loss 0.28185
| epoch 4 |  600/2420 batches | train_acc 63.83  train_loss 0.30157
| epoch 4 |  900/2420 batches | train_acc 67.17  train_loss 0.28690
| epoch 4 | 1200/2420 batches | train_acc 64.42  train_loss 0.29970
| epoch 4 | 1500/2420 batches | train_acc 64.17  train_loss 0.30939
| epoch 4 | 1800/2420 batches | train_acc 68.25  train_loss 0.28026
| epoch 4 | 2100/2420 batches | train_acc 65.83  train_loss 0.29020
| epoch 4 | 2400/2420 batches | train_acc 66.75  train_loss 0.28798
---------------------------------------------------------------------
| epoch 4 | time: 72.30s | valid_acc 0.636 valid_loss 0.306 | lr 0.010000
---------------------------------------------------------------------
| epoch 5 |  300/2420 batches | train_acc 67.03  train_loss 0.28070
| epoch 5 |  600/2420 batches | train_acc 65.92  train_loss 0.28669
| epoch 5 |  900/2420 batches | train_acc 67.08  train_loss 0.27294
| epoch 5 | 1200/2420 batches | train_acc 70.08  train_loss 0.26934
| epoch 5 | 1500/2420 batches | train_acc 67.75  train_loss 0.27605
| epoch 5 | 1800/2420 batches | train_acc 68.17  train_loss 0.26254
| epoch 5 | 2100/2420 batches | train_acc 66.08  train_loss 0.28709
| epoch 5 | 2400/2420 batches | train_acc 68.08  train_loss 0.26833
---------------------------------------------------------------------
| epoch 5 | time: 72.30s | valid_acc 0.633 valid_loss 0.319 | lr 0.010000
---------------------------------------------------------------------
| epoch 6 |  300/2420 batches | train_acc 67.94  train_loss 0.26281
| epoch 6 |  600/2420 batches | train_acc 69.92  train_loss 0.25797
| epoch 6 |  900/2420 batches | train_acc 70.42  train_loss 0.26665
| epoch 6 | 1200/2420 batches | train_acc 69.25  train_loss 0.26434
| epoch 6 | 1500/2420 batches | train_acc 70.00  train_loss 0.26321
| epoch 6 | 1800/2420 batches | train_acc 69.50  train_loss 0.26537
| epoch 6 | 2100/2420 batches | train_acc 71.83  train_loss 0.25851
| epoch 6 | 2400/2420 batches | train_acc 69.83  train_loss 0.26199
---------------------------------------------------------------------
| epoch 6 | time: 72.17s | valid_acc 0.689 valid_loss 0.267 | lr 0.010000
---------------------------------------------------------------------
| epoch 7 |  300/2420 batches | train_acc 71.76  train_loss 0.25320
| epoch 7 |  600/2420 batches | train_acc 72.00  train_loss 0.24158
| epoch 7 |  900/2420 batches | train_acc 70.25  train_loss 0.26129
| epoch 7 | 1200/2420 batches | train_acc 69.83  train_loss 0.24659
| epoch 7 | 1500/2420 batches | train_acc 69.67  train_loss 0.25484
| epoch 7 | 1800/2420 batches | train_acc 71.25  train_loss 0.24846
| epoch 7 | 2100/2420 batches | train_acc 71.42  train_loss 0.24982
| epoch 7 | 2400/2420 batches | train_acc 70.67  train_loss 0.25304
---------------------------------------------------------------------
| epoch 7 | time: 71.04s | valid_acc 0.675 valid_loss 0.280 | lr 0.010000
---------------------------------------------------------------------
| epoch 8 |  300/2420 batches | train_acc 73.26  train_loss 0.22670
| epoch 8 |  600/2420 batches | train_acc 72.83  train_loss 0.23569
| epoch 8 |  900/2420 batches | train_acc 74.00  train_loss 0.23236
| epoch 8 | 1200/2420 batches | train_acc 73.42  train_loss 0.23499
| epoch 8 | 1500/2420 batches | train_acc 71.67  train_loss 0.24370
| epoch 8 | 1800/2420 batches | train_acc 74.33  train_loss 0.22888
| epoch 8 | 2100/2420 batches | train_acc 73.67  train_loss 0.23660
| epoch 8 | 2400/2420 batches | train_acc 72.50  train_loss 0.24079
---------------------------------------------------------------------
| epoch 8 | time: 70.63s | valid_acc 0.680 valid_loss 0.275 | lr 0.010000
---------------------------------------------------------------------
| epoch 9 |  300/2420 batches | train_acc 74.17  train_loss 0.22665
| epoch 9 |  600/2420 batches | train_acc 73.67  train_loss 0.23029
| epoch 9 |  900/2420 batches | train_acc 74.17  train_loss 0.22595
| epoch 9 | 1200/2420 batches | train_acc 73.58  train_loss 0.23056
| epoch 9 | 1500/2420 batches | train_acc 73.50  train_loss 0.23499
| epoch 9 | 1800/2420 batches | train_acc 76.00  train_loss 0.22081
| epoch 9 | 2100/2420 batches | train_acc 72.92  train_loss 0.23259
| epoch 9 | 2400/2420 batches | train_acc 72.08  train_loss 0.22822
---------------------------------------------------------------------
| epoch 9 | time: 71.16s | valid_acc 0.698 valid_loss 0.253 | lr 0.010000
---------------------------------------------------------------------
| epoch 10 |  300/2420 batches | train_acc 74.00  train_loss 0.22515
| epoch 10 |  600/2420 batches | train_acc 76.25  train_loss 0.21451
| epoch 10 |  900/2420 batches | train_acc 74.58  train_loss 0.21628
| epoch 10 | 1200/2420 batches | train_acc 75.83  train_loss 0.20905
| epoch 10 | 1500/2420 batches | train_acc 74.42  train_loss 0.22759
| epoch 10 | 1800/2420 batches | train_acc 75.42  train_loss 0.20945
| epoch 10 | 2100/2420 batches | train_acc 75.50  train_loss 0.21331
| epoch 10 | 2400/2420 batches | train_acc 72.25  train_loss 0.23877
---------------------------------------------------------------------
| epoch 10 | time: 70.13s | valid_acc 0.725 valid_loss 0.241 | lr 0.010000
---------------------------------------------------------------------
| epoch 11 |  300/2420 batches | train_acc 74.75  train_loss 0.21838
| epoch 11 |  600/2420 batches | train_acc 78.08  train_loss 0.20571
| epoch 11 |  900/2420 batches | train_acc 76.42  train_loss 0.20334
| epoch 11 | 1200/2420 batches | train_acc 77.17  train_loss 0.20223
| epoch 11 | 1500/2420 batches | train_acc 74.17  train_loss 0.22616
| epoch 11 | 1800/2420 batches | train_acc 74.75  train_loss 0.21948
| epoch 11 | 2100/2420 batches | train_acc 78.42  train_loss 0.19643
| epoch 11 | 2400/2420 batches | train_acc 76.25  train_loss 0.20678
---------------------------------------------------------------------
| epoch 11 | time: 69.36s | valid_acc 0.648 valid_loss 0.325 | lr 0.010000
---------------------------------------------------------------------
| epoch 12 |  300/2420 batches | train_acc 75.33  train_loss 0.20955
| epoch 12 |  600/2420 batches | train_acc 75.92  train_loss 0.20972
| epoch 12 |  900/2420 batches | train_acc 75.92  train_loss 0.20551
| epoch 12 | 1200/2420 batches | train_acc 77.33  train_loss 0.19269
| epoch 12 | 1500/2420 batches | train_acc 76.58  train_loss 0.19830
| epoch 12 | 1800/2420 batches | train_acc 78.33  train_loss 0.19465
| epoch 12 | 2100/2420 batches | train_acc 76.92  train_loss 0.19649
| epoch 12 | 2400/2420 batches | train_acc 77.17  train_loss 0.20489
---------------------------------------------------------------------
| epoch 12 | time: 69.44s | valid_acc 0.714 valid_loss 0.247 | lr 0.010000
---------------------------------------------------------------------
| epoch 13 |  300/2420 batches | train_acc 79.90  train_loss 0.18461
| epoch 13 |  600/2420 batches | train_acc 77.50  train_loss 0.19824
| epoch 13 |  900/2420 batches | train_acc 80.42  train_loss 0.18300
| epoch 13 | 1200/2420 batches | train_acc 76.08  train_loss 0.20041
| epoch 13 | 1500/2420 batches | train_acc 76.58  train_loss 0.19852
| epoch 13 | 1800/2420 batches | train_acc 76.67  train_loss 0.20634
| epoch 13 | 2100/2420 batches | train_acc 77.50  train_loss 0.20142
| epoch 13 | 2400/2420 batches | train_acc 75.50  train_loss 0.21208
---------------------------------------------------------------------
| epoch 13 | time: 69.42s | valid_acc 0.717 valid_loss 0.242 | lr 0.010000
---------------------------------------------------------------------
| epoch 14 |  300/2420 batches | train_acc 79.90  train_loss 0.17942
| epoch 14 |  600/2420 batches | train_acc 77.83  train_loss 0.19173
| epoch 14 |  900/2420 batches | train_acc 78.67  train_loss 0.18850
| epoch 14 | 1200/2420 batches | train_acc 80.08  train_loss 0.17583
| epoch 14 | 1500/2420 batches | train_acc 77.42  train_loss 0.19312
| epoch 14 | 1800/2420 batches | train_acc 78.50  train_loss 0.18778
| epoch 14 | 2100/2420 batches | train_acc 77.67  train_loss 0.19237
| epoch 14 | 2400/2420 batches | train_acc 77.25  train_loss 0.19629
---------------------------------------------------------------------
| epoch 14 | time: 69.33s | valid_acc 0.737 valid_loss 0.248 | lr 0.010000
---------------------------------------------------------------------
| epoch 15 |  300/2420 batches | train_acc 80.07  train_loss 0.17568
| epoch 15 |  600/2420 batches | train_acc 77.92  train_loss 0.18195
| epoch 15 |  900/2420 batches | train_acc 78.92  train_loss 0.18959
| epoch 15 | 1200/2420 batches | train_acc 77.83  train_loss 0.18873
| epoch 15 | 1500/2420 batches | train_acc 79.08  train_loss 0.17979
| epoch 15 | 1800/2420 batches | train_acc 80.17  train_loss 0.18310
| epoch 15 | 2100/2420 batches | train_acc 79.83  train_loss 0.17387
| epoch 15 | 2400/2420 batches | train_acc 77.50  train_loss 0.19704
---------------------------------------------------------------------
| epoch 15 | time: 69.73s | valid_acc 0.730 valid_loss 0.239 | lr 0.010000
---------------------------------------------------------------------
| epoch 16 |  300/2420 batches | train_acc 81.06  train_loss 0.16741
| epoch 16 |  600/2420 batches | train_acc 80.83  train_loss 0.16927
| epoch 16 |  900/2420 batches | train_acc 81.42  train_loss 0.16229
| epoch 16 | 1200/2420 batches | train_acc 78.67  train_loss 0.18782
| epoch 16 | 1500/2420 batches | train_acc 77.42  train_loss 0.19803
| epoch 16 | 1800/2420 batches | train_acc 81.33  train_loss 0.17238
| epoch 16 | 2100/2420 batches | train_acc 79.33  train_loss 0.18567
| epoch 16 | 2400/2420 batches | train_acc 78.25  train_loss 0.18489
---------------------------------------------------------------------
| epoch 16 | time: 69.80s | valid_acc 0.724 valid_loss 0.236 | lr 0.010000
---------------------------------------------------------------------
| epoch 17 |  300/2420 batches | train_acc 81.48  train_loss 0.16361
| epoch 17 |  600/2420 batches | train_acc 80.00  train_loss 0.18248
| epoch 17 |  900/2420 batches | train_acc 82.67  train_loss 0.15687
| epoch 17 | 1200/2420 batches | train_acc 78.50  train_loss 0.18885
| epoch 17 | 1500/2420 batches | train_acc 80.33  train_loss 0.16683
| epoch 17 | 1800/2420 batches | train_acc 79.67  train_loss 0.18124
| epoch 17 | 2100/2420 batches | train_acc 80.17  train_loss 0.17808
| epoch 17 | 2400/2420 batches | train_acc 79.42  train_loss 0.18256
---------------------------------------------------------------------
| epoch 17 | time: 69.36s | valid_acc 0.748 valid_loss 0.224 | lr 0.010000
---------------------------------------------------------------------
| epoch 18 |  300/2420 batches | train_acc 83.55  train_loss 0.15356
| epoch 18 |  600/2420 batches | train_acc 80.92  train_loss 0.16303
| epoch 18 |  900/2420 batches | train_acc 79.58  train_loss 0.17592
| epoch 18 | 1200/2420 batches | train_acc 81.08  train_loss 0.16351
| epoch 18 | 1500/2420 batches | train_acc 81.17  train_loss 0.17023
| epoch 18 | 1800/2420 batches | train_acc 80.50  train_loss 0.17150
| epoch 18 | 2100/2420 batches | train_acc 81.08  train_loss 0.17239
| epoch 18 | 2400/2420 batches | train_acc 78.50  train_loss 0.19141
---------------------------------------------------------------------
| epoch 18 | time: 68.89s | valid_acc 0.737 valid_loss 0.232 | lr 0.010000
---------------------------------------------------------------------
| epoch 19 |  300/2420 batches | train_acc 82.56  train_loss 0.15452
| epoch 19 |  600/2420 batches | train_acc 83.75  train_loss 0.14618
| epoch 19 |  900/2420 batches | train_acc 83.08  train_loss 0.15686
| epoch 19 | 1200/2420 batches | train_acc 81.92  train_loss 0.15739
| epoch 19 | 1500/2420 batches | train_acc 81.42  train_loss 0.16449
| epoch 19 | 1800/2420 batches | train_acc 83.50  train_loss 0.14977
| epoch 19 | 2100/2420 batches | train_acc 81.92  train_loss 0.16178
| epoch 19 | 2400/2420 batches | train_acc 80.50  train_loss 0.16922
---------------------------------------------------------------------
| epoch 19 | time: 68.77s | valid_acc 0.751 valid_loss 0.234 | lr 0.010000
---------------------------------------------------------------------
| epoch 20 |  300/2420 batches | train_acc 81.40  train_loss 0.16502
| epoch 20 |  600/2420 batches | train_acc 83.33  train_loss 0.14418
| epoch 20 |  900/2420 batches | train_acc 81.00  train_loss 0.16009
| epoch 20 | 1200/2420 batches | train_acc 83.25  train_loss 0.16221
| epoch 20 | 1500/2420 batches | train_acc 80.42  train_loss 0.17046
| epoch 20 | 1800/2420 batches | train_acc 82.42  train_loss 0.15513
| epoch 20 | 2100/2420 batches | train_acc 83.50  train_loss 0.14309
| epoch 20 | 2400/2420 batches | train_acc 83.58  train_loss 0.15196
---------------------------------------------------------------------
| epoch 20 | time: 69.46s | valid_acc 0.749 valid_loss 0.244 | lr 0.010000
---------------------------------------------------------------------
| epoch 21 |  300/2420 batches | train_acc 84.22  train_loss 0.13575
| epoch 21 |  600/2420 batches | train_acc 82.67  train_loss 0.15082
| epoch 21 |  900/2420 batches | train_acc 81.50  train_loss 0.16861
| epoch 21 | 1200/2420 batches | train_acc 82.50  train_loss 0.15516
| epoch 21 | 1500/2420 batches | train_acc 82.42  train_loss 0.15020
| epoch 21 | 1800/2420 batches | train_acc 82.33  train_loss 0.15762
| epoch 21 | 2100/2420 batches | train_acc 82.67  train_loss 0.15195
| epoch 21 | 2400/2420 batches | train_acc 82.17  train_loss 0.15492
---------------------------------------------------------------------
| epoch 21 | time: 70.53s | valid_acc 0.754 valid_loss 0.229 | lr 0.010000
---------------------------------------------------------------------
| epoch 22 |  300/2420 batches | train_acc 82.31  train_loss 0.15517
| epoch 22 |  600/2420 batches | train_acc 84.92  train_loss 0.13876
| epoch 22 |  900/2420 batches | train_acc 81.75  train_loss 0.15917
| epoch 22 | 1200/2420 batches | train_acc 84.92  train_loss 0.13322
| epoch 22 | 1500/2420 batches | train_acc 83.08  train_loss 0.14900
| epoch 22 | 1800/2420 batches | train_acc 81.42  train_loss 0.15705
| epoch 22 | 2100/2420 batches | train_acc 80.83  train_loss 0.16701
| epoch 22 | 2400/2420 batches | train_acc 83.08  train_loss 0.15327
---------------------------------------------------------------------
| epoch 22 | time: 69.65s | valid_acc 0.724 valid_loss 0.260 | lr 0.010000
---------------------------------------------------------------------
| epoch 23 |  300/2420 batches | train_acc 84.97  train_loss 0.14119
| epoch 23 |  600/2420 batches | train_acc 81.58  train_loss 0.15883
| epoch 23 |  900/2420 batches | train_acc 83.75  train_loss 0.14659
| epoch 23 | 1200/2420 batches | train_acc 82.75  train_loss 0.14945
| epoch 23 | 1500/2420 batches | train_acc 83.42  train_loss 0.14179
| epoch 23 | 1800/2420 batches | train_acc 82.83  train_loss 0.15463
| epoch 23 | 2100/2420 batches | train_acc 84.75  train_loss 0.13010
| epoch 23 | 2400/2420 batches | train_acc 83.67  train_loss 0.13842
---------------------------------------------------------------------
| epoch 23 | time: 69.79s | valid_acc 0.746 valid_loss 0.243 | lr 0.010000
---------------------------------------------------------------------
| epoch 24 |  300/2420 batches | train_acc 84.80  train_loss 0.13544
| epoch 24 |  600/2420 batches | train_acc 85.75  train_loss 0.12646
| epoch 24 |  900/2420 batches | train_acc 86.25  train_loss 0.12283
| epoch 24 | 1200/2420 batches | train_acc 82.42  train_loss 0.14701
| epoch 24 | 1500/2420 batches | train_acc 82.25  train_loss 0.14768
| epoch 24 | 1800/2420 batches | train_acc 84.17  train_loss 0.14072
| epoch 24 | 2100/2420 batches | train_acc 83.33  train_loss 0.14110
| epoch 24 | 2400/2420 batches | train_acc 82.67  train_loss 0.14679
---------------------------------------------------------------------
| epoch 24 | time: 68.95s | valid_acc 0.734 valid_loss 0.267 | lr 0.010000
---------------------------------------------------------------------
| epoch 25 |  300/2420 batches | train_acc 84.39  train_loss 0.13833
| epoch 25 |  600/2420 batches | train_acc 83.83  train_loss 0.13872
| epoch 25 |  900/2420 batches | train_acc 85.17  train_loss 0.13148
| epoch 25 | 1200/2420 batches | train_acc 83.58  train_loss 0.14485
| epoch 25 | 1500/2420 batches | train_acc 84.92  train_loss 0.13479
| epoch 25 | 1800/2420 batches | train_acc 86.08  train_loss 0.13403
| epoch 25 | 2100/2420 batches | train_acc 82.50  train_loss 0.14975
| epoch 25 | 2400/2420 batches | train_acc 83.17  train_loss 0.15161
---------------------------------------------------------------------
| epoch 25 | time: 68.88s | valid_acc 0.726 valid_loss 0.242 | lr 0.010000
---------------------------------------------------------------------
| epoch 26 |  300/2420 batches | train_acc 85.96  train_loss 0.11894
| epoch 26 |  600/2420 batches | train_acc 84.58  train_loss 0.13726
| epoch 26 |  900/2420 batches | train_acc 84.50  train_loss 0.13487
| epoch 26 | 1200/2420 batches | train_acc 84.58  train_loss 0.14240
| epoch 26 | 1500/2420 batches | train_acc 86.83  train_loss 0.11572
| epoch 26 | 1800/2420 batches | train_acc 85.00  train_loss 0.13577
| epoch 26 | 2100/2420 batches | train_acc 85.25  train_loss 0.12339
| epoch 26 | 2400/2420 batches | train_acc 85.33  train_loss 0.13181
---------------------------------------------------------------------
| epoch 26 | time: 68.36s | valid_acc 0.732 valid_loss 0.263 | lr 0.010000
---------------------------------------------------------------------
| epoch 27 |  300/2420 batches | train_acc 86.05  train_loss 0.12249
| epoch 27 |  600/2420 batches | train_acc 85.92  train_loss 0.12441
| epoch 27 |  900/2420 batches | train_acc 86.58  train_loss 0.12297
| epoch 27 | 1200/2420 batches | train_acc 82.25  train_loss 0.14503
| epoch 27 | 1500/2420 batches | train_acc 84.67  train_loss 0.13531
| epoch 27 | 1800/2420 batches | train_acc 83.75  train_loss 0.13247
| epoch 27 | 2100/2420 batches | train_acc 84.42  train_loss 0.13053
| epoch 27 | 2400/2420 batches | train_acc 85.08  train_loss 0.13622
---------------------------------------------------------------------
| epoch 27 | time: 69.02s | valid_acc 0.757 valid_loss 0.236 | lr 0.010000
---------------------------------------------------------------------
| epoch 28 |  300/2420 batches | train_acc 87.46  train_loss 0.11854
| epoch 28 |  600/2420 batches | train_acc 85.33  train_loss 0.12381
| epoch 28 |  900/2420 batches | train_acc 86.50  train_loss 0.11198
| epoch 28 | 1200/2420 batches | train_acc 84.67  train_loss 0.13816
| epoch 28 | 1500/2420 batches | train_acc 85.33  train_loss 0.12409
| epoch 28 | 1800/2420 batches | train_acc 86.33  train_loss 0.11955
| epoch 28 | 2100/2420 batches | train_acc 86.17  train_loss 0.12414
| epoch 28 | 2400/2420 batches | train_acc 84.17  train_loss 0.14352
---------------------------------------------------------------------
| epoch 28 | time: 71.55s | valid_acc 0.699 valid_loss 0.344 | lr 0.010000
---------------------------------------------------------------------
| epoch 29 |  300/2420 batches | train_acc 87.38  train_loss 0.11934
| epoch 29 |  600/2420 batches | train_acc 85.08  train_loss 0.12648
| epoch 29 |  900/2420 batches | train_acc 85.17  train_loss 0.12210
| epoch 29 | 1200/2420 batches | train_acc 86.67  train_loss 0.11674
| epoch 29 | 1500/2420 batches | train_acc 82.58  train_loss 0.13760
| epoch 29 | 1800/2420 batches | train_acc 86.00  train_loss 0.12176
| epoch 29 | 2100/2420 batches | train_acc 86.08  train_loss 0.12268
| epoch 29 | 2400/2420 batches | train_acc 88.25  train_loss 0.10436
---------------------------------------------------------------------
| epoch 29 | time: 69.59s | valid_acc 0.737 valid_loss 0.261 | lr 0.010000
---------------------------------------------------------------------
| epoch 30 |  300/2420 batches | train_acc 84.80  train_loss 0.12498
| epoch 30 |  600/2420 batches | train_acc 86.92  train_loss 0.11238
| epoch 30 |  900/2420 batches | train_acc 87.50  train_loss 0.11427
| epoch 30 | 1200/2420 batches | train_acc 87.00  train_loss 0.11437
| epoch 30 | 1500/2420 batches | train_acc 86.17  train_loss 0.11522
| epoch 30 | 1800/2420 batches | train_acc 86.83  train_loss 0.11858
| epoch 30 | 2100/2420 batches | train_acc 86.17  train_loss 0.11800
| epoch 30 | 2400/2420 batches | train_acc 86.17  train_loss 0.12578
---------------------------------------------------------------------
| epoch 30 | time: 68.85s | valid_acc 0.750 valid_loss 0.228 | lr 0.010000
---------------------------------------------------------------------

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值