1、下载VGG16的模型(True、False):
import torchvision
# train_data = torchvision.datasets.ImageNet('dataset', split='train', download=True,
# transform=torchvision.transforms.ToTensor())
from torch import nn
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)
False代表下载原始参数;
True代表下载已经训练好的参数;
2、将VGG16应用在CIFAR10数据集上:
(1)下载数据集:
train_data = torchvision.datasets.CIFAR10('dataset', train=True, transform=torchvision.transforms.ToTensor(),
download=True)
(2)修改网络:
A:添加一层(在外面加):
参数为:in_feature = 1000, out_feature = 10
vgg16_true.add_module('add_linear', nn.Linear(1000, 10))
B:添加一层(在里面加):
参数为:in_feature = 1000, out_feature = 10
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
C:修改网络:
修改参数为:in_feature = 4096, out_feature = 10
vgg16_false.classifier[6] = nn.Linear(4096, 10)