在一台主机上的虚拟环境训练的模型权重,怎么直接在另一台主机上调用使用?

Python项目中深度学习模型保存与加载方式改进

简介:

该项目是基于python的Flask后端的脑卒中康复系统,前端界面使用pyqt搭建并搭载在本地电脑,前端的主机负责肌电数据采集和康复训练,后端的Flask框架搭载在实验室服务器上主要负责模型的训练。所以在项目实战中碰到在后端训练的深度学习模型,返回至前端后无法直接调用。

1.常规保存网络模型的方式和加载模型的方式

保存方式:

torch.save(model, model_path, _use_new_zipfile_serialization=False)

加载方式:

model = torch.load(model_path)

其中model_path为模型保存路径,这种保存方式和加载方式还是比较简洁实用的,但是针对上述说的训练和实时测试分布在多个主机便不适用。如果使用这种方式在实际过程中会报这个错误:

Exception in thread Thread-1 (model_test):
Traceback (most recent call last):
  File "E:\ANACONDA\envs\py310\lib\threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "E:\ANACONDA\envs\py310\lib\threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "F:\flask_jiemian\myo_trans融合\realtime_predict.py", line 162, in model_test
    Plot(listener, model_path).main()
  File "F:\flask_jiemian\myo_trans融合\realtime_predict.py", line 89, in main
    model = torch.load(model_path)
  File "E:\ANACONDA\envs\py310\lib\site-packages\torch\serialization.py", line 713, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File "E:\ANACONDA\envs\py310\lib\site-packages\torch\serialization.py", line 930, in _legacy_load
    result = unpickler.load()
  File "E:\ANACONDA\envs\py310\lib\site-packages\torch\serialization.py", line 746, in find_class
    return super().find_class(mod_name, name)
ModuleNotFoundError: No module named 'cross_vit_re_attention_channel_at'

2. 为什么常规方式不适用训练和测试分开呢?

因为常规的保存方式是保存整个模型对象,包括模型的架构信息和模型参数。尽管这样可以方便地直接加载模型对象,但在许多情况下,不一定适用于我的应用场景,这种方式有如下几个缺点:

  • 当用 torch.save 保存整个模型时,模型架构及其参数都被序列化,任何包含自定义层、模块的模型都要求完全相同的代码结构以及版本才能正确加载。
  • 比如你在一个虚拟环境中用新版本的 PyTorch 保存了模型,另一个环境中的旧版本可能无法加载,反之亦然。而 state_dict 只存储参数,不包含架构,因此只要架构一致,不同版本的 PyTorch 也可以较容易地兼容。
  • 使用 torch.save(model, model_path) 方式,保存了模型的完整结构。在加载时,PyTorch 会尝试还原完整的模型结构,因此你无法轻松地修改或调试模型的某一层或模块。
  • 而使用 state_dict,可以根据需求在相似的结构上加载参数,即使结构有小改动也更灵活。
  • 保存整个模型对象会包含所有模型定义和模型参数,通常比只保存 state_dict 的文件体积大。而 state_dict 只保存权重和偏置等必要参数,在长时间训练或需要频繁保存的场景下更省存储空间。
  • 如果需要跨平台、跨设备部署模型,比如从服务器部署到边缘设备,或者从 GPU 设备转移到 CPU 上,加载整个模型对象可能会产生兼容性问题。
  • state_dict 更轻便且不与硬件设置绑定,使得模型可以在不同的硬件之间更方便地共享。

3. 修改后保存方式和加载方式

保存方式:

torch.save(model.state_dict(), model_path)

加载方式:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CrossViT(
            image_size=(8, 200),
            num_classes=8,
            depth=1,
            sm_dim=16,
            sm_patch_size=(8, 10),
            sm_enc_depth=1,
            sm_enc_heads=8,
            sm_enc_mlp_dim=512,
            lg_dim=64,
            lg_patch_size=(8, 20),
            lg_enc_depth=4,
            lg_enc_heads=8,
            lg_enc_mlp_dim=512,
            cross_attn_depth=2,
            cross_attn_heads=8,
            dropout=0.1,
            emb_dropout=0.1,
            channel=8
        ).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))

为什么突然多了CrossVit呢?这是因为当使用 state_dict 保存加载模型参数时,CrossViT 的具体参数仍需定义,因为 state_dict 仅保存模型的权重和偏置,而不包含模型的架构信息。加载时需要先构建一个与保存时完全相同的模型架构,然后才能通过 load_state_dict 加载这些参数。

通过该方式,问题得到很好的解决,具体项目部署过程中有些许小差错,解决后项目功能实现。

PyTorch版的YOLOv5是轻量而高性能的实时目标检测方法。利用YOLOv5训练完自己的数据集后,如何向大众展示并提供落地的服务呢?   本课程将提供相应的解决方案,具体讲述如何使用Web应用程序框架Flask进行YOLOv5的Web应用部署。用户可通过客户端浏览器上传图片,经服务器处理后返回图片检测数据并在浏览器中绘制检测结果。   本课程的YOLOv5使用ultralytics/yolov5,在Ubuntu系统上做项目演示,并提供在Windows系统上的部署方式文档。 本项目采取前后端分离的系统架构和开发方式,减少前后端的耦合。课程包括:YOLOv5的安装、 Flask的安装、YOLOv5的检测API接口python代码、 Flask的服务程序的python代码、前端html代码、CSS代码、Javascript代码、系统部署演示、生产系统部署建议等。   本人推出了有关YOLOv5目标检测的系列课程。请持续关注该系列的其它视频课程,包括:《YOLOv5(PyTorch)目标检测实战:训练自己的数据集》Ubuntu系统 https://edu.youkuaiyun.com/course/detail/30793 Windows系统 https://edu.youkuaiyun.com/course/detail/30923 《YOLOv5(PyTorch)目标检测:原理与源码解析》https://edu.youkuaiyun.com/course/detail/31428 《YOLOv5(PyTorch)目标检测实战:Flask Web部署》https://edu.youkuaiyun.com/course/detail/31087 《YOLOv5(PyTorch)目标检测实战:TensorRT加速部署》https://edu.youkuaiyun.com/course/detail/32303
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值