pytorch中torchvision.utils包下的save_image函数

雷郭出品

函数的用途:
将NCHW的tensor以网格图的形式存储到硬盘中,该图也叫做雪碧图sprite image
如下图所示:
在这里插入图片描述
将多张图以网格的形式拼凑起来,每张图的大小是28*28,单通道
那宽高如何确定?
我们可以来看看该函数的源码

def save_image(
    tensor: Union[torch.Tensor, List[torch.Tensor]],
    fp: Union[Text, pathlib.Path, BinaryIO],
    nrow: int = 8,
    padding: int = 2,
    normalize: bool = False,
    range: Optional[Tuple[int, int]] = None,
    scale_each: bool = False,
    pad_value: int = 0,
    format: Optional[str] = None,
) -> None:
    """Save a given Tensor into an image file.

    Args:
        tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
            saves the tensor as a grid of images by calling ``make_grid``.
        fp (string or file object): A filename or a file object
        format(Optional):  If omitted, the format to use is determined from the filename extension.
            If a file object was used instead of a filename, this parameter should always be used.
        **kwargs: Other arguments are documented in ``make_grid``.
    """
    from PIL import Image
    grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
                     normalize=normalize, range=range, scale_each=scale_each)
    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
    im = Image.fromarray(ndarr)
    im.save(fp, format=format)

可以看到nrow默认为8
padding默认为2
当我们的tensor形状为96* 1* 28 * 28的时候
网格的行和列对应的格子数分别为(N/nrow,nrow)
即(12,8)
对应的就是第一张图

但是实际当我去查看图片的像素大小时,由于padding的存在
像素大小并不是(12 * 28,8 * 28)
而是(12 * 28+13 * 2,8 * 28+9 * 2)

还有一点要注意,当你存储图片的时候由于总的图片数可能不能被batchsize整除
所以当雪碧图的格子数跟batchsize不对应的时候
不要犯愁
这是正常
我也是看了好几个小时才突然从下面的打印中得到的灵感

real_images的shape: torch.Size([128, 1, 28, 28])
real_images的shape: torch.Size([128, 1, 28, 28])
real_images的shape: torch.Size([128, 1, 28, 28])
real_images的shape: torch.Size([128, 1, 28, 28])
real_images的shape: torch.Size([128, 1, 28, 28])
real_images的shape: torch.Size([128, 1, 28, 28])
real_images的shape: torch.Size([128, 1, 28, 28])
real_images的shape: torch.Size([128, 1, 28, 28])
real_images的shape: torch.Size([128, 1, 28, 28])
real_images的shape: torch.Size([128, 1, 28, 28])
real_images的shape: torch.Size([96, 1, 28, 28])
real_img的shape: torch.Size([96, 784])

可以看到一开始的形状都是128
到了最后一个就变成了96
然后再次使用还是96
我就立刻想到了余数
然后我再验证6000=128 * 468+96
完美符合验证

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值