说明
使用torchvision.model加载预训练好的模型时,发现默认下载路径在系统盘下面的用户目录下(这个你执行的时候就会发现),即模型下载的默认路径
C:\用户名.cache\torch.checkpoints下,
修改方法
- 发现pytorch的默认下载路径是由load_state_dict_from_url函数进行控制,那么就好办了,只需要找到这个函数进行修改即可
- 由于我是下载vgg16,所以我先找到
vgg.py
源码,位于python路径
torchvision/models/vgg.py
(其中python路径即是你安装python所在的地址,假设你是用Anaconda创建了一个名为envtest的虚拟环境,那么所有安装的库都会Anaconda/envs/envtest/Lib/site-packages
这个文件夹下) - 接下来就是套娃操作在
vgg.py
直接搜索load_state_dict_from_url
发现有如下语句在这里
显然,需要到torch安装路径下找到hub.py - 找到
hub.py
,搜索load_state_dict_from_url
,成功找到如下代码
可发现model_dir参数即为下载模型的默认路径,所以直接将model_dir = None换成model_dir = 想要的模型下载绝对路径即可,感兴趣的同学可以仔细专研,这里就不过多阐述