第2章 预训练网络
讨论3种常用的预训练模型:
1、根据内容对图像进行标记(识别)
2、从真实图像中生成新图像(GAN)
3、使用正确的英语句子来描述图像内容(自然语言)
2.1 获取一个预训练好的网络用于图像识别
ImageNet数据集,用于大规模视觉识别挑战赛。
所有预训练好的模型都在TorchVision中。
2.1.1 导入已有的模型
所有模型都在torchvison的models中。导入并查看。
from torchvision import models
dir(models)
输出的是所有torchvison里面集成的模型框架。其中首字母大写的是一些流行的模型。小写的名字是快捷函数,返回实例化模型函数。
1.1.1 AlexNet模型
实例化AlexNet。
alexnet=models.AlexNet()
alexnet
可以像函数一样调用它。给alexnet输入数据,就会通过正向传播(forward pass)得到输出。比如output=alexnet(input)。由于网络没有初始化,没有经过训练。所以一般先要将模型从头训练或者加载训练好的网络。然后再调用。
1.1.2 Resnet模型
(1)加载在ImageNet数据集上训练好的权重,来实例化ResNet101
resnet=models.resnet101(pretrained=True)
resnet
然后就开始下载,下载完成后查看resnet101的结构。
神经网络由许多模块构成,包含过滤器和非线性函数,fc层结束,输出每个类的分数。
预训练好的模型可以跟函数一样调用,并输入图片实现预测。
(2)定义预处理函数:
from torchvision import transforms
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225]
)
])
预处理包括:图像缩放到2