PyTorch入门学习小记
1、区分numpy
和torch
图像的颜色轴
image = image.transpose((2, 0, 1))
Pytorch
中使用的数据格式与plt.imshow()
函数的格式不一致,Pytorch
中为[Channels, H, W]
,而plt.imshow()
中则是[H, W, Channels]
,因此,要先转置一下。
# numpy image:H x W x C
# 颜色轴对应编号:(0, 1, 2)
########################################
# torch image:C x H x W
# 颜色轴对应编号:(2, 0, 1)
########################################
2、关于输出图片的显示格式问题
def show_landmarks_batch(sample_batched):
"""Show image with landmarks for a batch of samples."""
images_batch, landmarks_batch = sample_batched['image'], sample_batched['landmarks']
batch_size = len(images_batch)
im_size = images_batch.size(2) # 这里的size是shape
grid_border_size = 2
grid = utils.make_grid(images_batch) # 多张图变成一张图
plt.imshow(grid.numpy().transpose((1, 2, 0))) # reshape到能用plt显示
for i in range(batch_size):
# 第i张图片的所有点的x,所有点的y,后面 + i*im_size是由于所有图像水平显示,所以需要水平有个偏移
# 转numpy是因为torch类型的数据没办法scatter
plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) *
grid_border_size,
landmarks_batch[i, :, 1].numpy() + grid_border_size,
s=10, marker='.', c='r')
plt.title('Batch from dataloader')
3、参考
PyTorch官网教程+注释
PyTorch数据加载及处理
pytorch读入图片并显示np.transpose(np_image, [1, 2, 0])
PyTorch官方教程中文版