我尝试使用PIL.Image.fromarray来从torch.Tensor生成一张图像:
from PIL import Image
import numpy as np
img_ori = Image.fromarray(np.array(images[0].cpu()))
其中images的形状是[bs, 3, 224, 224].
出现这个报错的原因是```np.array(images[0].cpu())```与Image.fromarray所需的数组形式不符合。
Image.fromarray需要的是形状为[H, W, 3]数组,并且数组的元素应当是uint8类型。而PyTorch默认的用于模型训练的张量数据类型是介于0和1之间的浮点数,且通道维度在尺度维度之前。
所以需要进行变形:
from PIL import Image
import numpy as np
img_ori = np.array(images[0].permute(1,2,0).clamp(0,1).cpu()*255).astype(np.uint8)
img_ori = Image.fromarray(img_ori)
变形之后不会报错