加载预训练网络模型并加载权重
resnet50=torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
in_features=resnet50.fc.in_features
# 将原resnet50网络中的最后一个全连接层改成10分类的输出
resnet50.fc=nn.Linear(in_features,10)
resnet50=resnet50.to(device)
因为resnet50网络需要输入224x224x3大小的图片
因此对网络接收的输入也要做相应的调整
tf=torchvision.transforms.Compose([
torchvision.transforms.Resize(size=(224,224)),
torchvision.transforms.Grayscale(num_output_channels=3),
torchvision.transforms.ToTensor(),
# torchvision.transforms.Normalize((0.1307,),(0.3081,))
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
固定卷积层参数
# 固定卷积层的参数
optim=torch.optim.Adam(resnet50.fc.parameters(),lr=0.001)
完整代码:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
tf=torchvision