使用trochvision.transforms进行数据预处理,以pix2代码为例

在训练模型的时候产生了思考,使用trochvision.transforms做数据增强的时候如何保证输入img与gt做相同的变换,重新看了看pix2pix代码,现做笔记如下:

代码论文引用:

Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.
[Jun-Yan Zhu](https://www.cs.cmu.edu/~junyanz/)\, [Taesung Park](https://taesung.me/)\Phillip IsolaAlexei A. Efros. In ICCV 2017. (* equal contributions) [Bibtex]

Image-to-Image Translation with Conditional Adversarial Networks.
Phillip IsolaJun-Yan ZhuTinghui ZhouAlexei A. Efros. In CVPR 2017. [Bibtex]

代码中在AlignedDataset模式下,pytorch数据生成器的__getitem__部分如下:

    def __getitem__(self, index):
        AB_path = self.AB_paths[index]
        A = Image.open(AB_path[0]).convert('RGB')
        B = Image.open(AB_path[1]).convert('RGB')
        transform_params = get_params(self.opt, A.size)
        A_transform = get_transform(self.opt, transform_params, 
        grayscale=(self.input_nc == 1))
        B_transform = get_transform(self.opt, transform_params, 
        grayscale=(self.output_nc == 1))
        A = A_transform(A)
        B = B_transform(B)
        return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}

其中A,B分别为input img以及gt,get_params为参数生成函数

#对源代码进行了修改
def get_params(opt):
    new_h = new_w = opt.load_size#opt.load_size=256
    #随机生成(0到crop_size-load_size)之间的整数
    x = random.randint(0, np.maximum(0, new_w - opt.crop_size))#opt.crop_size=286
    y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
    #水平翻转标志
    flip = random.random() > 0.5
    return {'crop_pos': (x, y), 'flip': flip}

接下来是 get_transform图片转换增强函数

def get_transform(opt, params=None, method=Image.BICUBIC):
    transform_list = []
    osize = [opt.load_size, opt.load_size]
    #将图片resize为286*286
    transform_list.append(transforms.Resize(osize, method))
    if params is None:
        transform_list.append(transforms.RandomCrop(opt.crop_size))
    else:
    #根据get_params中生成的整数将图片裁剪为256*256大小
        transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
    if params is None:
        transform_list.append(transforms.RandomHorizontalFlip())
    #根据get_params中生成的翻转标志是否将图片进行水平翻转
    elif params['flip']:
            transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
    #输入数据的标准化
    transform_list += [transforms.ToTensor()]
    transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)

保证输入img与gt做相同的变换的关键是transforms.Lambda方法,具体讲解见如下链接:

pytorch transforms.Lambda的使用 - 慢行厚积 - 博客园 (cnblogs.com)

__crop:

def __crop(img, pos, size):
    ow, oh = img.size#ow, oh为上一步resize的后的尺寸此时为286*286
    x1, y1 = pos#为随机生成的0到(286-256)之间的整数
    tw = th = size#为裁剪后的尺寸,为256*256
    if (ow > tw or oh > th):
    #crop为图片裁剪方法具体见下链接
        return img.crop((x1, y1, x1 + tw, y1 + th))
    return img

crop方法讲解: pillow模块Image.crop()函数切割图片方法,参数说明 - 正态分个布 - 博客园 (cnblogs.com)

 __flip:

def __flip(img, flip):
    if flip:
        #对img进行水平镜像处理
        return img.transpose(Image.FLIP_LEFT_RIGHT)
    return img

img.transpose(Image.FLIP_LEFT_RIGHT)讲解:

Python-图像处理库PIL图像变换transpose和transforms函数 - lovejobs - 博客园 (cnblogs.com)

完整的示意代码如下 

pathA="./dio.png"
pathB="./dio.png"
params=get_params(opt)
print(params)
img_A=Image.open(pathA)
img_B=Image.open(pathB)
A_transform=get_transform(opt,params)
B_transform=get_transform(opt,params)
img_A_t=A_transform(img_A)
img_B_t=B_transform(img_B)
img_A_t.show()
img_B_t.show()

 原始图片:

 params:

{'crop_pos': (12, 18), 'flip': True}

img_A_to: 

 

img_B_to:  

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值