import torch
import torchvision
from torch import nn
import warnings
warnings.filterwarnings("ignore")
vgg16_false = torchvision.models.vgg16(weights=False)
vgg16_true = torchvision.models.vgg16(weights='VGG16_Weights.DEFAULT') #加载默认权重的VGG16
print(vgg16_true)
vgg16_true.add_module('add_linear',nn.Linear(1000,10)) #在最后加上一个线性层
print(vgg16_true)
torch.save(vgg16_true, "model1") #参数与模型结构都保存
torch.save(vgg16_true.state_dict(), "model2") #只保留参数,适合大模型的保存
vgg16_load1 = torch.load("model1")
vgg16_load2 = torch.load("model2")
print(vgg16_load1,vgg16_load2)
模型的改进、保存、加载
最新推荐文章于 2024-06-24 19:45:00 发布