PyTorch入门学习小记

PyTorch入门学习小记

1、区分numpytorch图像的颜色轴

image = image.transpose((2, 0, 1))

Pytorch中使用的数据格式与plt.imshow()函数的格式不一致,Pytorch中为[Channels, H, W],而plt.imshow()中则是[H, W, Channels],因此,要先转置一下。

# numpy image:H x W x C
# 颜色轴对应编号:(0, 1, 2)
########################################
#  torch image:C x H x W
# 颜色轴对应编号:(2, 0, 1)
########################################

2、关于输出图片的显示格式问题

def show_landmarks_batch(sample_batched):
"""Show image with landmarks for a batch of samples."""
	images_batch, landmarks_batch = sample_batched['image'], sample_batched['landmarks']
	batch_size = len(images_batch)
	im_size = images_batch.size(2)	# 这里的size是shape
	grid_border_size = 2
	
	grid = utils.make_grid(images_batch)	# 多张图变成一张图
	plt.imshow(grid.numpy().transpose((1, 2, 0)))	# reshape到能用plt显示
	
	for i in range(batch_size):
	    # 第i张图片的所有点的x,所有点的y,后面 + i*im_size是由于所有图像水平显示,所以需要水平有个偏移
        # 转numpy是因为torch类型的数据没办法scatter
		plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) *
		grid_border_size,
					landmarks_batch[i, :, 1].numpy() + grid_border_size,
					s=10, marker='.', c='r')
		plt.title('Batch from dataloader')

图 2.1
图 2.2

3、参考

PyTorch官网教程+注释
PyTorch数据加载及处理
pytorch读入图片并显示np.transpose(np_image, [1, 2, 0])
PyTorch官方教程中文版

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值