torchvision.transforms
是一个图像转换工具包,不同的图像转换组件可以通过Compose
来连接从而形成一个流水线(类似于torch.nn.Sequential
),以实现更复杂的图像转换功能。
Image\Tensor\ndarray 之间的相互转换
绝大多数转换都同时支持 PIL 的Image
对象和张量图像。当然也可以使用ToTensor()
、ToPILImage()
等工具实现 PIL 对象和张量图像之间的相互转换。
还支持批量地转换张量图像。一般batch
图像是一个具有(B, C, H, W)
尺寸地张量,其中 B 代表每个 batch 中地图像个数。
ToTensor()
ToTensor()可以将一个PIL Image
或一个具有(H, W, C)
尺寸且数值范围在[0, 255]之间的ndarray
转换成一个形状为(C, H, W)
且数值范围在[0, 1]之间的浮点型张量。
转换的前提条件:对于 PIL Image
,图像模式必须为(L, LA, P, I, F. RGB, YCbCr, RGBA, CMYL, 1)中的一种,对于ndarray
,它的数据类型必须为np.uint8
。
事实上,ToTensor 中的归一化操作均是通过除以数组中的最大元进行实现的。PIL Image 和 ndarray 转化成 Tensor 后内容的顺序不一致。PIL Image 转化成 Tensor 后,排列格式为 [R, G, B],即 img[0] 代表 R 通道;而 opencv ndarray 转化成 Tensor 后,排列格式为 [B, G, R]。
import cv2
import torchvision
from PIL import Image
imgcv = cv2.imread('/path/to/fig') # shape: H W C uint8 [0-255] B G R
image = Image.open('/path/to/fig')
totenor = torchvision.transform.ToTensor()
imgcv_tensor = totensor(imgcv) # shape: C H W torch.float32 [0-1] B G R
image_tensor = totensor(image) # shape: C H W torch.float32 [0-1] R G B
PILToTensor()
不进行归于化的 Image to Tensor
import torchvision
from PIL import Image
img = Image.open('/path/to/img')
totensor = torchvision.transform.PILToTensor()
tensor = totensor(img) # shape: C H W torch.uint8 [0-255]
ToPILImage()
将 Tensor 或 ndarray 转为 Image.
其中 Tensor 要求形状为 (C, H, W),ndarray 要求形状为 (H, W, C)
import torchvision
import cv2
from PIL import Image
imgcv = cv2.imread('/path/to/img') # B G R
imgcv = imgcv[:, :, -1::-1]
totensor = torchvision.transform.ToTensor()
tensor = totensor(imgcv.copy())
topilimage = torchvision.transforms.ToPILImage()
image1 = topilimage(imgcv)
image2 = topilimage(tensor)
Compose
很多时候,我们需要对大量的图片完成一系列的图像转换操作,这时候我们就能用 Compose() 将这些操作组合成一道流水线,以简化我们的代码。
trans_pipeline = transforms.Compose([
transforms.RandomHorizontalFlip(1.0),
transforms.Resize((300, 300)),
transforms.ToTensor(),
])
img = Image.open('./pics/1.jpg')
img = trans_pipeline(img)
# 按需自定义Compose, 当然Compose_imglabel里的transforms也需做出自己的定义
class Compose_imglabel(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img, label):
for t in self.transforms:
img, label = t(img, label)
return img, label