torchvision.transforms.Resize()函数解读

本文介绍了一个用于调整PILImage对象大小的函数,该函数可以将图像缩放到指定的尺寸,支持按比例缩放,并提供了多种插值选项以优化图像质量。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

  • 函数作用
    对于PIL Image对象进行resize的运算。
    在这里插入图片描述
  • 函数源码
class Resize(object):
    """Resize the input PIL Image to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
    """

    def __init__(self, size, interpolation=Image.BILINEAR):
        assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be scaled.

        Returns:
            PIL Image: Rescaled image.
        """
        return F.resize(img, self.size, self.interpolation)

    def __repr__(self):
        interpolate_str = _pil_interpolation_to_str[self.interpolation]
        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)

参考链接

参考下面python预测代码,写出C++版本的预测代码import torch import torch.nn as nn from torchvision import transforms from torchvision.models import resnet18, ResNet18_Weights from torchvision.models import resnet50, ResNet50_Weights import os from PIL import Image import matplotlib.pyplot as plt import numpy as np def on_key_press(event, fig): if event.key == ' ': plt.close(fig) def predict_images_2(model_path, data_dir, classes): model = resnet18(weights=None) # model = resnet50(weights=None) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, len(classes)) # 将模型加载到 CPU 上 model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.eval() transform = transforms.Compose([ transforms.Resize((50, 50)), transforms.ToTensor(), transforms.RandomAutocontrast(p=1), # 调整对比度,让亮的更亮,暗的更暗 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) if not os.path.exists(data_dir): print(f"Error: The directory {data_dir} does not exist.") return image_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, f))] softmax = nn.Softmax(dim=1) for image_file in image_files: image = Image.open(image_file).convert('RGB') input_tensor = transform(image).unsqueeze(0) # 确保输入数据在 CPU 上 input_tensor = input_tensor.to(torch.device('cpu')) with torch.no_grad(): output = model(input_tensor) probabilities = softmax(output) _, predicted = torch.max(probabilities.data, 1) predicted_class = classes[predicted.item()] # 打印每个类别的百分比 print(f"Image: {image_file}") for i, class_name in enumerate(classes): prob = probabilities[0][i].item() * 100 if prob > 20: print(f"{class_name}: {prob:.2f}%\n") # plt
最新发布
03-30
### 转换Python深度学习预测代码至C++实现 #### ResNet模型加载 在C++中,可以通过`torch::jit::load()`函数来加载已导出的PyTorch模型。此方法适用于通过`torch.jit.script`或`torch.jit.trace`保存的模型文件。如果使用的是ONNX格式,则需调用OpenCV DNN模块的相关API。 ```cpp #include <torch/torch.h> std::string model_path = "./models/resnet50.pt"; torch::jit::script::Module module; try { module = torch::jit::load(model_path); } catch (const c10::Error& e) { std::cerr << "Error loading the model: " << e.what() << '\n'; } module.eval(); // 设置模型为评估模式 ``` 上述代码展示了如何加载`.pt`格式的PyTorch模型[^3]。 #### 数据预处理 为了匹配PyTorch模型的输入需求,通常需要对图像执行标准化、缩放和通道调整操作。以下是基于OpenCV的数据预处理流程: ```cpp #include <opencv2/opencv.hpp> cv::Mat preprocessImage(const cv::Mat& image) { cv::Mat resized_image; cv::resize(image, resized_image, cv::Size(224, 224)); // 将图像大小调整为224x224 cv::Mat float_image; resized_image.convertTo(float_image, CV_32F); // 转换为浮点数类型 cv::Mat normalized_image; cv::subtract(float_image, cv::Scalar(123.675, 116.28, 103.53), normalized_image); // 减去均值 normalized_image /= cv::Scalar(58.395, 57.12, 57.375); // 归一化除以标准差 // HWC -> CHW std::vector<cv::Mat> channels; split(normalized_image, channels); cv::Mat chw_image; merge(channels, chw_image); return chw_image.reshape(1, {chw_image.rows * chw_image.cols, 3}); } // 示例:加载并预处理图像 cv::Mat input_image = cv::imread("./test.jpg"); cv::Mat tensor_data = preprocessImage(input_image); auto input_tensor = torch::from_blob(tensor_data.data, {1, 3, 224, 224}).to(torch::kFloat32); ``` 以上实现了从BGR到RGB的转换、尺寸调整、归一化以及维度排列变换的操作[^2]。 #### Softmax计算 完成前向传播后,输出张量可通过Softmax函数转化为概率分布形式: ```cpp at::Tensor output = module.forward({input_tensor}).toTensor(); output = torch::softmax(output, 1); // 对第1维应用Softmax激活函数 ``` 最终可提取最高置信度类别及其对应的分数: ```cpp int predicted_class = output.argmax(1).item<int>(); float confidence_score = output.max().item<float>(); std::cout << "Predicted Class: " << predicted_class << ", Confidence Score: " << confidence_score << "\n"; ``` 整个过程涵盖了模型加载、数据准备与推理阶段的结果解读。 ---
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值