Pytorch 保存加载模型时的坑

Pytorch 保存加载模型时的坑

在说Pytorch保存加载模型时的坑之前,先介绍一下pytorch对训练好的模型如何进行保存和加载。

方法1:保存模型的参数和结构信息

保存:

model=MobileNetV2(n_class=2)#加载模型
############进行训练##########
 model = torch.nn.DataParallel(model, device_ids=[int(i) for i in args.gpus.strip().split(',')])#用多gpus 训练×××关键
############进行训练##########
torch.save(model, os.path.join(args.save_path, "epoch_" + str(epoch) + ".pth.tar"))#保存模型

恢复:

model=torch.load(args.load_path)#

这种方法会出现一个问题:当利用pytorch 1.0.0 保存好了模型后,加载时利用pytorch1.1.0 进行load() 时回报错,所以官方推荐使用第二种方法进行加载

方法二:官方推荐的方法,只保存和恢复模型中的参数

一个完整的例子:
迁移学习加载模型(此时 checkpoint 字典只有 state_dict ):

model=MobileNetV2(n_class=2)#加载模型结构
model_dict =  model.state_dict()#获取模型参数(未加载保存的模型参数 )
if args.resume:#模型路径
    if os.path.isfile(args.resume):
        print(("=> loading checkpoint '{}'".format(args.resume)))
        checkpoint = torch.load(args.resume)#获取模型参数
         #因为我修改网络模型进行迁移学习,这一步是在checkpoint里获取没有修改的模型参数state_dict
        state_dict = {k: v for k, v in checkpoint.items() if k in model_dict.keys()}
        model_dict.update(state_dict)#更新已经保存的参数至model_dict
        model.load_state_dict(model_dict)#加载模型参数
    else:
        print(("=> no checkpoint found at '{}'".format(args.resume)))

保存:–这里有坑

torch.save({"epoch":epoch, #一共训练的epoch
                   "model_state_dict":model.module.state_dict(), #保存模型参数×××××这里埋个坑××××
                   'epoch_acc': epoch_acc, #一共训练的epoch
                   "optimizer":optimizer.state_dict() }#优化器好像也在保存,这样可以继续加载模型进行训练
                   ,os.path.join(args.save_path,"checkpoints_epoch_" + str(epoch) + ".tar"))      

再加载:

print("start loading cls model")
model=MobileNetV2(n_class=2)
if os.path.isfile(args.load_path):
    state_dict=torch.load(args.load_path)
    print(state_dict['epoch'])#获取保存的参数 对应key值的参数
    print(state_dict['epoch_acc'])
    params=state_dict["model_state_dict"] 
    for param_tensor in params:#打印参数信息
         print(param_tensor,"\t",params[param_tensor].size())
    model.load_state_dict(params)
    print("load cls model successfully")             

填坑

这段保存模型参数的代码

torch.save({"epoch":epoch, #一共训练的epoch
                   "model_state_dict":model.module.state_dict(), #保存模型参数
                   'epoch_acc': epoch_acc, #一共训练的epoch
                   "optimizer":optimizer.state_dict() }#优化器好像也在保存,这样可以继续加载模型进行训练
                   ,os.path.join(args.save_path,"checkpoints_epoch_" + str(epoch) + ".tar")) 

torch.save({"epoch":epoch, #一共训练的epoch
                   "model_state_dict":model.state_dict(), #保存模型参数
                   'epoch_acc': epoch_acc, #一共训练的epoch
                   "optimizer":optimizer.state_dict() }#优化器好像也在保存,这样可以继续加载模型进行训练
                   ,os.path.join(args.save_path,"checkpoints_epoch_" + str(epoch) + ".tar")) 

与这段的不同在于model.module.state_dict()与model.state_dict()的区别
现在来打印一下

model=MobileNetV2(n_class=2)#加载模型结构
model_dict =  model.state_dict()#获取模型参数(未加载保存的模型参数 )
model_dict----------model.module.state_dict()---------model.state_dict()三者参数的对应的名称(这里只打印几个)
model_dict:
features.0.0.weight 	 torch.Size([32, 3, 3, 3])
features.0.1.weight 	 torch.Size([32])
features.0.1.bias 	 torch.Size([32])
features.0.1.running_mean 	 torch.Size([32])
features.0.1.running_var 	 torch.Size([32])
features.0.1.num_batches_tracked 	 torch.Size([])

model.module.state_dict():
features.0.0.weight 	 torch.Size([32, 3, 3, 3])
features.0.1.weight 	 torch.Size([32])
features.0.1.bias 	 torch.Size([32])
features.0.1.running_mean 	 torch.Size([32])
features.0.1.running_var 	 torch.Size([32])
features.0.1.num_batches_tracked 	 torch.Size([])

model.state_dict():
module.features.0.0.weight 	 torch.Size([32, 3, 3, 3])
module.features.0.1.weight 	 torch.Size([32])
module.features.0.1.bias 	 torch.Size([32])
module.features.0.1.running_mean 	 torch.Size([32])
module.features.0.1.running_var 	 torch.Size([32])
module.features.0.1.num_batches_tracked 	 torch.Size([])

用多gpus进行训练后直接用model.state_dict()进行保存的模型,每个层参数的名称前面会加上module,这时候再用单卡 gpu model_dict加载model.state_dict()参数时会出现名称不匹配的情况。
因此保存模型时注意使用model.module.state_dict():

总结

1.多gpus训练 用model.state_dict() 保存前面会加上网络参数名称前会加上 module
2.单gpus加载模型,需要去掉网络参数名称前加上的module
两种方法:
(1) 用model.module.state_dict()保存
(2) 去掉网络参数名称前会加上的module再加载模型
3.推荐多gpus训练使用model.module.state_dict()保存,然后单gpu加载,
此时如果还需要多gpu训练可以在加载模型参数后使用torch.nn.DataParallel进行训练

还有另外的思路可参考 @[参考这里][参考这里]多gpu训练(https://blog.youkuaiyun.com/CV_YOU/article/details/86670188)(https://blog.youkuaiyun.com/qq_32998593/article/details/89343507)

<think>好的,用户想了解如何在C++中加载MNIST数据集,用于测试从PyTorch导出的ResNet50模型。首先,我需要理清楚整个流程,确保每一步都正确无误。 首先,用户提到的是ResNet50模型,而MNIST数据集是28x28的灰度图像。但ResNet50通常接受的是224x224的RGB三通道图像。这里可能存在输入尺寸和通道数不匹配的问题。需要提醒用户注意这一点,可能需要进行预处理,比如调整图像大小、转换为三通道,或者修改模型的第一层卷积层。这一步很关键,否则模型无法正确处理MNIST的数据。 接下来,C++加载MNIST数据集。MNIST的数据集格式是固定的,包含图像和标签的文件。通常图像文件是二进制格式,前面有头信息,包括魔数、图像数量、行数和列数。需要指导用户如何正确读取这些二进制文件,处理字节序的问题,可能需要使用标准的文件读取方法,或者第三方库如OpenCV来辅助读取图像。 然后是关于模型导出。用户提到从PyTorch导出模型,应该是指将训练好的模型保存为某种格式,比如TorchScript。需要说明如何使用PyTorch的torch.jit.trace或script来导出模型,并确保在C++环境中能够加载。这里可能涉及到LibTorch的安装和配置,用户需要正确链接库文件。 在C++部分,加载模型需要使用LibTorch的API,比如torch::jit::load。同,读取到的MNIST图像数据需要转换为Torch张量,调整形状和类型,进行归一化处理(如除以255),并确保输入符合模型的期望格式。可能还需要处理批量数据,但如果是单张测试,可能只需扩展一个批次维度。 此外,内存管理和错误处理也是需要注意的地方。比如,读取文件是否正确打开了文件,读取头信息是否正确,张量创建是否成功,模型加载是否有异常等。这些细节如果处理不好,可能导致程序崩溃或结果错误。 用户可能需要一个完整的代码示例,包括读取MNIST文件、预处理图像、加载模型、运行推理和输出结果。但也要提醒他们根据实际情况调整路径和参数,比如文件路径、模型路径,以及预处理步骤是否符合他们的模型要求。 最后,还要考虑性能问题,比如如果处理大量图像,可能需要批量处理以提高效率,但如果是测试单个图像,可能不需要。此外,验证模型输出是否正确,比如检查预测的标签是否合理,或者对比Python中的结果是否一致,以确保导出和加载过程无误。 总结下来,步骤应该是: 1. 导出PyTorch模型为TorchScript。 2. 准备MNIST数据集文件。 3. 在C++中使用LibTorch加载模型。 4. 读取并预处理MNIST数据。 5. 运行推理并处理结果。 每个步骤都需要详细说明,特别是可能遇到的,比如输入尺寸不匹配、字节序问题、文件路径错误等。同,提供代码片段帮助用户快速实现,并解释关键部分的作用,确保用户理解每一步的操作。</think>在C++中加载MNIST数据集并测试PyTorch导出的ResNet50模型,需遵循以下步骤: --- ### 一、核心问题分析 1. **输入维度适配**:ResNet50默认输入为`3x224x224`,而MNIST是`1x28x28`的灰度图,需进行上采样和通道复制 2. **数据格式转换**:MNIST原始数据为二进制文件,需解析为张量格式 3. **模型兼容性**:需将PyTorch模型导出为TorchScript格式供C++使用 --- ### 二、实现步骤详解 #### 步骤1:导出PyTorch模型 ```python # Python端导出模型 import torch import torchvision model = torchvision.models.resnet50(pretrained=True) model.eval() # 设置为推理模式 # 修改第一层适配MNIST单通道输入 model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) # 创建示例输入并导出 example = torch.rand(1, 1, 224, 224) # 假设已上采样到224x224 traced_script_module = torch.jit.trace(model, example) traced_script_module.save("resnet50_mnist.pt") ``` #### 步骤2:MNIST数据加载(C++实现) ```cpp #include <torch/script.h> #include <fstream> // 读取MNIST图像文件的函数 torch::Tensor load_mnist_images(const std::string& path) { std::ifstream file(path, std::ios::binary); // 解析文件头(32位大端格式) int32_t magic_number, num_images, rows, cols; file.read((char*)&magic_number, sizeof(magic_number)); file.read((char*)&num_images, sizeof(num_images)); file.read((char*)&rows, sizeof(rows)); file.read((char*)&cols, sizeof(cols)); // 转换字节序 magic_number = __builtin_bswap32(magic_number); num_images = __builtin_bswap32(num_images); rows = __builtin_bswap32(rows); cols = __builtin_bswap32(cols); // 读取图像数据 auto tensor = torch::empty({num_images, 1, rows, cols}, torch::kByte); file.read((char*)tensor.data_ptr(), tensor.numel()); // 转换为float并归一化 tensor = tensor.to(torch::kFloat32).div_(255); // 上采样到224x224并复制通道 tensor = torch::nn::functional::interpolate( tensor, torch::nn::functional::InterpolateFuncOptions() .mode(torch::kNearest) .size({224, 224}) ).repeat({1, 3, 1, 1}); // 复制单通道到三通道 return tensor; } ``` #### 步骤3:模型加载与推理 ```cpp int main() { // 加载模型 torch::jit::script::Module model; try { model = torch::jit::load("resnet50_mnist.pt"); } catch (const c10::Error& e) { std::cerr << "模型加载失败: " << e.what() << std::endl; return -1; } // 加载测试数据 auto test_images = load_mnist_images("t10k-images-idx3-ubyte"); // 执行推理 auto output = model.forward({test_images}).toTensor(); // 获取预测结果 auto preds = std::get<1>(torch::max(output, 1)); std::cout << "预测结果: " << preds.slice(/*dim=*/0, /*start=*/0, /*end=*/5) << std::endl; return 0; } ``` --- ### 三、关键注意事项 1. **数据预处理一致性**:必须与模型训练的预处理一致(归一化方式等) 2. **内存管理**:使用`libtorch`需注意内存泄漏问题,推荐使用智能指针 3. **性能优化**:可启用OpenMP并行化: ```cpp torch::set_num_threads(4); // 设置线程数 at::globalContext().setBenchmarkCuDNN(true); ``` --- ### 四、编译配置示例 使用CMake需包含LibTorch: ```cmake cmake_minimum_required(VERSION 3.0) project(mnist_resnet) find_package(Torch REQUIRED) add_executable(main main.cpp) target_link_libraries(main "${TORCH_LIBRARIES}") set_property(TARGET main PROPERTY CXX_STANDARD 14) ``` --- ### 五、常见问题解决 1. **维度不匹配错误**:检查模型输入层和预处理后的张量形状 2. **精度下降**:确认C++与Python的预处理完全一致 3. **文件读取失败**:确保MNIST文件路径正确,使用绝对路径更可靠 实际部署建议使用OpenCV优化图像处理流程,并添加异常处理机制保证稳定性。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值