PyTorch 现有网络模型的使用及修改
先用最简单的 VGG分类模型 作为案例
最常用的是VGG16 和 VGG19两个版本
-
pretrained
如果为true 下载的网络模型中的参数在ImageNet数据集中已经训练好。
如果为False 下载的网络模型中的参数没有训练过。 -
progress
如果为True 显示下载进度条
如果为False 不显示下载进度条
由于ImageNet数据集太大,就不下载了。
vgg16 架构
我们先看看看到 vgg_16 的网络架构
import torchvision.datasets
from torch import nn
vgg16_false = torchvision.models.vgg16(pretrained=False) # pretrained=False, 只加载网络模型, 不需要下载, 参数都是默认的
vgg16_true = torchvision.models.