在训练模型的时候产生了思考,使用trochvision.transforms做数据增强的时候如何保证输入img与gt做相同的变换,重新看了看pix2pix代码,现做笔记如下:
代码论文引用:
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.
[Jun-Yan Zhu](https://www.cs.cmu.edu/~junyanz/)\, [Taesung Park](https://taesung.me/)\, Phillip Isola, Alexei A. Efros. In ICCV 2017. (* equal contributions) [Bibtex]Image-to-Image Translation with Conditional Adversarial Networks.
Phillip Isola, Jun-Yan Zhu, Tinghui Zhou, Alexei A. Efros. In CVPR 2017. [Bibtex]
代码中在AlignedDataset模式下,pytorch数据生成器的__getitem__部分如下:
def __getitem__(self, index):
AB_path = self.AB_paths[index]
A = Image.open(AB_path[0]).convert('RGB')
B = Image.open(AB_path[1]).convert('RGB')
transform_params = get_params(self.opt, A.size)
A_transform = get_transform(self.opt, transform_params,
grayscale=(self.input_nc == 1))
B_transform = get_transform(self.opt, transform_params,
grayscale=(self.output_nc == 1))
A = A_transform(A)
B = B_transform(B)
return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
其中A,B分别为input img以及gt,get_params为参数生成函数
#对源代码进行了修改
def get_params(opt):
new_h = new_w = opt.load_size#opt.load_size=256
#随机生成(0到crop_size-load_size)之间的整数
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))#opt.crop_size=286
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
#水平翻转标志
flip = random.random() > 0.5
return {'crop_pos': (x, y), 'flip': flip}
接下来是 get_transform图片转换增强函数
def get_transform(opt, params=None, method=Image.BICUBIC):
transform_list = []
osize = [opt.load_size, opt.load_size]
#将图片resize为286*286
transform_list.append(transforms.Resize(osize, method))
if params is None:
transform_list.append(transforms.RandomCrop(opt.crop_size))
else:
#根据get_params中生成的整数将图片裁剪为256*256大小
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
if params is None:
transform_list.append(transforms.RandomHorizontalFlip())
#根据get_params中生成的翻转标志是否将图片进行水平翻转
elif params['flip']:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
#输入数据的标准化
transform_list += [transforms.ToTensor()]
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
保证输入img与gt做相同的变换的关键是transforms.Lambda方法,具体讲解见如下链接:
__crop:
def __crop(img, pos, size):
ow, oh = img.size#ow, oh为上一步resize的后的尺寸此时为286*286
x1, y1 = pos#为随机生成的0到(286-256)之间的整数
tw = th = size#为裁剪后的尺寸,为256*256
if (ow > tw or oh > th):
#crop为图片裁剪方法具体见下链接
return img.crop((x1, y1, x1 + tw, y1 + th))
return img
crop方法讲解: pillow模块Image.crop()函数切割图片方法,参数说明 - 正态分个布 - 博客园 (cnblogs.com)
__flip:
def __flip(img, flip):
if flip:
#对img进行水平镜像处理
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
img.transpose(Image.FLIP_LEFT_RIGHT)讲解:
Python-图像处理库PIL图像变换transpose和transforms函数 - lovejobs - 博客园 (cnblogs.com)
完整的示意代码如下
pathA="./dio.png"
pathB="./dio.png"
params=get_params(opt)
print(params)
img_A=Image.open(pathA)
img_B=Image.open(pathB)
A_transform=get_transform(opt,params)
B_transform=get_transform(opt,params)
img_A_t=A_transform(img_A)
img_B_t=B_transform(img_B)
img_A_t.show()
img_B_t.show()
原始图片:
params:
{'crop_pos': (12, 18), 'flip': True}
img_A_to:
img_B_to: