- 可以看这里,比较与tensorflow版本的区别
- 在输入格式方面,pytorch是NCHW,tensorflow是NHWC
- 网络返回
log_softmax
时,应该使用nll_loss
- MyData.py
import os
import cv2
import random
import numpy as np
from torch.utils.data.dataset import Dataset
class MnistDataset(Dataset):
def __init__(self, dataset_path):
with open(dataset_path) as f:
self.all = f.readlines()
self.len = len(self.all)
def __len__(self):
return self.len
def __getitem__(self, index):
index=index%self.len
sp=self.all[index].split()
im=cv2.imread(sp[0]).astype(np.float32) / 255.0
im=np.transpose(im, (2,0,1))
lab=int(sp[1])
return im, lab
def collate(batch):
images = []
bboxes = []
for img, box in batch:
images.append(img)
bboxes.append(box)
images = np.array(images)
return images, bboxes
import numpy as np
import torch
from torch.utils.data import DataLoader
import MyData
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import os,sys
class Mnist(nn.Module):
def __init__(self):
super(Mnist, self).__init__()
self.conv_list = nn.ModuleList()
self.conv_list.append(nn.Conv2d(3, 6, 5, 1, 2))
self.conv_list.append(nn.MaxPool2d(2))
self.conv_list.append(nn.Conv2d(6, 16, 5, 1))
self.conv_list.append(nn.MaxPool2d(2))
self.linear_list = nn.ModuleList()
self.linear_list.append(nn.Linear(400, 200))
self.linear_list.append(nn.Linear(200, 10))
def forward(self, x):
for m in self.conv_list:
x = m(x)
x=x.contiguous().view(-1, 400)
x=self.linear_list[0](x)
x = F.relu(x)
x=self.linear_list[1](x)
return F.log_softmax(x, dim=1)
cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if cuda else 'cpu')
i=0
net1 = Mnist()
if os.path.exists('checkpoints/first.pt'):
checkpoint = torch.load('checkpoints/first.pt', map_location='cpu')
net1.load_state_dict(checkpoint['model'])
net1.to(device).train()
optimizer = optim.SGD(net1.parameters(), lr=1e-4)
optimizer.load_state_dict(checkpoint['optimizer'])
i=checkpoint['epoch'] + 1
del checkpoint
else:
optimizer = optim.SGD(net1.parameters(), lr=1e-4, momentum=0.99)
net1.to(device).train()
train_dataset = MyData.MnistDataset('/home/lwd/data/mnist/train.txt')
train_data = DataLoader(train_dataset, shuffle = True, batch_size = 32, num_workers = 4, pin_memory=True,drop_last=True)
test_dataset = MyData.MnistDataset('/home/lwd/data/mnist/test.txt')
test_data = DataLoader(test_dataset, shuffle = False, batch_size = 4, num_workers = 4, pin_memory=True,drop_last=True)
while i < 100:
net1.train()
for batch_idx, (batch_image, batch_label) in enumerate(train_data):
optimizer.zero_grad()
output = net1(batch_image.to(device))
loss = F.nll_loss(output, batch_label.to(device))
loss.backward()
optimizer.step()
if i == 0:
checkpoint = {'epoch': i,
'loss': loss,
'model': net1.state_dict(),
'optimizer': optimizer.state_dict()}
torch.save(checkpoint, 'checkpoints/first.pt')
break
net1.eval()
test_loss = 0
correct = 0
cnt = 0
with torch.no_grad():
for batch_idx, (batch_image, batch_label) in enumerate(test_data):
output = net1(batch_image.to(device))
target=batch_label.to(device)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum().item()
cnt += 1
test_loss /= cnt*test_data.batch_size
correct/= cnt*test_data.batch_size
print(test_loss, correct)
if test_loss < 0.06:
checkpoint = {'epoch': i,
'loss': loss,
'model': net1.state_dict(),
'optimizer': optimizer.state_dict()}
torch.save(checkpoint, 'checkpoints/last.pt')
break
i+=1