PyTorch 现有网络模型的使用及修改
先用最简单的 VGG分类模型 作为案例
最常用的是VGG16 和 VGG19两个版本
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-N2vi6w67-1645014558525)(C:\Users\14158\AppData\Roaming\Typora\typora-user-images\1644756654647.png)]](https://i-blog.csdnimg.cn/blog_migrate/24c6213fe405936197270f8d3546b900.png)
-
pretrained
如果为true 下载的网络模型中的参数在ImageNet数据集中已经训练好。
如果为False 下载的网络模型中的参数没有训练过。 -
progress
如果为True 显示下载进度条
如果为False 不显示下载进度条
由于ImageNet数据集太大,就不下载了。
vgg16 架构
我们先看看看到 vgg_16 的网络架构
import torchvision.datasets
from torch import nn
vgg16_false = torchvision.models.vgg16(pretrained=False) # pretrained=False, 只加载网络模型, 不需要下载, 参数都是默认的
vgg16_true = torchvision.models.

本文介绍了如何在PyTorch中使用VGG16模型,包括预训练模型的选择,如何查看网络架构,以及如何根据需求添加和修改网络层。针对分类任务,演示了当目标类别数不同于原始模型时,如何调整最后的线性层以适应新的任务。此外,还讨论了如何修改现有模型中的特定层参数。
最低0.47元/天 解锁文章
1049

被折叠的 条评论
为什么被折叠?



