3.2 模型保存、加载

部署运行你感兴趣的模型镜像

Table of Contents

模型保存

常规模型加载 

模型详细信息保存

模型参数冻结


模型保存

# 保存整个网络

torch.save(net, PATH)

# 保存网络中的参数, 速度快,占空间少

torch.save(net.state_dict(),PATH)

常规模型加载 

model_dict=torch.load(PATH)

model_dict=model.load_state_dict(torch.load(PATH))

模型详细信息保存

torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN,

'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma},

checkpoint_path + '/m-' + launchTimestamp + '-' + str("%.4f" % lossMIN) + '.pth.tar')

eg:    
for epoch in range(init_epoch, basic_configs['num_epochs']):
        logger.info("Begin training epoch {}".format(epoch + 1))
        train_function(epoch)

        net_checkpoint_name = args.exp + "_net_epoch" + str(epoch + 1)
        net_checkpoint_path = os.path.join(exp_ckpt_dir, net_checkpoint_name)
        net_state = {"epoch": epoch + 1,
                     "network": net.module.state_dict()}
        torch.save(net_state, net_checkpoint_path)

模型参数冻结 



for p in self.parameters():
    p.requires_grad = False


class RESNET_YYT(nn.Module):

    def init(self, model, pretrained):

        super(RESNET_YYT, self).__init__()

        self.resnet = model(pretrained)

        for p in self.parameters():

            p.requires_grad = False

        self.f = SpectralNorm(nn.Conv2d(2048, 512, 1))

        self.g = SpectralNorm(nn.Conv2d(2048, 512, 1))

        self.h = SpectralNorm(nn.Conv2d(2048, 2048, 1))

打印网络层数

model_dict = torch.load('/001_net_epoch5loss0.08947')['network'] 
dict_name = list(model_dict)
for i, p in enumerate(dict_name):
    print(i, p)

 

 

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值