如何布署 PyTorch 深度学习模型
原文可直接运行代码👇
链接: 如何布署 PyTorch 深度学习模型
需要介绍一下 TorchServe,这是一个灵活且易于使用的工具,用于为 PyTorch 模型提供服务。
为什么要布署
有小伙伴问了:模型还要怎么布署呢,我直接 python run.py 运行不行吗?

设想一个最简单的刷脸进门的程序,如果每来一个人都要手动执行一次 python run.py,中午饭点的时候手可能就抽筋了。。。。。
所以我们需要一个能够随时监听请求的服务,来代替我们的双手。
怎么布署
简单来讲,通过网络接收各种协议(HTTP)发送过来的输入数据,调用提前存放的模型进行推理,再返回结果。
当然发送的数据会遵守一定的规范(REST),返回的数据也遵循一定的格式(json, xml),这些细节感兴趣的小伙伴可以自行了解学习。
TorchServe 架构
安装
安装 Torch Serve 以及必要的一些包
%pip install torchserve torch-model-archiver torch-workflow-archiver captum timm
创建一个用于存储模型的目录
!mkdir model_store
权重文件
真实环境中就使用自己训练好的权重,这里方便展示就使用了 imagenet 的预训练模型
import timm
import torch
import torch.nn.utils.prune as prune
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.backbone = timm.create_model('efficientnet_b0', pretrained=True)
def forward(self, x):
x = self.backbone(x)
return x
model = Net()
torch.save(model.state_dict(), 'model.pth.tar')
model.py
model.py 需要包含单个模型的类, 并且能成功加载(torch.load_state_dict)上面的 model.pth
!featurize dataset download ee8b0992-df2e-4ad6-a240-1f486b8eef8b
!cat /home/featurize/data/torchserve/model.py
import torch
import timm
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.backbone = timm.create_model('efficientnet_b0', pretrained=False)
def forward(self, x):
x = self.backbone(x)
return x
(可选)preprocess.py
可选的一些预处理,比如:flip、 resize 等等。
!cat /home/featurize/data/torchserve/preprocess.py
from torchvision.transforms.transforms import Grayscale
from ts.torch_handler.image_classifier import ImageClassifier
from torchvision import transforms
class CustomHandler(ImageClassifier):
image_processing = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
])
模型打包
!wget https://raw.githubusercontent.com/pytorch/serve/master/examples/image_classifier/index_to_name.json
!torch-model-archiver \
--model-name efficientnetb0 \
--handler image_classifier \
--version 1.0 \
--model-file /home/featurize/data/torchserve/model.py \
--serialized-file model.pth.tar \
--export-path model_store \
--extra-files index_to_name.json
启动服务
import os
os.system("torchserve --start --ncs --model-store model_store --models efficientnetb0.mar")
下载测试图片
!curl -O https://raw.githubusercontent.com/pytorch/serve/master/docs/images/kitten_small.jpg
import cv2
import matplotlib.pyplot as plt
plt.imshow(cv2.cvtColor(cv2.imread('kitten_small.jpg'), cv2.COLOR_BGR2RGB));
发送推理请求
%%time
!curl http://127.0.0.1:8080/predictions/efficientnetb0 -T /home/featurize/kitten_small.jpg
{
"tabby": 0.45582714676856995,
"lynx": 0.2556627094745636,
"Egyptian_cat": 0.1583441197872162,
"tiger_cat": 0.04835391417145729,
"tiger": 0.003294622991234064
}CPU times: user 48.2 ms, sys: 52.8 ms, total: 101 ms
Wall time: 1.5 s
停止服务
os.system("torchserve --stop")