将 PyTorch 模型投入生产
1. Netron 模型可视化工具
Netron 最初是一个 ONNX 可视化工具,现在它已经可以接受所有流行框架导出的模型。目前,根据官方文档,Netron 支持的模型格式包括:ONNX、Keras、CoreML、Caffe2、MXNet、TensorFlow Lite、TensorFlow.js、TensorFlow、Caffe、PyTorch、Torch、CNTK、PaddlePaddle、Darknet 和 scikit - learn。
2. MXNet 模型服务器
2.1 选择 MXNet 模型服务器的原因
与其他服务模块相比,MXNet 表现更优。在编写本文时,TensorFlow 与 Python 3.7 不兼容,而 MXNet 的服务模块与内置的 ONNX 模型集成,这使得开发人员只需几行命令就能轻松部署模型,无需学习分布式或高度可扩展部署的复杂性。
其他模型服务器,如 TensorRT 和 Clipper,在设置和管理方面不如 MXNet 服务器简单。此外,MXNet 还附带了一个名为 MXNet archiver 的实用工具,它可以将所有必需的文件打包成一个独立的包,无需担心其他依赖项。MXNet 模型服务器的最大优点是能够自定义预处理和后处理步骤。
2.2 整体流程
整个过程的流程如下:
graph LR
A[创建 .mar 格式的存档文件] --> B[调用模型服务器并传入存档文件位置]
B --> C[模型由高性能服务器提供服务]
具体步骤如下:
1.
创建 .mar 存档文件
:使用模型归档器创建一个 .mar 格式的单个存档文件。这个文件需要 ONNX 模型文件、
signature.json
(提供输入大小、名称等信息,可视为配置文件,可随时更改,也可将值硬编码到代码中而不使用该文件)以及服务文件(定义预处理、推理函数、后处理函数和其他实用函数)。
2.
启动模型服务器
:调用模型服务器并将模型存档的位置作为输入传递。
2.3 安装 MXNet 模型归档器和服务器
2.3.1 安装 MXNet 模型归档器
默认随 MXNet 模型服务器附带的模型归档器不支持 ONNX,因此需要单独安装。ONNX 模型归档器依赖于协议缓冲区和 MXNet 包本身。每个操作系统的 protobuf 编译器安装指南可在官方文档中找到。MXNet 包可以通过 pip 安装:
pip install mxnet
pip install model-archiver[onnx]
2.3.2 安装 MXNet 模型服务器
MXNet 模型服务器基于 Java 虚拟机(JVM)构建,因此可以从 JVM 调用多个运行模型实例的线程。在 JVM 的支持下,MXNet 服务器可以扩展到多个进程来处理数千个请求。由于模型服务器在 JVM 上运行,需要安装 Java 8。此外,MXNet 模型服务器在 Windows 上仍处于实验模式,但在 Linux 和 Mac 上是稳定的。
pip install mxnet-model-server
2.4 编写生产就绪的 PyTorch 模型代码
2.4.1 创建目录并移动文件
安装完所有先决条件后,我们可以开始使用 MXNet 模型服务器编写生产就绪的 PyTorch 模型代码。首先,创建一个新目录来保存模型归档器所需的所有文件,然后将上一步生成的 .onnx 文件移动到该目录。
2.4.2 编写服务文件
MMS 要求必须有一个包含服务类的服务文件,MMS 会执行该文件中唯一类的
initialize()
和
handle()
函数。以下是服务文件的骨架:
class MXNetModelService(object):
def __init__(self):
...
def initialize(self, context):
...
def preprocess(self, batch):
...
def inference(self, model_input):
...
def postprocess(self, inference_output):
...
def handle(self, data, context):
...
2.4.3 编写签名文件
签名文件是一个配置文件,用于描述数据名称、输入形状和输入类型。以下是 fizbuz 网络的最小签名文件示例:
{
"inputs": [
{
"data_name": "input.1",
"data_shape": [
1,
10
]
}
],
"input_type": "application/json"
}
2.4.4 实现
initialize()
函数
def initialize(self, context):
properties = context.system_properties
model_dir = properties.get("model_dir")
gpu_id = properties.get("gpu_id")
self._batch_size = properties.get('batch_size')
signature_file_path = os.path.join(
model_dir, "signature.json")
if not os.path.isfile(signature_file_path):
raise RuntimeError("Missing signature.json file.")
with open(signature_file_path) as f:
self.signature = json.load(f)
data_names = []
data_shapes = []
input_data = self.signature["inputs"][0]
data_name = input_data["data_name"]
data_shape = input_data["data_shape"]
data_shape[0] = self._batch_size
data_names.append(data_name)
data_shapes.append((data_name, tuple(data_shape)))
self.mxnet_ctx = mx.cpu() if gpu_id is None else
mx.gpu(gpu_id)
sym, arg_params, aux_params = mx.model.load_checkpoint(
checkpoint_prefix, self.epoch)
self.mx_model = mx.mod.Module(
symbol=sym, context=self.mxnet_ctx,
data_names=data_names, label_names=None)
self.mx_model.bind(
for_training=False, data_shapes=data_shapes)
self.mx_model.set_params(
arg_params, aux_params,
allow_missing=True, allow_extra=True)
self.has_initialized = True
2.4.5 实现
handle()
函数
def handle(self, data, context):
try:
if not self.has_initialized:
self.initialize()
preprocess_start = time.time()
data = self.preprocess(data)
inference_start = time.time()
data = self.inference(data)
postprocess_start = time.time()
data = self.postprocess(data)
end_time = time.time()
metrics = context.metrics
metrics.add_time(self.add_first())
metrics.add_time(self.add_second())
metrics.add_time(self.add_third())
return data
except Exception as e:
request_processor = context.request_processor
request_processor.report_status(
500, "Unknown inference error")
return [str(e)] * self._batch_size
2.4.6 实现其他函数
-
preprocess()函数 :
def preprocess(self, batch):
param_name = self.signature['inputs'][0]['data_name']
data = batch[0].get('body').get(param_name)
if data:
self.input = data + 1
tensor = mx.nd.array(
[self.binary_encoder(self.input, input_size=10)])
return tensor
self.error = 'InvalidData'
-
inference()函数 :
def inference(self, model_input):
if self.error is not None:
return None
self.mx_model.forward(DataBatch([model_input]))
model_output = self.mx_model.get_outputs()
return model_output
-
postprocess()函数 :
def postprocess(self, inference_output):
if self.error is not None:
return [self.error] * self._batch_size
prediction = self.get_readable_output(
self.input,
int(inference_output[0].argmax(1).asscalar()))
out = [{'next_number': prediction}]
return out
2.5 创建 .mar 存档文件并启动模型服务器
model-archiver \
--model-name fizbuz_package \
--model-path fizbuz_package \
--handler fizbuz_service -f
mxnet-model-server \
--start \
--model-store FizBuz_with_ONNX \
--models fizbuz_package.mar
现在模型服务器在端口 8080 上运行(如果未更改端口)。可以使用与 Flask 应用程序相同的 curl 命令(当然要更改端口号)来测试模型,结果应与 Flask 应用程序相同,但现在可以根据需要动态调整工作线程的数量。
2.6 管理 API 操作
模型服务器运行时,会在端口 8081 上提供一个管理 API 服务,可通过该服务控制配置。例如,将工作线程数量设置为 1:
curl -v -X PUT "http://localhost:8081/models/fizbuz_package?max_worker=1"
检查服务器状态:
curl "http://localhost:8081/models/fizbuz_package"
3. 负载测试
可以使用 Locust 进行负载测试。以下是一个示例 Locust 脚本,应保存为
locust.py
:
import random
from locust import HttpLocust, TaskSet, task
class UserBehavior(TaskSet):
def on_start(self):
self.url = "/predictions/fizbuz_package"
self.headers = {"Content-Type": "application/json"}
@task(1)
def success(self):
data = {'input.1': random.randint(0, 1000)}
self.client.post(self.url, headers=self.headers,
json=data)
class WebsiteUser(HttpLocust):
task_set = UserBehavior
host = "http://localhost: 8081"
安装 Locust(使用 pip 安装)后,运行
locust
会启动 Locust 服务器并打开 UI,可在其中输入要测试的规模。逐渐增加规模,检查服务器在何时开始出现问题,然后调用管理 API 增加工作线程,确保服务器能够承受负载。
4. TorchScript 实现高效服务
4.1 TorchScript 简介
我们已经搭建了简单的 Flask 应用服务器来提供模型服务,也使用 MXNet 模型服务器实现了相同的模型。但如果要脱离 Python 环境,用 C++ 或 Go 等高效语言构建高性能服务器,PyTorch 推出了 TorchScript。它能将模型转换为 C++ 可读的高效形式。
TorchScript 和使用 ONNX 创建中间表示(IR)的过程类似,但 ONNX 通过跟踪(tracing)创建优化的 IR,即向模型传入一个虚拟输入,在模型执行时记录 PyTorch 操作,再将这些操作转换为中间 IR。然而,这种方法存在问题:如果模型依赖于数据,如 RNN 中的循环或基于输入的 if/else 条件判断,跟踪就无法准确处理。TorchScript 正是为解决此类问题而引入的。
4.2 TorchScript 的优势
- 支持数据依赖操作 :TorchScript 支持一部分 Python 控制流,只需将现有程序转换为 TorchScript 支持的控制流形式即可。它创建的中间阶段可以被 LibTorch 读取。
- 优化模型并可脱离 Python 运行 :TorchScript 可以序列化和优化用 PyTorch 编写的模型,将其保存为 IR 格式。与 ONNX 不同的是,这种 IR 经过优化,适合在生产环境中运行。保存的 TorchScript 模型可以在不依赖 Python 的环境中加载。由于性能和多线程的原因,Python 一直是生产部署的瓶颈,而 TorchScript 解决了这个问题,它引入了基于 C++ 的运行时和高级 API,方便开发者使用 C++ 进行编码。
4.3 创建 TorchScript IR 的方法
PyTorch 提供了两种创建 TorchScript IR 的方法:
-
跟踪(Tracing)
:和 ONNX 类似,将模型(甚至是函数)与虚拟输入一起传递给
torch.jit.trace
。PyTorch 会将虚拟输入通过模型/函数,并在运行时跟踪操作。跟踪得到的函数(PyTorch 操作)可以转换为优化的 IR,也称为静态单赋值 IR。但这种方法存在和 ONNX 相同的问题,无法处理依赖于数据的模型结构变化。
-
脚本模式(Scripting)
:对于普通函数,可以使用
torch.jit.script
装饰器;对于 PyTorch 模型的方法,可以使用
torch.jit.script_method
装饰器。使用
torch.jit.script_method
装饰模型类时,需要从
torch.jit.ScriptModule
继承,以避免使用无法转换为 TorchScript 的纯 Python 方法。目前,TorchScript 虽然不支持所有 Python 特性,但具备支持数据依赖张量操作的必要特性。
4.4 C++ 实现 fizbuz 模型
以下是将 fizbuz 模型导出为 ScriptModule IR 的代码:
net = FizBuzNet(input_size, hidden_size, output_size)
traced = torch.jit.trace(net, dummy_input)
traced.save('fizbuz.pt')
保存的模型可以使用
torch.load()
方法在 Python 中加载,我们将使用 LibTorch 在 C++ 中加载该模型。首先,导入所需的头文件:
#include <torch/script.h>
#include <iostream>
#include <memory>
#include <string>
torch/script.h
是最重要的头文件,它包含了 LibTorch 的所有必要方法和函数。我们通过命令行参数传递模型名称和样本输入,以下是
main
函数的部分代码,用于读取和解析命令行参数:
std::string arg = argv[2];
int x = std::stoi(arg);
float array[10];
4.5 总结
本文介绍了将 PyTorch 模型投入生产的多种方法,包括使用 Netron 可视化模型、MXNet 模型服务器提供服务、Locust 进行负载测试以及 TorchScript 实现高效服务。通过这些方法,我们可以将 PyTorch 模型部署到生产环境中,并根据实际需求进行性能优化和资源管理。具体步骤总结如下表:
| 步骤 | 方法 | 操作 |
| — | — | — |
| 1 | 模型可视化 | 使用 Netron 可视化多种框架的模型 |
| 2 | 模型服务 | 安装 MXNet 模型归档器和服务器,编写服务文件和签名文件,创建 .mar 存档文件并启动模型服务器 |
| 3 | 负载测试 | 使用 Locust 脚本进行负载测试,根据测试结果调整服务器资源 |
| 4 | 高效服务 | 使用 TorchScript 将模型转换为 C++ 可读的高效形式,在 C++ 中加载和运行模型 |
通过这些步骤,我们可以将 PyTorch 模型从开发环境顺利过渡到生产环境,实现高效、稳定的服务。
超级会员免费看
1467

被折叠的 条评论
为什么被折叠?



