训练一个分类器
通过之前的学习,相信你已经学会如何定义一个神经网络,计算损失和更新模型的权重值。那么,现在你可能会思考一个问题:
数据是什么?
通常,当你在处理图像、文本、音频或视频数据时,你可以用标准的python 包将这些数据加载到numpy array中,然后转换为一个torch.*tensor。
(1)对于图像,常用的包有Pillow,OpenCV;
(2)对于音频,常用的包有scipy和librosa;
(3)对于文本,可以用raw Python或CyPython加载,或者也可以使用NLTK和SpaCy。
尤其对于视觉,我们已经建立了一个torchvision的包,它包含对通用数据(例如ImageNet, CIFAR10,MNIST等)的加载器和和对图像的数据传输器,即torchvision.datasets和torch.utils.data.DataLoader.这样为我们提供了很多方便并且不用重写代码。
在接下来的教程中,我们会使用CIFAR10数据集。这个数据集包含的类有:"airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"。在CIFAR10中的图片的尺寸(size)为 3×32×32,即 3-channel color images of 32×32 pixels of size。
cifar10:
训练一个图像分类器:
目录:
(1) Loading and Normalizing the CIFAR10 training and test dataset using torchvision;
(2) Define a Convolution Neural Network
(3) Define a loss function
(4) Train the network on the training data
(5) Test the network on the test data
(1) Loading and normalizing CIFAR10
使用torchvision加载CIFAR10数据
import torch
import torchvision
import torchvision.transforms as transforms
torchvision数据集输出去的图像是[0, 1]的PIL Image图像。我们将这些图像转换为Tensor并且归一化在[-1, 1]的范围内。
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoder(trainset, batch_size = 4, shuffle = True, num_workers = 2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoder(testset, batch_size=4, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Out:Files already downloaded and verified
Files already downloaded and verified
下面我们show一些training images:
import matplotlib.pyplot as plt
import numpy as np
# functions to show an image
def imshow(img):
img = img / 2 + 0.5 # un-normalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
#show images
imshow(torchvision.utils.make_grid(images))
#print labels
print(''.join('%5s' % classes[labels[j]] for j in range(4)))
Out:ship truck horse horse
(2) Define a convolution neural network
将之前的卷积神经网络部分的代码复制过来,并且将1-channel改为3channels.
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
(3) Define a loss function and optimizer
这里我们使用分类交叉熵损失和有动量的SGD:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
(4) Train the Network
接下来的事情就变得很有趣了。我们通过循环遍历数据,将数据输入到网络并且做优化。
for epoch in range(2): #loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
#get the inputs
inputs, labels = data
#zero the parameter gradients
optimizer.zero_grad()
#forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
#print statistics
running_loss += loss.item()
if i % 2000 == 1999: #print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' % (epoch + 1, i+1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
out:[1, 2000] loss: 2.199
[1, 4000] loss: 1.856
[1, 6000] loss: 1.688
[1, 8000] loss: 1.606
[1, 10000] loss: 1.534
[1, 12000] loss: 1.488
[2, 2000] loss: 1.420
[2, 4000] loss: 1.384
[2, 6000] loss: 1.336
[2, 8000] loss: 1.351
[2, 10000] loss: 1.309
[2, 12000] loss: 1.277
Finished Training
(5) Test the network on the test data
我们已经在trainning dataset上把网络训练了两次。但是我们还是需要去检验一下网络是否学到了一些东西。
我们通过预测神经网络输出的类来检验网络,并将预测的结果和ground-truth对比。如果预测正确,我们将其加入到正确预测的列表中。
首先,我们先看一下test set中的图片。
dataiter = iter(testloader)
images, labels = dataiter.next()
#print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: '.join('%5s' % classes[labels[j]] for j in range(4)))
out:
GroundTruth: cat ship ship plane
接下来,让我们看一下神经网络的预测结果:
outputs= net(images)
输出的结果是10个类的能量,一个类的能量越高,网络的预测就越准确。最大能量的指数为:
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ''.join('%5s' % classes[predicted[j]] for j in range(4)))
Out:Predicted: cat car car plane
以上的结果看起来还不错!
接下来我们看看网络在整个数据集上的表现:
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100*correct / total))
Out:Accuracy of the network on the 10000 test images: 53 %
看看哪些类预测的更准,那些类预测的不准:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(4):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(10):
print("Accuracy of %5s : %2d %%" % classes[i], 100*class_correct[i] / class_total[i])
Out:
Accuracy of plane : 60 %
Accuracy of car : 75 %
Accuracy of bird : 33 %
Accuracy of cat : 50 %
Accuracy of deer : 26 %
Accuracy of dog : 47 %
Accuracy of frog : 54 %
Accuracy of horse : 66 %
Accuracy of ship : 48 %
Accuracy of truck : 70 %
(6) Trainning on GPU
就像我们将一个Tensor闯传入GPU一样,我们可以将神经网络传到GPU上。
如果电脑已经安装CUDA,我们先定义CUDA.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Assume that we are on a CUDA machine, then this should print a CUDA device:
print(device)
Out:cuda:0
接下来我们都是假设设备是有CUDA的设备。
然后这些方法将会递归的遍历所有Modules并且将他们的参数和缓存传入到CUDA tensors中:
net.to(device)
记住,你还必须将每一步的inputs和targets传入GPU中:
inputs, labels = input.to(device), labels.to(device)