Pytorch框架flask部署简单例子—图像识别分类
写在前面
Flask是一种用Python编写的轻量级Web框架,可以帮助您快速构建Web应用程序。
如果我们正在使用PyTorch框架开发深度学习应用程序,并希望将其部署到Web服务器上,则可以使用Flask框架实现。本文将介绍如何使用Flask对前一篇博客中所编写的基于PyTorch框架的图像分类模型进行本地部署,共包含两个py文件(flask_server.py和flask_predict.py),分别表示服务端和客户端,以实现对该模型的远程访问和使用,下文将会详细介绍。(点击这里:基于PyTorch实现经典网络架构的花卉图像分类模型)
在使用Flask部署PyTorch应用程序之前,需要在本地计算机上安装Flask库,若pip install flask下载速度过慢,可换成conda install flask(安装了anaconda3),就能很快下载完毕。

1.flask_server服务端
1.1 初始化flask app
创建一个名为app的Flask对象,并将__name__作为参数传递给它(__name__是一个特殊变量,它表示当前模块的名称,通常用于确定应用程序根目录的位置)。接着创建一个名为model的变量,并将其初始化为None,该变量将用于存储训练好的PyTorch模型。再创建一个名为use_gpu的布尔变量,并将其初始化为False,这个变量将用于控制是否使用GPU加速模型的计算(GPU不错的小伙伴建议为True)。
初始化的流程较为固定,可作为模板进行使用,代码如下:
app = flask.Flask(__name__)
model = None
use_gpu = False
1.2 加载模型
定义一个load_model函数,传入训练模型model、相应结构和参数。需要注意的是,model的值需与训练时所用模型相同(重要!!),同时将model声明为全局变量。
接着重新定义全连接层(102表示最后输出的类别,需根据自身任务来确定),再加载best.pth文件(best.pth存储着训练时效果最好的参数,与前篇博客是同一文件),再使用model.load_state_dict()函数将保存的模型参数加载到我们定义的模型中.
最后使用model.eval()函数将模型设置为验证模式,这将禁用例如dropout和batch normalization等一些训练时的策略,输出分类的概率值。
def load_model():
"""Load the pre-trained model, you can use your model just as easily.
"""
global model
#这里我们直接加载官方工具包里提供的训练好的模型(代码会自动下载)括号内参数为是否下载模型对应的配置信息
model = models.resnet18()
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(nn.Linear(num_ftrs, 102)) # 类别数自己根据自己任务来
#print(model)
checkpoint = torch.load('best.pth')
model.load_state_dict(checkpoint['state_dict'])
#将模型指定为测试格式
model.eval()
#是否使用gpu
if use_gpu:
model.cuda()
1.3 数据预处理
数据预处理部分大致与验证集相似。不同之处在于添加了一个格式转换,有可能请求端所给image的格式不同,因此我们需要将其统一至RGB格式(训练时所用格式)。
def prepare_image(image, target_size):
#针对不同模型,image的格式不同,但需要统一至RGB格式
if image

本文介绍了如何使用Flask将基于PyTorch的图像分类模型部署到Web服务器上,包括服务端的初始化、模型加载、数据预处理和开启服务的步骤,以及客户端如何发送预测请求并获取结果。示例提供了flask_server.py和flask_predict.py两个文件,分别用于服务端和客户端的实现。
最低0.47元/天 解锁文章
112

被折叠的 条评论
为什么被折叠?



