import torch
import torchvision
from torch import optim,nn
import visdom
from torch.utils.data import Dataset,DataLoader
# 从其他函数里调用预处理和网络结构
from main import Pokemon
from model import ResNet18
batch_sz=32
lr=1e-3
epochs=10
device = torch.device('cuda')
# 下面这句保证实验可以很好的复现
torch.manual_seed(1234)
# 把数据集加载进来
# 新建了三个db对象
train_db = Pokemon('pokeman',224,mode='train')
valid_db = Pokemon('pokeman',224,mode='valid')
test_db = Pokemon('pokeman',224,mode='test')
# 新建三个loader对象
train_loader = DataLoader(train_db,batch_size=batch_sz,shuffle=True,num_workers=4)
valid_loader = DataLoader(valid_db,batch_size=batch_sz,shuffle=False,num_workers=2)
test_loader = DataLoader(test_db,batch_size=batch_sz,shuffle=False,num_workers=2)
viz = visdom.Visdom()
# valid的准确度测试
def evalute(model,loader):
# 这个测试针对valid和test是一样的功能,都可以用
correct = 0
total = len(loader.dataset)
for x,y in loader:
x,y=x.to(device),y.to(device)
# 做一个前向运算,这里不需要反向传播,所以放到一个无梯度的环境里
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1)
correct += torch.eq(pred,y).sum().float().item()
return correct / total
def main():
# 创建train的逻辑
# 首先创建模型
model = ResNet18().to(device)
optimizer = optim.Adam(model.parameters(),lr=lr)
criteon = nn.CrossEntropyLoss() # 这个不需要softmax
# 预设的状态
best_acc,best_epoch=0,0
# 全局step
global_step = 0
viz.line([0],[-1],win='loss',opts=dict(title='loss'))
viz.line([0],[-1],win='valid_acc',opts=dict(title='valid_acc'))
# 创建train的逻辑
for epoch in range(epochs):
for step,(x,y) in enumerate(train_loader):
# x:[b,3,224,224] y:[b]
x,y=x.to(device),y.to(device)
# 放入模型中
logits=model(x)
loss = criteon(logits,y)
# 得到loss后,做一个优化器
optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([loss.item()],[global_step],win='loss',update='append')
global_step += 1
# 约定一段时间做一个valid,比如两个epoch做一次
# 需要做一个valid的准确度测试,见函数def evalute(model,db)
if epochs % 2 == 0:
valid_acc = evalute(model,valid_loader)
# 根据上述得到的valid_acc,选择要不要保存这个状态,所以需要预设一个状态
if valid_acc > best_acc:
best_epoch = epoch
best_acc = valid_acc
# 要把这个模型保存起来,因为下次要从最好的模型上加载状态值
torch.save(model.state_dict(),'best.model')
viz.line([valid_acc],[global_step],win='valid_acc',update='append')
print('best acc:',best_acc,'best_epoch:',best_epoch)
# 然后从最好的状态加载这个模型
model.load_state_dict(torch.load('best.model'))
print('loaded from ckpt!')
test_acc = evalute(model,test_loader)
print('test acc',test_acc)
# 到这里我们只是保存下了最好的模型,并做了测试,还需要把train的进度保存下来,用visdom,见P32,59
if __name__ == '__main__':
main()
训练过程中,选中一个最好的模型,用来测试。
结果如下: