3.2 模型保存、加载

部署运行你感兴趣的模型镜像

Table of Contents

模型保存

常规模型加载 

模型详细信息保存

模型参数冻结


模型保存

# 保存整个网络

torch.save(net, PATH)

# 保存网络中的参数, 速度快,占空间少

torch.save(net.state_dict(),PATH)

常规模型加载 

model_dict=torch.load(PATH)

model_dict=model.load_state_dict(torch.load(PATH))

模型详细信息保存

torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN,

'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma},

checkpoint_path + '/m-' + launchTimestamp + '-' + str("%.4f" % lossMIN) + '.pth.tar')

eg:    
for epoch in range(init_epoch, basic_configs['num_epochs']):
        logger.info("Begin training epoch {}".format(epoch + 1))
        train_function(epoch)

        net_checkpoint_name = args.exp + "_net_epoch" + str(epoch + 1)
        net_checkpoint_path = os.path.join(exp_ckpt_dir, net_checkpoint_name)
        net_state = {"epoch": epoch + 1,
                     "network": net.module.state_dict()}
        torch.save(net_state, net_checkpoint_path)

模型参数冻结 



for p in self.parameters():
    p.requires_grad = False


class RESNET_YYT(nn.Module):

    def init(self, model, pretrained):

        super(RESNET_YYT, self).__init__()

        self.resnet = model(pretrained)

        for p in self.parameters():

            p.requires_grad = False

        self.f = SpectralNorm(nn.Conv2d(2048, 512, 1))

        self.g = SpectralNorm(nn.Conv2d(2048, 512, 1))

        self.h = SpectralNorm(nn.Conv2d(2048, 2048, 1))

打印网络层数

model_dict = torch.load('/001_net_epoch5loss0.08947')['network'] 
dict_name = list(model_dict)
for i, p in enumerate(dict_name):
    print(i, p)

 

 

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

在机器学习模型保存时,是否包含训练配置信息取决于模型保存方式和使用的库。通常情况下,训练配置信息(如优化器参数、损失函数、学习率、训练轮数等)可以通过模型保存时的附加数据进行存储,以便后续恢复模型训练状态或重新训练模型。 ### 3.1 模型保存与训练配置信息的关系 在 PyTorch 和 TensorFlow 等主流深度学习框架中,模型保存时可以将训练配置信息作为附加数据嵌入到模型文件中。例如,在 PyTorch 中,开发者可以使用 `torch.save()` 函数将模型状态字典(`state_dict`)、优化器状态、训练轮数和配置参数一并保存: ```python torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch, 'config': config_dict, }, "model.pth") ``` 这种方式允许在后续加载模型时恢复训练状态,从而实现断点续训或模型调试。类似地,在 TensorFlow 中,`tf.keras.Model.save()` 方法默认保存模型结构、权重以及训练配置(如优化器状态和损失函数),支持完整的模型恢复 [^1]。 ### 3.2 保存训练配置信息的常见做法 为确保模型训练状态可恢复,通常建议在模型保存时显式记录训练配置信息。这些信息包括: - **模型结构参数**:如输入维度、隐藏层大小等; - **优化器参数**:如学习率、动量、权重衰减等; - **训练元数据**:如训练轮数、批量大小、数据增强策略等; - **数据预处理参数**:如图像归一化均值和标准差等。 例如,在 PyTorch 中加载包含训练配置的模型文件时,可以提取并使用这些配置信息: ```python checkpoint = torch.load("model.pth") model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] config = checkpoint['config'] ``` 这种方式有助于在部署或调试过程中确保模型行为与训练阶段一致 [^1]。 ### 3.3 模型部署与训练配置的持续维护 模型部署后,训练配置信息仍然具有重要意义。例如,在模型性能监测和再训练过程中,需要依据原始训练配置进行数据预处理、超参数设置和性能评估。若未保存训练配置信息,可能导致模型在新数据上的预测结果与训练阶段不一致,影响模型的可解释性和部署稳定性 [^2]。 此外,在分布式训练场景中,训练配置信息(如 `tf.distribute.MultiWorkerMirroredStrategy()` 的使用)决定了模型的训练方式和硬件资源分配,因此在模型保存时应明确记录相关策略配置,以确保多节点训练的可复现性 [^3]。 ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值