pytorch模型部署

PyTorch 模型部署是将训练好的模型应用到生产环境中的过程。常见的部署方式包括使用 TorchScriptONNXTorchServeFlask/Django 等框架。以下是 PyTorch 模型部署的详细教程。


1. 使用 TorchScript 部署

TorchScript 是 PyTorch 提供的一种模型序列化工具,可以将 PyTorch 模型转换为可以在 C++ 或 Python 中运行的脚本。

1.1 将模型转换为 TorchScript

import torch
import torchvision

# 加载预训练模型
model = torchvision.models.resnet18(pretrained=True)
model.eval()  # 设置为评估模式

# 创建示例输入
example_input = torch.rand(1, 3, 224, 224)

# 将模型转换为 TorchScript
traced_script_module = torch.jit.trace(model, example_input)

# 保存 TorchScript 模型
traced_script_module.save("resnet18_traced.pt")

1.2 加载 TorchScript 模型

# 加载 TorchScript 模型
loaded_model = torch.jit.load("resnet18_traced.pt")

# 使用模型进行推理
output = loaded_model(example_input)
print(output)

2. 使用 ONNX 部署

ONNX(Open Neural Network Exchange)是一种开放的模型格式,支持跨平台部署。

2.1 将 PyTorch 模型导出为 ONNX

import torch
import torchvision

# 加载预训练模型
model = torchvision.models.resnet18(pretrained=True)
model.eval()  # 设置为评估模式

# 创建示例输入
example_input = torch.rand(1, 3, 224, 224)

# 导出模型为 ONNX 格式
torch.onnx.export(
    model,                      # 模型
    example_input,              # 示例输入
    "resnet18.onnx",            # 保存路径
    input_names=["input"],      # 输入名称
    output_names=["output"],    # 输出名称
    dynamic_axes={"input": {0: "batch_size"},  # 支持动态 batch size
    opset_version=11            # ONNX 版本
)

2.2 使用 ONNX Runtime 进行推理

pip install onnxruntime
import onnxruntime
import numpy as np

# 加载 ONNX 模型
ort_session = onnxruntime.InferenceSession("resnet18.onnx")

# 准备输入数据
ort_inputs = {ort_session.get_inputs()[0].name: example_input.numpy()}

# 进行推理
ort_outputs = ort_session.run(None, ort_inputs)
print(ort_outputs)

3. 使用 TorchServe 部署

TorchServe 是 PyTorch 官方提供的模型服务化工具,支持高性能的模型部署。

3.1 安装 TorchServe

pip install torchserve torch-model-archiver

3.2 打包模型

torch-model-archiver --model-name resnet18 --version 1.0 --model-file model.py --serialized-file resnet18.pth --handler image_classifier

3.3 启动 TorchServe

torchserve --start --model-store /path/to/model_store --models resnet18=resnet18.mar

3.4 发送请求

curl -X POST http://127.0.0.1:8080/predictions/resnet18 -T example.jpg

4. 使用 Flask/Django 部署

Flask 和 Django 是常用的 Python Web 框架,可以将 PyTorch 模型封装为 REST API。

4.1 使用 Flask 部署

pip install flask
from flask import Flask, request, jsonify
import torch
import torchvision.transforms as transforms
from PIL import Image

app = Flask(__name__)

# 加载模型
model = torchvision.models.resnet18(pretrained=True)
model.eval()

# 图像预处理
def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return transform(image).unsqueeze(0)

@app.route('/predict', methods=['POST'])
def predict():
    if 'file' not in request.files:
        return jsonify({"error": "No file provided"}), 400

    file = request.files['file']
    image = Image.open(file.stream).convert('RGB')
    input_tensor = preprocess_image(image)

    with torch.no_grad():
        output = model(input_tensor)
        _, predicted = torch.max(output, 1)

    return jsonify({"class_id": predicted.item()})

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)

4.2 发送请求

curl -X POST -F "file=@example.jpg" http://127.0.0.1:5000/predict

5. 使用 Docker 部署

Docker 可以将模型和依赖打包成容器,方便部署到任何环境中。

5.1 创建 Dockerfile

FROM pytorch/pytorch:latest

# 安装依赖
RUN pip install flask

# 复制代码
COPY app.py /app/app.py
COPY resnet18.pth /app/resnet18.pth

# 设置工作目录
WORKDIR /app

# 暴露端口
EXPOSE 5000

# 启动应用
CMD ["python", "app.py"]

5.2 构建 Docker 镜像

docker build -t pytorch-flask-app .

5.3 运行 Docker 容器

docker run -p 5000:5000 pytorch-flask-app

6. 使用 FastAPI 部署

FastAPI 是一个高性能的 Python Web 框架,适合部署机器学习模型。

6.1 安装 FastAPI

pip install fastapi uvicorn

6.2 创建 FastAPI 应用

from fastapi import FastAPI, File, UploadFile
import torch
import torchvision.transforms as transforms
from PIL import Image

app = FastAPI()

# 加载模型
model = torchvision.models.resnet18(pretrained=True)
model.eval()

# 图像预处理
def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return transform(image).unsqueeze(0)

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    image = Image.open(file.file).convert('RGB')
    input_tensor = preprocess_image(image)

    with torch.no_grad():
        output = model(input_tensor)
        _, predicted = torch.max(output, 1)

    return {"class_id": predicted.item()}

if __name__ == '__main__':
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

6.3 启动 FastAPI 服务

uvicorn app:app --reload

总结

PyTorch 模型部署有多种方式,包括 TorchScript、ONNX、TorchServe、Flask/Django 和 FastAPI 等。根据你的需求选择合适的部署方式:

  • TorchScript:适合 C++ 或 Python 环境。
  • ONNX:适合跨平台部署。
  • TorchServe:适合高性能模型服务化。
  • Flask/Django/FastAPI:适合封装为 REST API。
  • Docker:适合容器化部署。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小赖同学啊

感谢上帝的投喂

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值