简介:
该项目是基于python的Flask后端的脑卒中康复系统,前端界面使用pyqt搭建并搭载在本地电脑,前端的主机负责肌电数据采集和康复训练,后端的Flask框架搭载在实验室服务器上主要负责模型的训练。所以在项目实战中碰到在后端训练的深度学习模型,返回至前端后无法直接调用。
1.常规保存网络模型的方式和加载模型的方式
保存方式:
torch.save(model, model_path, _use_new_zipfile_serialization=False)
加载方式:
model = torch.load(model_path)
其中model_path为模型保存路径,这种保存方式和加载方式还是比较简洁实用的,但是针对上述说的训练和实时测试分布在多个主机便不适用。如果使用这种方式在实际过程中会报这个错误:
Exception in thread Thread-1 (model_test):
Traceback (most recent call last):
File "E:\ANACONDA\envs\py310\lib\threading.py", line 1016, in _bootstrap_inner
self.run()
File "E:\ANACONDA\envs\py310\lib\threading.py", line 953, in run
self._target(*self._args, **self._kwargs)
File "F:\flask_jiemian\myo_trans融合\realtime_predict.py", line 162, in model_test
Plot(listener, model_path).main()
File "F:\flask_jiemian\myo_trans融合\realtime_predict.py", line 89, in main
model = torch.load(model_path)
File "E:\ANACONDA\envs\py310\lib\site-packages\torch\serialization.py", line 713, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "E:\ANACONDA\envs\py310\lib\site-packages\torch\serialization.py", line 930, in _legacy_load
result = unpickler.load()
File "E:\ANACONDA\envs\py310\lib\site-packages\torch\serialization.py", line 746, in find_class
return super().find_class(mod_name, name)
ModuleNotFoundError: No module named 'cross_vit_re_attention_channel_at'
2. 为什么常规方式不适用训练和测试分开呢?
因为常规的保存方式是保存整个模型对象,包括模型的架构信息和模型参数。尽管这样可以方便地直接加载模型对象,但在许多情况下,不一定适用于我的应用场景,这种方式有如下几个缺点:
- 当用
torch.save保存整个模型时,模型架构及其参数都被序列化,任何包含自定义层、模块的模型都要求完全相同的代码结构以及版本才能正确加载。 - 比如你在一个虚拟环境中用新版本的 PyTorch 保存了模型,另一个环境中的旧版本可能无法加载,反之亦然。而
state_dict只存储参数,不包含架构,因此只要架构一致,不同版本的 PyTorch 也可以较容易地兼容。 - 使用
torch.save(model, model_path)方式,保存了模型的完整结构。在加载时,PyTorch 会尝试还原完整的模型结构,因此你无法轻松地修改或调试模型的某一层或模块。 - 而使用
state_dict,可以根据需求在相似的结构上加载参数,即使结构有小改动也更灵活。 - 保存整个模型对象会包含所有模型定义和模型参数,通常比只保存
state_dict的文件体积大。而state_dict只保存权重和偏置等必要参数,在长时间训练或需要频繁保存的场景下更省存储空间。 - 如果需要跨平台、跨设备部署模型,比如从服务器部署到边缘设备,或者从 GPU 设备转移到 CPU 上,加载整个模型对象可能会产生兼容性问题。
state_dict更轻便且不与硬件设置绑定,使得模型可以在不同的硬件之间更方便地共享。
3. 修改后保存方式和加载方式
保存方式:
torch.save(model.state_dict(), model_path)
加载方式:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CrossViT(
image_size=(8, 200),
num_classes=8,
depth=1,
sm_dim=16,
sm_patch_size=(8, 10),
sm_enc_depth=1,
sm_enc_heads=8,
sm_enc_mlp_dim=512,
lg_dim=64,
lg_patch_size=(8, 20),
lg_enc_depth=4,
lg_enc_heads=8,
lg_enc_mlp_dim=512,
cross_attn_depth=2,
cross_attn_heads=8,
dropout=0.1,
emb_dropout=0.1,
channel=8
).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
为什么突然多了CrossVit呢?这是因为当使用 state_dict 保存加载模型参数时,CrossViT 的具体参数仍需定义,因为 state_dict 仅保存模型的权重和偏置,而不包含模型的架构信息。加载时需要先构建一个与保存时完全相同的模型架构,然后才能通过 load_state_dict 加载这些参数。
通过该方式,问题得到很好的解决,具体项目部署过程中有些许小差错,解决后项目功能实现。
Python项目中深度学习模型保存与加载方式改进
817

被折叠的 条评论
为什么被折叠?



