PyTorch学习笔记-8.PyTorch深度体验

8.PyTorch深度体验

8.1.图像分类预测

模型如何完成图像分类?

将图像转换为tensor --> 模型 --> 输出向量 --> 取向量的最大值作为预测结果

 

代码基本步骤:
1. 获取数据与模型
2. 数据变换,如RGB → 4D-Tensor
3. 前向传播
4. 输出保存预测结果

注意事项:
1. 确保 model处于eval状态而非training
2. 设置torch.no_grad(),减少内存消耗
3. 数据预处理需保持一致, RGB or BGR

 

代码实现:

以模型微调中的分类模型为例,进行预测

# -*- coding: utf-8 -*-
import os
import time
import torch.nn as nn
import torch
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.models as models
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

# 配置可视化开关
vis = True
# vis = False
# 设置可视化时每行有几张图片
vis_row = 4

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

# 对数据预处理,注意和模型的预处理方式保持一致
inference_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 标签的分类
classes = ["ants""bees"]

# 将图片进行预处理得到张量,即数据转换为模型读取的形式
def img_transform(img_rgb, transform=None):

    if transform is None:
        raise ValueError("找不到transform!必须有transform对img进行处理")

    img_t = transform(img_rgb)
    return img_t

# 获取文件夹下format格式的文件名
def get_img_name(img_dir, format="jpg"):

    file_names = os.listdir(img_dir)
    img_names = list(filter(lambda x: x.endswith(format), file_names))

    if len(img_names) < 1:
        raise ValueError("{}下找不到{}格式数据".format(img_dir, format))
    return img_names


def get_model(m_path, vis_model=False):

    # 创建resnet18模型
    resnet18 = models.resnet18()
    # 获取全连接层的输入参数
    num_ftrs = resnet18.fc.in_features
    # 自定义全连接层,设置输入为2,即二分类
    resnet18.fc = nn.Linear(num_ftrs, 2)

    # 根据模型路径加载模型,为检查点
    checkpoint = torch.load(m_path)
    # 将参数加载到模型中
    resnet18.load_state_dict(checkpoint['model_state_dict'])

    # 打印模型的信息,如每一层的类型、shape 和 参数量等
    if vis_model:
        # 需要导入torchsummary包,安装:pip install torchsummary
        from torchsummary import summary
        summary(resnet18, input_size=(3, 224, 224), device="cpu")

    return resnet18


if __name__ == "__main__":
    # 指定数据路径
    img_dir = os.path.join("..""..""data/hymenoptera_data/val/ants")
    # 指定模型路径,为蚂蚁蜜蜂之前保存的检查点
    model_path = "./checkpoint_24_epoch.pkl"
    # 用来统计总预测时间
    time_total = 0
    # 定义两个list,分别用来存放图片和对应的预测类型
    img_list, img_pred = list(), list()

    # 1. data
    # 获取指定路径下所有文件名
    img_names = get_img_name(img_dir)
    # 获取文件的数量
    num_img = len(img_names)

    # 2. model
    # 获取模型
    resnet18 = get_model(model_path, True)
    # 将模型放到指定设备上
    resnet18.to(device)
    # 将模型设置为测试模式
    resnet18.eval()

    # 下面的所有运算无需保存梯度
    with torch.no_grad():
        # 遍历文件名
        for idx, img_name in enumerate(img_names):

            # 拼接每个文件的全路径
            path_img = os.path.join(img_dir, img_name)

            # step 1/4 : path --> img 根据路径读取rgb图片
            img_rgb = Image.open(path_img).convert('RGB')

            # step 2/4 : img --> tensor 将rgb图像转换为张量
            img_tensor = img_transform(img_rgb, inference_transform)
            # 3d --> 4d
            img_tensor.unsqueeze_(0)
            # 将数据放到指定设备上
            img_tensor = img_tensor.to(device)

            # step 3/4 : tensor --> vector
            # 统计运行时间
            time_tic = time.time()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值