1、Image 2 Tensor(三维到三维):
transforms.ToTensor:可以将PIL和numpy格式的数据从[0,255]范围转换到[0,1] ,具体做法其实就是将原始数据除以255。另外原始数据的shape是(H x W x C),经过 transforms.ToTensor 后shape会变为(C x H x W)
from torchvision import transforms
img = Image.open(img_path)
tran = transforms.ToTensor()
img_tensor = tran(img)
img_tensor 输出结果:
2、Tensor 2 Image (三维到三维):
from torchvision import transforms
img_original = transforms.ToPILImage () (img_tensor)
img_original,show()
3、升维:三维——四维(加上batch_size一维)
什么时候用: 把单张img放进一个net里的时候需要升维(考虑batch_size)
img_tensor_with_batchdim = torch.unsqueeze (img_tensor, dim=0)
print(img_tensor.shape)
print(img_tensor_with_batchdim .shape)
4、降维:四维——三维(减去batch_size一维)
img = img[0]
img = img.detach ().numpy () # FloatTensor转为ndarray
x = np.transpose (img, (1, 2, 0)) # 把channel那一维放到最后
plt.imshow(x)
plt.show()