【Libtorch部署】pkl 转 pt

本文介绍了如何将PyTorch的pkl格式模型文件转换为pt格式,以便于在C++环境中部署Libtorch使用。提供了两种不同的转换方法,并强调了在不同环境下(CPU/GPU)模型的具体操作细节。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在C++部署Libtorch的时候,需要调用 pt 格式的模型,因此需要提前将 pkl 格式的模型转成 pt 格式。

  • CPU 版本
import torch
from myModels.UNet import UNet as unet

# Download and load the pre-trained model
# 模型加载方式1
# model = torch.load("checkpoints/model_50.pkl",map_location='cpu')

# 模型加载方式2,需要提前加载模型框架(断点)
num_channels = 3
num_classes = 1
model = unet(n_classes=num_classes,n_channels=num_channels).to(device='cpu')

checkpoint = torch.load('./checkpoints(Unet)/model.pkl')            # 加载断点
model.load_state_dict(checkpoint['net'])            # 加载模型可学习参数

# 模型加载方式3,需要提前加载模型框架(非断点)
# model.load_state_dict(torch.load('checkpoints/model_50.pkl'))

model.eval()

example_input = torch.rand(1, 3, 256, 256)
script_module = torch.jit.trace(model, example_input)
script_module.save('./checkpoints(Unet)/model_cpu.pt')
- GPU 版本(看看就好,其实用不着)
import torch
from myModels.UNet import UNet as unet
# 模型加载
num_channels = 3
num_classes = 1
model = unet(n_classes=num_classes,n_channels=num_channels).to(device='cuda')

checkpoint = torch.load('./checkpoints(Unet)/model.pkl')            # 加载断点
model.load_state_dict(checkpoint['net'])            # 加载模型可学习参数

model.cuda()
model.eval()

example_input = torch.rand(1, 3, 256, 256).cuda()  # 生成测试sample
script_module = torch.jit.trace(model, example_input)
script_module.save('./checkpoints(Unet)/model_gpu.pt')

【P.S.】

  Python中无需将模型转成GPU版本的 pt,因为CPU版本的 pt 在C++中Release模型下也可以调用CUDA,相反GPU版本的 pt 在C++中Debug模型还会报错。
  在Libtorch中:Debug 模型下,默认设备使用的是 CPU; Releas 模型下,根据用户设置使用 GPU / CPU。

<think>好的,用户现在问的是哪些常见的框架支持导出机器学习模型供后端调用。这个问题是在之前的对话中提到的相关问题中的第一个,所以用户可能是在进一步了解模型部署的具体技术栈。 首先,我需要回想一下常见的机器学习框架,比如TensorFlow、PyTorch、Scikit-learn、ONNX、Keras、XGBoost/LightGBM等,这些都是常用的,并且它们都有导出模型的功能。接下来要确认每个框架支持的导出格式,比如TensorFlow的SavedModel或HDF5,PyTorch的TorchScript或ONNX,Scikit-learn的pickle文件等。 然后,用户可能想知道这些导出格式如何在后端调用,是否需要特定的库或工具。比如,TensorFlow可以用TF Serving,PyTorch可以用TorchServe,或者通过REST API调用。此外,用户可能关心跨平台兼容性,比如ONNX的作用,以及不同框架之间的互操作性。 还要考虑用户的实际应用场景,比如他们是否在寻找生产环境中的高效部署方案,或者是否需要考虑模型的大小和推理速度,比如量化或剪枝技术。另外,用户可能对数据预处理和后处理的集成有疑问,但当前问题主要集中在导出模型的支持框架上。 需要确保回答覆盖主流框架,每个框架的导出方法、常用工具,以及一些优缺点或适用场景。比如,TensorFlow适合大规模部署PyTorch动态图方便但可能需要静态,ONNX解决跨框架问题等。 同时,要注意用户可能的隐藏需求,比如是否在比较不同框架的部署难易度,或者是否有特定的性能要求。回答中需要简明扼要,结构清晰,分点列出框架、导出格式、调用方法,并给出例子,帮助用户快速理解和选择适合自己项目的工具。</think>以下是常见的机器学习框架及其支持的模型导出格式和后端调用方式,按技术特性分类说明: ### 一、深度学习框架 1. **TensorFlow** - **导出格式** - SavedModel(目录结构,含计算图和权重) - Frozen Graph(`.pb` 文件,固化计算图) - HDF5(`.h5`,兼容 Keras 模型) - **调用方式** - 使用 `tf.saved_model.load()` 直接加载 - 通过 TensorFlow Serving 部署为 gRPC/REST 服务 - 编译为 TensorFlow Lite 格式(移动端优化) 2. **PyTorch** - **导出格式** - TorchScript(`.pt` 或 `.pth`,静态计算图) - ONNX(跨框架通用格式) - **调用方式** - 用 `torch.jit.load()` 加载 TorchScript - 通过 TorchServe 部署为微服务 - 换为 C++ 可调用的 LibTorch 格式 3. **Keras** - **导出格式** - HDF5(`.h5`) - SavedModel(与 TensorFlow 兼容) - **调用方式** - 通过 `keras.models.load_model()` 加载 - 集成到 Flask/Django 等 Web 框架中直接调用 ### 二、传统机器学习框架 4. **Scikit-learn** - **导出格式** - Pickle(`.pkl`,Python 原生序列化) - Joblib(`.joblib`,优化大数组存储) - **调用方式** - 用 `pickle.load()` 或 `joblib.load()` 反序列化 - 通过 MLflow 包装为 API 服务 5. **XGBoost/LightGBM** - **导出格式** - 二进制模型文件(`.model`) - JSON 格式(可读性强,便于调试) - **调用方式** - 调用 `xgb.Booster()` 或 `lgb.Booster` 加载 - 换为 ONNX 实现跨语言调用 ### 三、跨框架解决方案 6. **ONNX(Open Neural Network Exchange)** - **核心能力** - 将 TensorFlow/PyTorch/MXNet 等模型换为 `.onnx` 格式 - 支持 C++/C#/Java 等语言调用 - **运行时库** - ONNX Runtime(微软优化,支持 GPU 加速) ### 四、部署优化工具 7. **专用服务器方案** - **TensorFlow Serving**:低延迟,自动版本管理 - **TorchServe**:支持动态批处理和多模型管理 - **Triton Inference Server**(NVIDIA):多框架支持,GPU 优化 ### 五、实际项目选型建议 - **Web 服务轻量级部署**:优先选 Pickle(Scikit-learn)或 Flask 直接加载 - **高并发生产环境**:用 TensorFlow Serving/TorchServe - **边缘设备部署**:换为 TensorFlow Lite 或 ONNX + ONNX Runtime - **跨语言调用需求**:统一导出为 ONNX 格式 *示例代码(PyTorch 模型调用)*: ```python # 导出 TorchScript scripted_model = torch.jit.script(model) scripted_model.save("model.pt") # 后端加载推理 model = torch.jit.load("model.pt") inputs = torch.tensor(preprocessed_data) outputs = model(inputs) ``` 不同框架导出的模型在实际调用时需注意**输入维度匹配**、**运行环境依赖库版本**以及**线程安全**等问题。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值