OpenCV转pytorch

OpenCV的一些操作转pytorch,从而有助于使用GPU加速,甚至导出onnx和转TensorRT

需要注意opencv的输入是numpy tensor,format是HW的2D张量或者HWC的3D张量,而pytorch一般是NCHW的4D或者CHW的3D张量。

Dilation腐蚀与膨胀

12: 腐蚀与膨胀 | 陌上见花开

https://blog.51cto.com/u_16175442/8629546

import cv2
import torch.nn.functional as F


def dilate_cv(img, dilate_factor=10):
    """
    input img is np 2D, HWC 3D
    """
    img = img.astype(np.uint8)
    img1 = cv2.dilate(
        img,
        np.ones((dilate_factor, dilate_factor), np.uint8),
        iterations=1
    )
    return img1


def dilate_torch(img, dilate_factor=10):
    """
    input img should be 3D CHW, or 4D NCHW
    """
    h, w = img.shape[-2:]
    img1 = F.max_pool2d(img, kernel_size=dilate_factor, stride=1, padding=dilate_factor//2)
    if dilate_factor % 2 == 0:
        img1 = img1[:, :, :h, :w]
    return img1

贴一个DeepSeek转换的版本:

import torch
import numpy as np

def dilate_torch(img, dilate_factor=10):
    """
    input img is np 2D array
    """
    # 转换为PyTorch张量并添加batch和channel维度
    img_tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()
    
    # 创建最大池化层实现膨胀效果
    kernel_size = dilate_factor
    padding = dilate_factor // 2  # 保持输出尺寸与输入一致
    max_pool = torch.nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=padding)
    
    # 应用池化操作
    dilated_tensor = max_pool(img_tensor)
    
    # 转换回numpy数组并恢复原始维度
    dilated_img = dilated_tensor.squeeze().cpu().numpy()
    
    return dilated_img.astype(np.uint8)

Resize

import cv2
from torchvision.transforms.functional import resize
from torchvision.transforms import InterpolationMode

img_cv = cv2.resize(img_hwc, (scale*W, scale*H), interpolation=cv2.INTER_NEAREST)
img_torch = resize(img_chw, (scale*H, scale*W), interpolation=InterpolationMode.NEAREST)

需要注意的是opencv的resize和torch的resize结果不是完全对齐的,因为align方式的原因。

颜色转换

bgr_cv = cv2.cvtColor(data_np, cv2.COLOR_RGB2BGR)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

def bgr2rgb_torch_nchw(bgr_nchw):
    b, g, r = bgr_nchw.split(split_size=1, dim=-3)
    rgb = torch.cat([r, g, b], dim=-3).numpy()
    return rgb

def rgb2bgr_torch_nchw(rgb_nchw):
    r, g, b = rgb_nchw.split(split_size=1, dim=-3)
    bgr = torch.cat([b, g, r], dim=-3)
    return bgr

Blur

import torch
import numpy as np
import cv2

img_hwc = np.random.rand(*[256, 256, 3]).astype("float32")
img_chw = img_hwc.transpose([2, 0, 1])
img_chw_tc = torch.from_numpy(img_chw)

kernel_size = 3

img_blur_cv = cv2.blur(img_hwc, (kernel_size, kernel_size))
img_blur_cv_chw = img_blur_cv.transpose([2, 0, 1])

def mean_blur_torch(img_chw, kernel_size):
    device = img_chw.device
    dtype = img_chw.dtype

    pad_l = kernel_size // 2
    pad_r = kernel_size // 2
    if kernel_size % 2 == 0:
        pad_r = pad_r-1

    img_chw1 = torch.nn.functional.pad(img_chw, pad=[pad_l, pad_r, pad_l, pad_r], mode='reflect')
    weight = torch.ones(*(3, 1, kernel_size, kernel_size), dtype=dtype, device=device)/kernel_size/kernel_size
    img_blur_chw = torch.nn.functional.conv2d(img_chw1, weight, padding=0, groups=3)
    return img_blur_chw

img_blur_torch_chw = mean_blur_torch(img_chw_tc, kernel_size)
img_blur_torch_chw = img_blur_torch_chw.numpy()

error = np.abs(img_blur_cv_chw - img_blur_torch_chw)
print("error:", np.max(error), np.mean(error))
def to_tensor_torch(tensor):
    # use torch tensor but not numpy as input
    # hwc to chw / nhwc to nchw transpose, dtype conversion, rescale
    if len(tensor.shape) == 2:
        tensor = tensor.unsqueeze(dim=-1)

    # transpose
    if len(tensor.shape) == 3:
        tensor = tensor.permute([2, 0, 1])
    elif len(tensor.shape) == 4:
        tensor = tensor.permute([0, 3, 1, 2])
    else:
        raise ValueError("unsupported")

    is_uint8 = tensor.dtype == torch.uint8
    # dtype conversion
    tensor = tensor.to(torch.float32)
    # rescale
    if is_uint8:
        tensor = tensor / 255.0
    return tensor

<think>好的,用户的问题是如何使用OpenCVPyTorch进行图像预处理,需要方法和示例代码。首先,我需要明确图像预处理通常包括哪些步骤。常见的步骤有尺寸调整、归一化、数据增强(如翻、旋)、颜色空间换等。然后,用户提到了OpenCVPyTorch,所以需要结合这两个工具来处理。 接下来,回忆一下OpenCVPyTorch各自的优势。OpenCV擅长图像处理,比如读取、调整大小、颜色换等,而PyTorch在张量操作和数据增强方面更强大。可能需要先用OpenCV进行基础处理,再换为PyTorch张量,并应用PyTorch的增强方法。 然后,查看用户提供的引用内容,引用[4]提到PyTorch C++和OpenCV数据预处理的问题,其中镜像翻在Python中使用的是random模块和numpy操作。这可能涉及到如何在PyTorch中自定义数据增强,或者如何结合OpenCV函数进行处理。此外,引用[3]和[1]提到将PyTorch模型换为ONNX并用OpenCV加载,这可能涉及到预处理步骤需要与模型训练时的处理一致。 用户需要示例代码,因此需要考虑分步骤的代码示例。例如,先用OpenCV读取图像,调整大小,归一化,然后换为PyTorch张量,再应用其他增强。同时,可能需要展示自定义的数据增强,比如随机翻,可以参考引用[4]中的镜像代码,但需要适应PyTorch的张量格式。 还要注意数据预处理流程中的顺序问题,比如先调整大小,再裁剪,然后归一化。此外,颜色通道的换,因为OpenCV默认是BGR,而PyTorch通常需要RGB,可能需要换颜色空间。 另外,归一化的处理在PyTorch中通常使用transforms.Normalize,需要计算均值和标准差。可能需要提醒用户根据自己数据集的情况调整这些参数。 最后,根据用户提供的引用,可能需要提到如何确保预处理与模型训练时一致,特别是在部署模型时,如引用[3]所述,将模型换为ONNX并用OpenCV加载时,预处理步骤需要保持一致,否则会影响模型性能。 综合以上思考,回答的结构应该分步骤,结合OpenCVPyTorch的处理,提供代码示例,并引用相关的参考资料。同时,生成的问题需要围绕预处理步骤、数据增强方法、处理差异等展开,帮助用户进一步了解相关主题。</think>### 使用OpenCVPyTorch进行图像预处理的方法 图像预处理是深度学习模型训练和部署的关键步骤,结合OpenCVPyTorch可实现高效的图像处理流程。以下是具体方法及示例代码: --- #### **1. OpenCV基础预处理** OpenCV适合完成图像读取、尺寸调整、颜色空间换等操作: ```python import cv2 import numpy as np # 读取图像(BGR格式) image = cv2.imread("image.jpg") # 调整尺寸为224x224 resized = cv2.resize(image, (224, 224)) # 换为RGB格式(PyTorch默认使用RGB) rgb_image = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB) # 归一化到[0,1]范围并换为浮点型 normalized = rgb_image.astype(np.float32) / 255.0 ``` --- #### **2. PyTorch张量换与增强** 使用PyTorch的`transforms`模块进行数据增强: ```python import torch from torchvision import transforms # 定义预处理流水线 transform = transforms.Compose([ transforms.ToTensor(), # 将numpy数组换为PyTorch张量(自动归一化到[0,1]) transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # ImageNet标准化 transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻 ]) # 应用预处理 tensor_image = transform(normalized) ``` --- #### **3. 自定义OpenCVPyTorch混合处理** 对于复杂需求(如仿射变换),可结合两者: ```python # 使用OpenCV实现随机裁剪 def random_crop(image, size=(200, 200)): h, w = image.shape[:2] x = np.random.randint(0, w - size[0]) y = np.random.randint(0, h - size[1]) return image[y:y+size[1], x:x+size[0]] cropped = random_crop(rgb_image) tensor_cropped = transform(cropped) ``` --- #### **注意事项** - **颜色通道一致性**:OpenCV默认使用BGR格式,而PyTorch模型通常需要RGB输入,需通过`cv2.cvtColor`换[^3]。 - **归一化参数**:`transforms.Normalize`的均值和标准差需与模型训练时使用的参数一致[^1]。 - **数据增强同步**:若在训练和推理时使用不同框架(如部署时用OpenCV),需确保预处理逻辑完全一致[^2]。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Luchang-Li

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值