深度学习实战案例一:第四课:开始训练,验证,测试

python可以相互在同一个文件夹下面的其他的py文件,因此我们调用设计好的data_loader.py的

Pokemon类,和resnet.py的ResNet18类,
from    data_loader import Pokemon
from    resnet import ResNet18

下面是全部代码

import  torch
from    torch import optim, nn
import  visdom
import  torchvision
from    torch.utils.data import DataLoader

from    data_loader import Pokemon
from    resnet import ResNet18

batchsz = 32
picture_resize = 224
lr = 1e-3
epochs = 10

device = torch.device('cuda')
torch.manual_seed(1234)


train_db = Pokemon('pokemon', picture_resize, mode='train')
val_db = Pokemon('pokemon', picture_resize, mode='val')
test_db = Pokemon('pokemon', picture_resize, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,
                          num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)

viz = visdom.Visdom()


def evalute(model, loader):
    model.eval()

    correct = 0
    total = len(loader.dataset)

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()

    return correct / total

def main():
    model = ResNet18(5).to(device)
    optimizer = optim.Adam(model.parameters(),lr=lr)
    criteon = nn.CrossEntropyLoss()

    best_acc,best_epoch =0, 0
    global_step = 0
    viz.line([0],[-1], win='loss',opts = dict(title='loss'))
    viz.line([0],[-1], win='val_acc', opts=dict(title='val_acc'))
    for epoch in range(epochs):

        for step, (x,y) in enumerate(train_loader):
            #x:[b,3,224,224],y :[b]
            x,y = x.to(device),y.to(device)

            model.train()
            logits = model(x)
            loss = criteon(logits,y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            viz.line([loss.item()], [global_step], win='loss', update = 'append')
            global_step+=1

        if epoch %1 ==0:
            val_acc = evalute(model,val_loader)
            if val_acc > best_acc:
                best_epoch =epoch
                best_acc =val_acc

                torch.save(model.state_dict(),'best.mdl')
                viz.line([val_acc], [global_step], win='val_acc', update='append')

    print('best acc:',best_acc,'best epoch:',best_epoch)

    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')

    test_acc = evalute(model,test_loader)
    print('test acc:', test_acc)



if __name__=='__main__':
    main()

 

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值