PyTorch比较好的代码写法

本文介绍如何在修改模型网络结构后,正确载入旧有模型预训练的参数,并探讨了在优化器中过滤不需要训练参数的方法及其可能的影响。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1 修改模型网络结构后,如何载入旧有模型预训练的参数到修改后的模型

def load_checkpoint(model, checkpoint):
        # 修改后的模型的参数
        model_dict = model.state_dict()
        # 旧有模型结构的预训练网络模型,其中['state_dict']保存的模型参数
        modelCheckpoint = torch.load(checkpoint)
        pretrained_dict = modelCheckpoint['state_dict']
        # 将训练好的参数update到model_dict当中
        new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
        model_dict.update(new_dict)

        model.load_state_dict(model_dict)

2 过滤掉冻结不训练参数的code

optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()), 
        lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5
        )

疑问:不过滤有影响吗?我觉得没有哎,反正也不更新,待验证...

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值