使用Flask框架部署pytorch模型Rest接口是最简答快捷的方式,但如果服务要求服务具有更高的处理性能,比如高并发、低时延等,使用Flask部署可能就不太合适了。
比如直接使用Flask框架提供的web服务启动Rest接口时,会有如下提示:
提示说这个web服务可以用来开发测试,不要用来进行生产部署,生产部署可以使用WSGI服务进行替代。实际压测中也会发现,如果用Flask自带的web服务,性能极不稳定。
对于上述问题,可以使用TorchScript对模型进行转换,然后使用C++进行调用,可以提升服务处理性能。但是C++调用有一定技术门槛,我一般采用一种折中的方法,使用gunicorn框架提供的WSGI服务来部署Flask框架的Rest接口,也就是将Flask代码部署在gunicorn框架提供的WSGI服务上,可以显著提升Rest接口的并发处理能力和稳定性,但是并不能提升模型的推理速度,提升推理速度还是进行模型的量化压缩、转换成C++调用这些方法靠谱。
1、创建app.py
import io
import json
import os
from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request
os.environ["TORCH_HOME"] = "./models"
app = Flask(__name__)
imagenet_class_index = json.load(open('imagenet_class_index.json'))
model = models.densenet121(pretrained=True)
model.eval()
def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)
def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
return imagenet_class_index[predicted_idx]
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
file = request.files['file']
img_bytes = file.read()
class_id, class_name = get_prediction(image_bytes=img_bytes)
return jsonify({'class_id': class_id, 'class_name': class_name})
if __name__ == '__main__':
app.run()
2、创建Gunicorn配置文件gunicorn.conf
bind = '127.0.0.1:8088'
workers = 2
threads = 8
backlog = 16
worker_class = 'gevent'
debug = False
# chdir为项目完整的绝对路径
chdir = 'XXX'
accesslog = "./gunicorn_access.log"
errorlog = "./gunicorn_error.log"
3、使用Gunicorn提供的WSGI服务启动Flask代码
gunicorn -c ./gunicorn.conf app:app
注意:gunicorn与args命令行参数共用会报错
注意:gunicorn与args命令行参数共用会报错
注意:gunicorn与args命令行参数共用会报错
解决办法:
args = parser.parse_args()改为args = parser.parse_args(args=[])