PyTorch迁移VGG16
目录
1. 下载已经具备最优参数的模型
这需要对我们之前使用的model=Models()代码部分进行替换,因为我们已经不需要再自己搭建和定义训练的模型了,而是通过代码自动下载模型并直接调用,具体代码如下:
import torchvision.models as models
model = models.vgg16(pretrained=True)
1)代码含义
指定进行下载的模型为VGG16,并通过设置参数pretrained=True,来实现下载模型附带了已经优化好的模型参数。
2)查看迁移模型细节
2. 对当前迁移过来的模型进行调整
尽管迁移学习要求我们需要解决的问题之间具有很强的相似性,但是每个问题对最后输出的结果会有不一样的要求,而承担整个模型输出分类工作的是卷积神经网络模型中的全连接层,所以在迁移学习中调整最多的也是全连接部分。其基本思路是冻结卷积神经网络中全连接层之前的全部网络层次,让这些被冻结的网络层次中的参数在模型训练过程中不进行梯度更新,能够被优化的参数仅仅是没被冻结的全连接层的全部参数。
由于迁移过来的VGG16架构模型在最后输出的结果是1000个,而我们的Dogs vs. Cats只需输出两个结果,所以全连接层必须进行调整,具体实现代码如下:
for parma in model.parameters():
parma.requires_grad = False
model.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 4096),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.5),
torch.nn.Linear(4096, 4096),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.5),
torch.nn.Linear(4096, 2))
Use_gpu = torch.cuda.is_available()
if Use_gpu:
model = model.cuda()
cost = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=0.00001)
1)代码含义
- 首先,对原模型中的参数进行遍历操作,将参数parma.requires_grad全部设置为False,这样对应的参数将不计算梯度,当然也就不会进行梯度更新了,这就是之前所说的冻结操作;
- 然后,定义新的全连接层结构并赋值给model.classifier,在完成新的全连接层定义之后,全连接层中的parma.requires_grad参数会被默认重置为True,所以不需要再次遍历参数来进行解冻操作;
- 将优化函数中负责优化参数设置为全连接层中的所有参数,即对model.classifier.parameters这部分参数进行优化。
2)查看调整后的模型细节
3. 完整代码
import torch
from torch.autograd import Variable
import torchvision
from torchvision import datasets, transforms, models
import os
import matplotlib.pyplot as plt
import time
%matplotlib inline
data_dir = './data/DogsVSCats'
# 定义要对数据进行的处理
data_transform = {x: transforms.Compose([transforms.Resize([224, 224]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
for x in ["train", "valid"]}
# 数据载入
image_datasets = {x: datasets.ImageFolder(root=os.path.join(data_dir, x),
transform=data_transform[x])
for x in ["train", "valid"]}
# 数据装载
dataloader = {x: torch.utils.data.DataLoader(dataset=image_datasets[x],
batch_size=16,
shuffle=True)
for x in ["train", "valid"]}
X_example, y_example = next(iter(dataloader["train"]))
# print(u'X_example个数{}'.format(len(X_example)))
# print(u'y_example个数{}'.format(len(y_example)))
# print(X_example.shape)
# print(y_example.shape)
# 验证独热编码的对应关系
index_classes = image_datasets["train"].class_to_idx
# print(index_classes)
# 使用example_classes存放原始标签的结果
example_classes = image_datasets["train"].classes
# print(example_classes)
# 图片预览
img = torchvision.utils.m