Pytorch 测试模型的推理速度

该代码段展示了如何使用PyTorch测量不同模型(如AlexNet,VGG,ResNet等)在GPU上的推理时间。首先,它进行了预热以确保GPU达到最佳状态,然后利用cudaEvent进行精确的时间测量,以计算平均推理时间。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

# TODO - 计算模型的推理时间
def calcTime():

    import numpy as np
    from torchvision.models import resnet50
    import torch
    from torch.backends import cudnn
    import tqdm

    '''  导入你的模型
    from module.amsnet import amsnet, anet, msnet, iresnet18, anet2, iresnet2, amsnet2
    from module.resnet import resnet18, resnet34
    from module.alexnet import AlexNet
    from module.vgg import vgg
    from module.lenet import LeNet
    from module.googLenet import GoogLeNet
    from module.ivgg import iVGG
    '''


    cudnn.benchmark = True

    device = 'cuda:0'
    model = anet().to(device)
    repetitions = 1000

    dummy_input = torch.rand(1, 3, 224, 224).to(device)

    # 预热, GPU 平时可能为了节能而处于休眠状态, 因此需要预热
    print('warm up ...\n')
    with torch.no_grad():
        for _ in range(100):
            _ = model(dummy_input)

    # synchronize 等待所有 GPU 任务处理完才返回 CPU 主线程
    torch.cuda.synchronize()

    # 设置用于测量时间的 cuda Event, 这是PyTorch 官方推荐的接口,理论上应该最靠谱
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    # 初始化一个时间容器
    timings = np.zeros((repetitions, 1))

    print('testing ...\n')
    with torch.no_grad():
        for rep in tqdm.tqdm(range(repetitions)):
            starter.record()
            _ = model(dummy_input)
            ender.record()
            torch.cuda.synchronize()  # 等待GPU任务完成
            curr_time = starter.elapsed_time(ender)  # 从 starter 到 ender 之间用时,单位为毫秒
            timings[rep] = curr_time

    avg = timings.sum() / repetitions
    print('\navg={}\n'.format(avg))

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Ray Song

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值