19、将 PyTorch 模型投入生产

将 PyTorch 模型投入生产

1. Netron 模型可视化工具

Netron 最初是一个 ONNX 可视化工具,现在它已经可以接受所有流行框架导出的模型。目前,根据官方文档,Netron 支持的模型格式包括:ONNX、Keras、CoreML、Caffe2、MXNet、TensorFlow Lite、TensorFlow.js、TensorFlow、Caffe、PyTorch、Torch、CNTK、PaddlePaddle、Darknet 和 scikit - learn。

2. MXNet 模型服务器

2.1 选择 MXNet 模型服务器的原因

与其他服务模块相比,MXNet 表现更优。在编写本文时,TensorFlow 与 Python 3.7 不兼容,而 MXNet 的服务模块与内置的 ONNX 模型集成,这使得开发人员只需几行命令就能轻松部署模型,无需学习分布式或高度可扩展部署的复杂性。

其他模型服务器,如 TensorRT 和 Clipper,在设置和管理方面不如 MXNet 服务器简单。此外,MXNet 还附带了一个名为 MXNet archiver 的实用工具,它可以将所有必需的文件打包成一个独立的包,无需担心其他依赖项。MXNet 模型服务器的最大优点是能够自定义预处理和后处理步骤。

2.2 整体流程

整个过程的流程如下:

graph LR
    A[创建 .mar 格式的存档文件] --> B[调用模型服务器并传入存档文件位置]
    B --> C[模型由高性能服务器提供服务]

具体步骤如下:
1. 创建 .mar 存档文件 :使用模型归档器创建一个 .mar 格式的单个存档文件。这个文件需要 ONNX 模型文件、 signature.json (提供输入大小、名称等信息,可视为配置文件,可随时更改,也可将值硬编码到代码中而不使用该文件)以及服务文件(定义预处理、推理函数、后处理函数和其他实用函数)。
2. 启动模型服务器 :调用模型服务器并将模型存档的位置作为输入传递。

2.3 安装 MXNet 模型归档器和服务器

2.3.1 安装 MXNet 模型归档器

默认随 MXNet 模型服务器附带的模型归档器不支持 ONNX,因此需要单独安装。ONNX 模型归档器依赖于协议缓冲区和 MXNet 包本身。每个操作系统的 protobuf 编译器安装指南可在官方文档中找到。MXNet 包可以通过 pip 安装:

pip install mxnet
pip install model-archiver[onnx]
2.3.2 安装 MXNet 模型服务器

MXNet 模型服务器基于 Java 虚拟机(JVM)构建,因此可以从 JVM 调用多个运行模型实例的线程。在 JVM 的支持下,MXNet 服务器可以扩展到多个进程来处理数千个请求。由于模型服务器在 JVM 上运行,需要安装 Java 8。此外,MXNet 模型服务器在 Windows 上仍处于实验模式,但在 Linux 和 Mac 上是稳定的。

pip install mxnet-model-server

2.4 编写生产就绪的 PyTorch 模型代码

2.4.1 创建目录并移动文件

安装完所有先决条件后,我们可以开始使用 MXNet 模型服务器编写生产就绪的 PyTorch 模型代码。首先,创建一个新目录来保存模型归档器所需的所有文件,然后将上一步生成的 .onnx 文件移动到该目录。

2.4.2 编写服务文件

MMS 要求必须有一个包含服务类的服务文件,MMS 会执行该文件中唯一类的 initialize() handle() 函数。以下是服务文件的骨架:

class MXNetModelService(object):
    def __init__(self):
        ...
    def initialize(self, context):
        ...
    def preprocess(self, batch):
        ...
    def inference(self, model_input):
        ...
    def postprocess(self, inference_output):
        ...
    def handle(self, data, context):
        ...
2.4.3 编写签名文件

签名文件是一个配置文件,用于描述数据名称、输入形状和输入类型。以下是 fizbuz 网络的最小签名文件示例:

{
  "inputs": [
    {
      "data_name": "input.1",
      "data_shape": [
        1,
        10
      ]
    }
  ],
  "input_type": "application/json"
}
2.4.4 实现 initialize() 函数
def initialize(self, context):
    properties = context.system_properties
    model_dir = properties.get("model_dir")
    gpu_id = properties.get("gpu_id")
    self._batch_size = properties.get('batch_size')
    signature_file_path = os.path.join(
        model_dir, "signature.json")
    if not os.path.isfile(signature_file_path):
        raise RuntimeError("Missing signature.json file.")
    with open(signature_file_path) as f:
        self.signature = json.load(f)
    data_names = []
    data_shapes = []
    input_data = self.signature["inputs"][0]
    data_name = input_data["data_name"]
    data_shape = input_data["data_shape"]
    data_shape[0] = self._batch_size
    data_names.append(data_name)
    data_shapes.append((data_name, tuple(data_shape)))
    self.mxnet_ctx = mx.cpu() if gpu_id is None else 
        mx.gpu(gpu_id)
    sym, arg_params, aux_params = mx.model.load_checkpoint( 
        checkpoint_prefix, self.epoch)
    self.mx_model = mx.mod.Module(
        symbol=sym, context=self.mxnet_ctx,
        data_names=data_names, label_names=None)
    self.mx_model.bind(
        for_training=False, data_shapes=data_shapes)
    self.mx_model.set_params(
        arg_params, aux_params,
        allow_missing=True, allow_extra=True)
    self.has_initialized = True
2.4.5 实现 handle() 函数
def handle(self, data, context):
    try:
        if not self.has_initialized:
            self.initialize()
        preprocess_start = time.time()
        data = self.preprocess(data)
        inference_start = time.time()
        data = self.inference(data)
        postprocess_start = time.time()
        data = self.postprocess(data)
        end_time = time.time()
        metrics = context.metrics
        metrics.add_time(self.add_first())
        metrics.add_time(self.add_second())
        metrics.add_time(self.add_third())
        return data
    except Exception as e:
        request_processor = context.request_processor
        request_processor.report_status(
            500, "Unknown inference error")
        return [str(e)] * self._batch_size
2.4.6 实现其他函数
  • preprocess() 函数
def preprocess(self, batch):
    param_name = self.signature['inputs'][0]['data_name']
    data = batch[0].get('body').get(param_name)
    if data:
        self.input = data + 1
        tensor = mx.nd.array(
            [self.binary_encoder(self.input, input_size=10)])
        return tensor
    self.error = 'InvalidData'
  • inference() 函数
def inference(self, model_input):
    if self.error is not None:
        return None
    self.mx_model.forward(DataBatch([model_input]))
    model_output = self.mx_model.get_outputs()
    return model_output
  • postprocess() 函数
def postprocess(self, inference_output):
    if self.error is not None:
        return [self.error] * self._batch_size
    prediction = self.get_readable_output(
        self.input,
        int(inference_output[0].argmax(1).asscalar()))
    out = [{'next_number': prediction}]
    return out

2.5 创建 .mar 存档文件并启动模型服务器

model-archiver \
        --model-name fizbuz_package \
        --model-path fizbuz_package \
        --handler fizbuz_service -f
mxnet-model-server \
        --start \
        --model-store FizBuz_with_ONNX \
        --models fizbuz_package.mar

现在模型服务器在端口 8080 上运行(如果未更改端口)。可以使用与 Flask 应用程序相同的 curl 命令(当然要更改端口号)来测试模型,结果应与 Flask 应用程序相同,但现在可以根据需要动态调整工作线程的数量。

2.6 管理 API 操作

模型服务器运行时,会在端口 8081 上提供一个管理 API 服务,可通过该服务控制配置。例如,将工作线程数量设置为 1:

curl -v -X PUT  "http://localhost:8081/models/fizbuz_package?max_worker=1"

检查服务器状态:

curl "http://localhost:8081/models/fizbuz_package"

3. 负载测试

可以使用 Locust 进行负载测试。以下是一个示例 Locust 脚本,应保存为 locust.py

import random
from locust import HttpLocust, TaskSet, task

class UserBehavior(TaskSet):
    def on_start(self):
        self.url = "/predictions/fizbuz_package"
        self.headers = {"Content-Type": "application/json"}

    @task(1)
    def success(self):
        data = {'input.1': random.randint(0, 1000)}
        self.client.post(self.url, headers=self.headers,  
                         json=data)

class WebsiteUser(HttpLocust):
    task_set = UserBehavior
    host = "http://localhost: 8081"

安装 Locust(使用 pip 安装)后,运行 locust 会启动 Locust 服务器并打开 UI,可在其中输入要测试的规模。逐渐增加规模,检查服务器在何时开始出现问题,然后调用管理 API 增加工作线程,确保服务器能够承受负载。

4. TorchScript 实现高效服务

4.1 TorchScript 简介

我们已经搭建了简单的 Flask 应用服务器来提供模型服务,也使用 MXNet 模型服务器实现了相同的模型。但如果要脱离 Python 环境,用 C++ 或 Go 等高效语言构建高性能服务器,PyTorch 推出了 TorchScript。它能将模型转换为 C++ 可读的高效形式。

TorchScript 和使用 ONNX 创建中间表示(IR)的过程类似,但 ONNX 通过跟踪(tracing)创建优化的 IR,即向模型传入一个虚拟输入,在模型执行时记录 PyTorch 操作,再将这些操作转换为中间 IR。然而,这种方法存在问题:如果模型依赖于数据,如 RNN 中的循环或基于输入的 if/else 条件判断,跟踪就无法准确处理。TorchScript 正是为解决此类问题而引入的。

4.2 TorchScript 的优势

  • 支持数据依赖操作 :TorchScript 支持一部分 Python 控制流,只需将现有程序转换为 TorchScript 支持的控制流形式即可。它创建的中间阶段可以被 LibTorch 读取。
  • 优化模型并可脱离 Python 运行 :TorchScript 可以序列化和优化用 PyTorch 编写的模型,将其保存为 IR 格式。与 ONNX 不同的是,这种 IR 经过优化,适合在生产环境中运行。保存的 TorchScript 模型可以在不依赖 Python 的环境中加载。由于性能和多线程的原因,Python 一直是生产部署的瓶颈,而 TorchScript 解决了这个问题,它引入了基于 C++ 的运行时和高级 API,方便开发者使用 C++ 进行编码。

4.3 创建 TorchScript IR 的方法

PyTorch 提供了两种创建 TorchScript IR 的方法:
- 跟踪(Tracing) :和 ONNX 类似,将模型(甚至是函数)与虚拟输入一起传递给 torch.jit.trace 。PyTorch 会将虚拟输入通过模型/函数,并在运行时跟踪操作。跟踪得到的函数(PyTorch 操作)可以转换为优化的 IR,也称为静态单赋值 IR。但这种方法存在和 ONNX 相同的问题,无法处理依赖于数据的模型结构变化。
- 脚本模式(Scripting) :对于普通函数,可以使用 torch.jit.script 装饰器;对于 PyTorch 模型的方法,可以使用 torch.jit.script_method 装饰器。使用 torch.jit.script_method 装饰模型类时,需要从 torch.jit.ScriptModule 继承,以避免使用无法转换为 TorchScript 的纯 Python 方法。目前,TorchScript 虽然不支持所有 Python 特性,但具备支持数据依赖张量操作的必要特性。

4.4 C++ 实现 fizbuz 模型

以下是将 fizbuz 模型导出为 ScriptModule IR 的代码:

net = FizBuzNet(input_size, hidden_size, output_size)
traced = torch.jit.trace(net, dummy_input)
traced.save('fizbuz.pt')

保存的模型可以使用 torch.load() 方法在 Python 中加载,我们将使用 LibTorch 在 C++ 中加载该模型。首先,导入所需的头文件:

#include <torch/script.h>
#include <iostream>
#include <memory>
#include <string>

torch/script.h 是最重要的头文件,它包含了 LibTorch 的所有必要方法和函数。我们通过命令行参数传递模型名称和样本输入,以下是 main 函数的部分代码,用于读取和解析命令行参数:

std::string arg = argv[2];
int x = std::stoi(arg);
float array[10];

4.5 总结

本文介绍了将 PyTorch 模型投入生产的多种方法,包括使用 Netron 可视化模型、MXNet 模型服务器提供服务、Locust 进行负载测试以及 TorchScript 实现高效服务。通过这些方法,我们可以将 PyTorch 模型部署到生产环境中,并根据实际需求进行性能优化和资源管理。具体步骤总结如下表:
| 步骤 | 方法 | 操作 |
| — | — | — |
| 1 | 模型可视化 | 使用 Netron 可视化多种框架的模型 |
| 2 | 模型服务 | 安装 MXNet 模型归档器和服务器,编写服务文件和签名文件,创建 .mar 存档文件并启动模型服务器 |
| 3 | 负载测试 | 使用 Locust 脚本进行负载测试,根据测试结果调整服务器资源 |
| 4 | 高效服务 | 使用 TorchScript 将模型转换为 C++ 可读的高效形式,在 C++ 中加载和运行模型 |

通过这些步骤,我们可以将 PyTorch 模型从开发环境顺利过渡到生产环境,实现高效、稳定的服务。

基于可靠性评估序贯蒙特卡洛模拟法的配电网可靠性评估研究(Matlab代码实现)内容概要:本文围绕“基于可靠性评估序贯蒙特卡洛模拟法的配电网可靠性评估研究”,介绍了利用Matlab代码实现配电网可靠性的仿真分析方法。重点采用序贯蒙特卡洛模拟法对配电网进行长时间段的状态抽样与统计,通过模拟系统元件的故障与修复过程,评估配电网的关键可靠性指标,如系统停电频率、停电持续时间、负荷点可靠性等。该方法能够有效处理复杂网络结构与设备时序特性,提升评估精度,适用于含分布式电源、电动汽车等新型负荷接入的现代配电网。文中提供了完整的Matlab实现代码与案例分析,便于复现和扩展应用。; 适合人群:具备电力系统基础知识和Matlab编程能力的高校研究生、科研人员及电力行业技术人员,尤其适合从事配电网规划、运行与可靠性分析相关工作的人员; 使用场景及目标:①掌握序贯蒙特卡洛模拟法在电力系统可靠性评估中的基本原理与实现流程;②学习如何通过Matlab构建配电网仿真模型并进行状态转移模拟;③应用于含新能源接入的复杂配电网可靠性定量评估与优化设计; 阅读建议:建议结合文中提供的Matlab代码逐段调试运行,理解状态抽样、故障判断、修复逻辑及指标统计的具体实现方式,同时可扩展至不同网络结构或加入更多不确定性因素进行深化研究。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值