PyTorch 模型部署是将训练好的模型应用到生产环境中的过程。常见的部署方式包括使用 TorchScript、ONNX、TorchServe 和 Flask/Django 等框架。以下是 PyTorch 模型部署的详细教程。
1. 使用 TorchScript 部署
TorchScript 是 PyTorch 提供的一种模型序列化工具,可以将 PyTorch 模型转换为可以在 C++ 或 Python 中运行的脚本。
1.1 将模型转换为 TorchScript
import torch
import torchvision
# 加载预训练模型
model = torchvision.models.resnet18(pretrained=True)
model.eval() # 设置为评估模式
# 创建示例输入
example_input = torch.rand(1, 3, 224, 224)
# 将模型转换为 TorchScript
traced_script_module = torch.jit.trace(model, example_input)
# 保存 TorchScript 模型
traced_script_module.save("resnet18_traced.pt")
1.2 加载 TorchScript 模型
# 加载 TorchScript 模型
loaded_model = torch.jit.load("resnet18_traced.pt")
# 使用模型进行推理
output = loaded_model(example_input)
print(output)
2. 使用 ONNX 部署
ONNX(Open Neural Network Exchange)是一种开放的模型格式,支持跨平台部署。
2.1 将 PyTorch 模型导出为 ONNX
import torch
import torchvision
# 加载预训练模型
model = torchvision.models.resnet18(pretrained=True)
model.eval() # 设置为评估模式
# 创建示例输入
example_input = torch.rand(1, 3, 224, 224)
# 导出模型为 ONNX 格式
torch.onnx.export(
model, # 模型
example_input, # 示例输入
"resnet18.onnx", # 保存路径
input_names=["input"], # 输入名称
output_names=["output"], # 输出名称
dynamic_axes={"input": {0: "batch_size"}, # 支持动态 batch size
opset_version=11 # ONNX 版本
)
2.2 使用 ONNX Runtime 进行推理
pip install onnxruntime
import onnxruntime
import numpy as np
# 加载 ONNX 模型
ort_session = onnxruntime.InferenceSession("resnet18.onnx")
# 准备输入数据
ort_inputs = {ort_session.get_inputs()[0].name: example_input.numpy()}
# 进行推理
ort_outputs = ort_session.run(None, ort_inputs)
print(ort_outputs)
3. 使用 TorchServe 部署
TorchServe 是 PyTorch 官方提供的模型服务化工具,支持高性能的模型部署。
3.1 安装 TorchServe
pip install torchserve torch-model-archiver
3.2 打包模型
torch-model-archiver --model-name resnet18 --version 1.0 --model-file model.py --serialized-file resnet18.pth --handler image_classifier
3.3 启动 TorchServe
torchserve --start --model-store /path/to/model_store --models resnet18=resnet18.mar
3.4 发送请求
curl -X POST http://127.0.0.1:8080/predictions/resnet18 -T example.jpg
4. 使用 Flask/Django 部署
Flask 和 Django 是常用的 Python Web 框架,可以将 PyTorch 模型封装为 REST API。
4.1 使用 Flask 部署
pip install flask
from flask import Flask, request, jsonify
import torch
import torchvision.transforms as transforms
from PIL import Image
app = Flask(__name__)
# 加载模型
model = torchvision.models.resnet18(pretrained=True)
model.eval()
# 图像预处理
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return transform(image).unsqueeze(0)
@app.route('/predict', methods=['POST'])
def predict():
if 'file' not in request.files:
return jsonify({"error": "No file provided"}), 400
file = request.files['file']
image = Image.open(file.stream).convert('RGB')
input_tensor = preprocess_image(image)
with torch.no_grad():
output = model(input_tensor)
_, predicted = torch.max(output, 1)
return jsonify({"class_id": predicted.item()})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
4.2 发送请求
curl -X POST -F "file=@example.jpg" http://127.0.0.1:5000/predict
5. 使用 Docker 部署
Docker 可以将模型和依赖打包成容器,方便部署到任何环境中。
5.1 创建 Dockerfile
FROM pytorch/pytorch:latest
# 安装依赖
RUN pip install flask
# 复制代码
COPY app.py /app/app.py
COPY resnet18.pth /app/resnet18.pth
# 设置工作目录
WORKDIR /app
# 暴露端口
EXPOSE 5000
# 启动应用
CMD ["python", "app.py"]
5.2 构建 Docker 镜像
docker build -t pytorch-flask-app .
5.3 运行 Docker 容器
docker run -p 5000:5000 pytorch-flask-app
6. 使用 FastAPI 部署
FastAPI 是一个高性能的 Python Web 框架,适合部署机器学习模型。
6.1 安装 FastAPI
pip install fastapi uvicorn
6.2 创建 FastAPI 应用
from fastapi import FastAPI, File, UploadFile
import torch
import torchvision.transforms as transforms
from PIL import Image
app = FastAPI()
# 加载模型
model = torchvision.models.resnet18(pretrained=True)
model.eval()
# 图像预处理
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return transform(image).unsqueeze(0)
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
image = Image.open(file.file).convert('RGB')
input_tensor = preprocess_image(image)
with torch.no_grad():
output = model(input_tensor)
_, predicted = torch.max(output, 1)
return {"class_id": predicted.item()}
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
6.3 启动 FastAPI 服务
uvicorn app:app --reload
总结
PyTorch 模型部署有多种方式,包括 TorchScript、ONNX、TorchServe、Flask/Django 和 FastAPI 等。根据你的需求选择合适的部署方式:
- TorchScript:适合 C++ 或 Python 环境。
- ONNX:适合跨平台部署。
- TorchServe:适合高性能模型服务化。
- Flask/Django/FastAPI:适合封装为 REST API。
- Docker:适合容器化部署。