载入预训练权重
1. 直接载入预训练权重
简单粗暴法:
pretrain_weights_path = "./resnet50.pth"
net.load_state_dict(torch.load(pretrain_weights_path))
如果这里的pretrain_weights与我们训练的网络不同,一般指的是包含大于模型参数时,可以修改为
net.load_state_dict(torch.load(pretrain_weights_path), strict=False)
2. 修改网络结构
常用方法1:
model_weight_path = "resnet34pre.pth"
net.load_state_dict(torch.load(model_weight_path))
# 这里假设最后一层为FC层,使用迁移学习,将分类结果修改
# net是实例化的resnet网络,in_features是网络输入结构参数,最后的5是修改的输出参数
inchannel = net.fc.in_features
net.fc = nn

本文介绍如何在PyTorch中加载预训练模型权重,并提供多种方法调整模型结构以适应不同的任务需求,包括直接加载、修改网络结构及冻结训练等。
最低0.47元/天 解锁文章
2万+

被折叠的 条评论
为什么被折叠?



