OpenCV的一些操作转pytorch,从而有助于使用GPU加速,甚至导出onnx和转TensorRT
需要注意opencv的输入是numpy tensor,format是HW的2D张量或者HWC的3D张量,而pytorch一般是NCHW的4D或者CHW的3D张量。
Dilation腐蚀与膨胀
https://blog.51cto.com/u_16175442/8629546
import cv2
import torch.nn.functional as F
def dilate_cv(img, dilate_factor=10):
"""
input img is np 2D, HWC 3D
"""
img = img.astype(np.uint8)
img1 = cv2.dilate(
img,
np.ones((dilate_factor, dilate_factor), np.uint8),
iterations=1
)
return img1
def dilate_torch(img, dilate_factor=10):
"""
input img should be 3D CHW, or 4D NCHW
"""
h, w = img.shape[-2:]
img1 = F.max_pool2d(img, kernel_size=dilate_factor, stride=1, padding=dilate_factor//2)
if dilate_factor % 2 == 0:
img1 = img1[:, :, :h, :w]
return img1
贴一个DeepSeek转换的版本:
import torch
import numpy as np
def dilate_torch(img, dilate_factor=10):
"""
input img is np 2D array
"""
# 转换为PyTorch张量并添加batch和channel维度
img_tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()
# 创建最大池化层实现膨胀效果
kernel_size = dilate_factor
padding = dilate_factor // 2 # 保持输出尺寸与输入一致
max_pool = torch.nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=padding)
# 应用池化操作
dilated_tensor = max_pool(img_tensor)
# 转换回numpy数组并恢复原始维度
dilated_img = dilated_tensor.squeeze().cpu().numpy()
return dilated_img.astype(np.uint8)
Resize
import cv2
from torchvision.transforms.functional import resize
from torchvision.transforms import InterpolationMode
img_cv = cv2.resize(img_hwc, (scale*W, scale*H), interpolation=cv2.INTER_NEAREST)
img_torch = resize(img_chw, (scale*H, scale*W), interpolation=InterpolationMode.NEAREST)
需要注意的是opencv的resize和torch的resize结果不是完全对齐的,因为align方式的原因。
颜色转换
bgr_cv = cv2.cvtColor(data_np, cv2.COLOR_RGB2BGR)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
def bgr2rgb_torch_nchw(bgr_nchw):
b, g, r = bgr_nchw.split(split_size=1, dim=-3)
rgb = torch.cat([r, g, b], dim=-3).numpy()
return rgb
def rgb2bgr_torch_nchw(rgb_nchw):
r, g, b = rgb_nchw.split(split_size=1, dim=-3)
bgr = torch.cat([b, g, r], dim=-3)
return bgr
Blur
import torch
import numpy as np
import cv2
img_hwc = np.random.rand(*[256, 256, 3]).astype("float32")
img_chw = img_hwc.transpose([2, 0, 1])
img_chw_tc = torch.from_numpy(img_chw)
kernel_size = 3
img_blur_cv = cv2.blur(img_hwc, (kernel_size, kernel_size))
img_blur_cv_chw = img_blur_cv.transpose([2, 0, 1])
def mean_blur_torch(img_chw, kernel_size):
device = img_chw.device
dtype = img_chw.dtype
pad_l = kernel_size // 2
pad_r = kernel_size // 2
if kernel_size % 2 == 0:
pad_r = pad_r-1
img_chw1 = torch.nn.functional.pad(img_chw, pad=[pad_l, pad_r, pad_l, pad_r], mode='reflect')
weight = torch.ones(*(3, 1, kernel_size, kernel_size), dtype=dtype, device=device)/kernel_size/kernel_size
img_blur_chw = torch.nn.functional.conv2d(img_chw1, weight, padding=0, groups=3)
return img_blur_chw
img_blur_torch_chw = mean_blur_torch(img_chw_tc, kernel_size)
img_blur_torch_chw = img_blur_torch_chw.numpy()
error = np.abs(img_blur_cv_chw - img_blur_torch_chw)
print("error:", np.max(error), np.mean(error))
def to_tensor_torch(tensor):
# use torch tensor but not numpy as input
# hwc to chw / nhwc to nchw transpose, dtype conversion, rescale
if len(tensor.shape) == 2:
tensor = tensor.unsqueeze(dim=-1)
# transpose
if len(tensor.shape) == 3:
tensor = tensor.permute([2, 0, 1])
elif len(tensor.shape) == 4:
tensor = tensor.permute([0, 3, 1, 2])
else:
raise ValueError("unsupported")
is_uint8 = tensor.dtype == torch.uint8
# dtype conversion
tensor = tensor.to(torch.float32)
# rescale
if is_uint8:
tensor = tensor / 255.0
return tensor