一、原理
二、数据集下载以及加载
cycleGAN训练数据集,苹果橙子数据集 APPLE 2 ORANGE
# 加载训练数据
apples_path = glob.glob('data/trainA/*.jpg')
oranges_path = glob.glob('data/trainB/*.jpg')
transform = transforms.Compose([transforms.ToTensor(), # 0-1归一化
transforms.Normalize(0.5, 0.5), # -1,1
])
class AppleOrangeDataset(data.Dataset):
def __init__(self, img_path):
self.img_path = img_path
def __getitem__(self, index):
img_path = self.img_path[index]
pil_img = Image.open(img_path)
pil_img = transform(pil_img)
return pil_img
def __len__(self):
return len(self.img_path)
apple_dataset = AppleOrangeDataset(apples_path)
oranges_dataset = AppleOrangeDataset(oranges_path)
三、基于Unet结构定义上 / 下采样模块
生成器结构和判别器结构 与pix2pixGAN中完全相同,具体可参考:
GAN实战之Pytorch使用pix2pixGAN生成建筑物Label to Facade
class Downsample(nn.Module):
def __init__(self, in_channels, out_channels):
super(Downsample, self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(in_channels, out_channels,
kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(inplace=True),
)
self.bn = nn.InstanceNorm2d(out_channels)
def forward(self, x, is_bn=True):
x = self.conv_relu(x)
if is_bn:
x = self.bn(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels, out_channels):
super(Upsample, self).__init__()
self.upconv_relu = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels,
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.LeakyReLU(inplace=True)
)
self.bn = nn.InstanceNorm2d(out_channels)
def forward(self, x, is_drop=False):
x = self.upconv_relu(x)
x = self.bn(x)
if is_drop:
x = F.dropout2d(x)
return x
四、生成器结构
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.down1 = D