Pytorch 测试模型的推理速度

# TODO - 计算模型的推理时间
def calcTime():

    import numpy as np
    from torchvision.models import resnet50
    import torch
    from torch.backends import cudnn
    import tqdm

    '''  导入你的模型
    from module.amsnet import amsnet, anet, msnet, iresnet18, anet2, iresnet2, amsnet2
    from module.resnet import resnet18, resnet34
    from module.alexnet import AlexNet
    from module.vgg import vgg
    from module.lenet import LeNet
    from module.googLenet import GoogLeNet
    from module.ivgg import iVGG
    '''


    cudnn.benchmark = True

    device = 'cuda:0'
    model = anet().to(device)
    repetitions = 1000

    dummy_input = torch.rand(1, 3, 224, 224).to(device)

    # 预热, GPU 平时可能为了节能而处于休眠状态, 因此需要预热
    print('warm up ...\n')
    with torch.no_grad():
        for _ in range(100):
            _ = model(dummy_input)

    # synchronize 等待所有 GPU 任务处理完才返回 CPU 主线程
    torch.cuda.synchronize()

    # 设置用于测量时间的 cuda Event, 这是PyTorch 官方推荐的接口,理论上应该最靠谱
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    # 初始化一个时间容器
    timings = np.zeros((repetitions, 1))

    print('testing ...\n')
    with torch.no_grad():
        for rep in tqdm.tqdm(range(repetitions)):
            starter.record()
            _ = model(dummy_input)
            ender.record()
            torch.cuda.synchronize()  # 等待GPU任务完成
            curr_time = starter.elapsed_time(ender)  # 从 starter 到 ender 之间用时,单位为毫秒
            timings[rep] = curr_time

    avg = timings.sum() / repetitions
    print('\navg={}\n'.format(avg))

### PyTorch 模型推理部署流程 #### 一、模型准备阶段 确保 PyTorch 模型已经过充分训练并能在本地环境中正常运行,完成必要的测试验证工作[^1]。 #### 二、模型转换至ONNX格式 为了提高跨平台兼容性和优化性能,需先将经过验证的 PyTorch 模型导出为 ONNX (Open Neural Network Exchange) 格式的文件。此过程涉及调用 `torch.onnx.export` 函数,并指定输入张量形状和其他参数配置。 ```python import torch dummy_input = torch.randn(1, 3, 224, 224) # 假设图像尺寸为224x224 model.eval() # 切换到评估模式 torch.onnx.export(model, dummy_input, "model.onnx", export_params=True) ``` #### 三、集成ONNX Runtime环境 安装适用于目标操作系统的 ONNX Runtime 库版本,在 Python 中可通过 pip 安装 onnxruntime 包实现快速设置。之后加载之前保存下来的 .onnx 文件作为新的计算图执行对象。 ```python import onnxruntime as ort sess = ort.InferenceSession('model.onnx') outputs = sess.run(None, {'input': input_array}) ``` #### 四、构建Web服务接口 采用 Flask 或 FastAPI 等轻量化框架搭建 RESTful API 接口,用于接收来自前端或其他应用程序发送来的待预测数据请求。当接收到 POST 请求时,解析上传的数据流,预处理成适合喂入神经网络的形式;接着利用已初始化好的 Session 对象来进行前向传播运算得到最终输出结果;最后整理好响应体结构反馈回去给调用方[^2]。 ```python from flask import Flask, request, jsonify app = Flask(__name__) @app.route('/predict', methods=['POST']) def predict(): file = request.files['image'] img_bytes = file.read() # 图像预处理逻辑... output = model.predict(img_tensor) return jsonify({'prediction': str(output)}) if __name__ == '__main__': app.run(debug=True) ``` #### 五、发布与维护 将上述开发完毕的服务程序打包上线至云服务器或容器化集群内,持续监控其健康状态以及日志记录情况以便及时发现潜在问题并作出相应调整措施。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Ray Song

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值