PyTorch载入预训练权重方法和冻结权重方法

本文介绍如何在PyTorch中加载预训练模型权重,并提供多种方法调整模型结构以适应不同的任务需求,包括直接加载、修改网络结构及冻结训练等。

载入预训练权重

1. 直接载入预训练权重

简单粗暴法:

pretrain_weights_path = "./resnet50.pth"
net.load_state_dict(torch.load(pretrain_weights_path))

如果这里的pretrain_weights与我们训练的网络不同,一般指的是包含大于模型参数时,可以修改为

net.load_state_dict(torch.load(pretrain_weights_path), strict=False)

2. 修改网络结构

常用方法1:

model_weight_path = "resnet34pre.pth"
net.load_state_dict(torch.load(model_weight_path))
# 这里假设最后一层为FC层,使用迁移学习,将分类结果修改
# net是实例化的resnet网络,in_features是网络输入结构参数,最后的5是修改的输出参数
inchannel = net.fc.in_features
net.fc = nn
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Fighting_1997

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值