神奇宝贝五分类:train and test

本文介绍使用PyTorch框架进行图像分类任务的过程,包括数据集加载、模型定义、训练循环及模型验证等关键步骤,并展示了如何利用Visdom进行训练过程的可视化。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

import torch
import torchvision
from torch import optim,nn
import visdom
from torch.utils.data import Dataset,DataLoader

# 从其他函数里调用预处理和网络结构
from main import Pokemon
from model import ResNet18


batch_sz=32
lr=1e-3
epochs=10

device = torch.device('cuda')
# 下面这句保证实验可以很好的复现
torch.manual_seed(1234)


# 把数据集加载进来
# 新建了三个db对象
train_db = Pokemon('pokeman',224,mode='train')
valid_db = Pokemon('pokeman',224,mode='valid')
test_db = Pokemon('pokeman',224,mode='test')
# 新建三个loader对象
train_loader = DataLoader(train_db,batch_size=batch_sz,shuffle=True,num_workers=4)
valid_loader = DataLoader(valid_db,batch_size=batch_sz,shuffle=False,num_workers=2)
test_loader = DataLoader(test_db,batch_size=batch_sz,shuffle=False,num_workers=2)


viz = visdom.Visdom()

# valid的准确度测试
def evalute(model,loader):
    # 这个测试针对valid和test是一样的功能,都可以用
    correct = 0
    total = len(loader.dataset)

    for x,y in loader:

        x,y=x.to(device),y.to(device)
        # 做一个前向运算,这里不需要反向传播,所以放到一个无梯度的环境里
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred,y).sum().float().item()

    return correct / total

def main():
    # 创建train的逻辑

    # 首先创建模型
    model = ResNet18().to(device)
    optimizer = optim.Adam(model.parameters(),lr=lr)
    criteon = nn.CrossEntropyLoss()  # 这个不需要softmax

    # 预设的状态
    best_acc,best_epoch=0,0

    # 全局step
    global_step = 0
    viz.line([0],[-1],win='loss',opts=dict(title='loss'))
    viz.line([0],[-1],win='valid_acc',opts=dict(title='valid_acc'))

    # 创建train的逻辑
    for epoch in range(epochs):

        for step,(x,y) in enumerate(train_loader):

            # x:[b,3,224,224] y:[b]
            x,y=x.to(device),y.to(device)

            # 放入模型中
            logits=model(x)
            loss = criteon(logits,y)

            # 得到loss后,做一个优化器
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            viz.line([loss.item()],[global_step],win='loss',update='append')
            global_step += 1

            # 约定一段时间做一个valid,比如两个epoch做一次
            # 需要做一个valid的准确度测试,见函数def evalute(model,db)
        if epochs % 2 == 0:

            valid_acc = evalute(model,valid_loader)
            # 根据上述得到的valid_acc,选择要不要保存这个状态,所以需要预设一个状态
            if valid_acc > best_acc:
                best_epoch = epoch
                best_acc = valid_acc

                # 要把这个模型保存起来,因为下次要从最好的模型上加载状态值
                torch.save(model.state_dict(),'best.model')

                viz.line([valid_acc],[global_step],win='valid_acc',update='append')


    print('best acc:',best_acc,'best_epoch:',best_epoch)

    # 然后从最好的状态加载这个模型
    model.load_state_dict(torch.load('best.model'))
    print('loaded from ckpt!')

    test_acc = evalute(model,test_loader)
    print('test acc',test_acc)
    # 到这里我们只是保存下了最好的模型,并做了测试,还需要把train的进度保存下来,用visdom,见P32,59




if __name__ == '__main__':
    main()

训练过程中,选中一个最好的模型,用来测试。

结果如下:

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值