PVT图像二分类推理v1.0

from PIL import Image
import cv2
import numpy as np
import onnxruntime as ort
import torchvision.transforms as transforms


# 定义默认的均值和标准差
# 这些值是针对 ImageNet 数据集计算出来的均值和标准差,它们被广泛用于预训练模型上,特别是那些在 ImageNet 上预训练的模型。这些值是基于 ImageNet 数据集中的 RGB 图像计算得出的
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]

def build_transform(is_train, args):
    resize_im = args.input_size > 32
    if is_train:
        # this should always dispatch to transforms_imagenet_train
        transform = transforms.Compose([
            transforms.RandomResizedCrop(args.input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
        ])
        if not resize_im:
            # replace RandomResizedCropAndInterpolation with
            # RandomCrop
            transform.transforms[0] = transforms.RandomCrop(
                args.input_size, padding=4)
        return transform

    t = []
    if resize_im:
        size = int((256 / 224) * args.input_size)
        t.append(
            transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),  # to maintain same ratio w.r.t. 224 images
        )
        t.append(transforms.CenterCrop(args.input_size))  # 裁切到input_size尺寸

    t.append(transforms.ToTensor())
    t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
    return transforms.Compose(t)


def preprocess_image(image_path, target_size=(224, 224), is_train=False, args=None):
    """
    Preprocesses the input image.
    """
    # Load the image
    image = cv2.imread(image_path)
    # Convert to RGB
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Convert the NumPy array to a PIL Image
    image = Image.fromarray(image)


    # 输入是Nunmpy数组
    # Apply transformations
    transform = build_transform(is_train, args)
    image = transform(image)

    # Convert to numpy array
    image = np.array(image)

    # Add batch dimension
    image = np.expand_dims(image, axis=0)

    return image

    # 输入是Pytorch张量
    # # Apply transformations
    # transform = build_transform(is_train, args)
    # image = transform(image)

    # # Add batch dimension
    # image = image.unsqueeze(0)

    # return image


def run_inference(onnx_model_path, image_path, is_train=False, args=None):
    """
    Runs inference on the input image using the ONNX model.
    """
    # Load the ONNX model
    session = ort.InferenceSession(onnx_model_path)

    # Get the input name
    input_name = session.get_inputs()[0].name

    # Preprocess the image
    preprocessed_image = preprocess_image(image_path, is_train=is_train, args=args)

    # Run inference
    outputs = session.run(None, {input_name: preprocessed_image})
    # Get the predicted class
    predicted_class = np.argmax(outputs[0])

    return predicted_class

if __name__ == "__main__":
    # Path to the ONNX model
    onnx_model_path = 'checkpoints/pvt_v2_b5/checkpoint_0.onnx'
    # Path to the input image
    image_path = '/home/nvidia/aigc/classify/PVT/classification/AIGC/val/aigc_0/0002.png'
    
    # Define arguments
    args = type('', (), {})()  # Empty class for arguments
    args.input_size = 224  # Input size of the model
    args.color_jitter = 0.4  # Color jitter factor
    args.aa = 'rand'  # Auto augmentation policy
    args.train_interpolation = 'bicubic'  # Interpolation method
    args.reprob = 0.25  # Probability of performing mixup or cutmix when training
    args.remode = 'pixel'  # How to apply mixup/cutmix params. Per-pixel (default), per-labeled-instance, per-image
    args.recount = 1  # Number of mixes/cuts per image
    args.is_train = False  # Whether we are training or not

    # Run inference
    prediction = run_inference(onnx_model_path, image_path, is_train=False, args=args)

    # Print the prediction
    print(f"The predicted class is: {prediction}")


"""
[BUG1]
  File "infer.py", line 101, in <module>
    prediction = run_inference(onnx_model_path, image_path, is_train=False, args=args)
  File "infer.py", line 74, in run_inference
    preprocessed_image = preprocess_image(image_path, is_train=is_train, args=args)
  File "infer.py", line 52, in preprocess_image
    image = transform(image)
报错:    raise TypeError(f"img should be PIL Image. Got {type(img)}")
TypeError: img should be PIL Image. Got <class 'numpy.ndarray'>

[debug]
问题在于 preprocess_image 函数中的图像数据类型不匹配。
torchvision.transforms 中的变换函数期望输入的是一个 PIL Image 对象,而您的代码中使用 cv2.imread 加载的图像是一个 NumPy 数组。

为了修复这个问题,您需要将 NumPy 数组转换为 PIL Image 对象,然后再应用变换。您可以使用 PIL.Image.fromarray 方法来实现这一转换。

"""
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

之群害马

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值