Pytorch 实现 GAN 对抗网络

GAN 对抗网络

GAN(Generative Adversarial Network)对抗网络指的是神经网络中包括两个子网络,一个用于生成信息,一个用于验证信息。下面的例子是生成图片的对抗网络,一个网络用于生成图片,另一个网络用于验证。G 网络用于生成图片,不断的学习并生成更接近于训练数据的图像,D 网络用于鉴别图片,通过学习更准确的识别出图片的真假,最终通过学习让 G 网络能够生成高质量的目标图片。下面通过代码实现两种不同的 GAN,图片为自动生成手写图片,采用 MNIST数据集。

  • DCGAN (Deep Convolution)深度卷积生成对抗网络
  • SAGAN(Self Attention)自注意力生成对抗网络

安装依赖

本文将使用 sklearn,首先安装 sklearn。

pip install -U scikit-learn

DCGAN 深度卷积对抗网络

数据准备

采用 MNIST 数据,并只选用 7、8 两个数字,MINST 中 7、8 数字各有 200 张。

def make_datapath_list():
    """创建用于学习和验证的图像数据及标注数据的文件路径列表。 """

    train_img_list = list() #保存图像文件的路径

    for img_idx in range(200):
        img_path = "./data/img_78/img_7_" + str(img_idx)+'.jpg'
        train_img_list.append(img_path)

        img_path = "./data/img_78/img_8_" + str(img_idx)+'.jpg'
        train_img_list.append(img_path)

    return train_img_list

class ImageTransform():
    """图像的预处理类"""

    def __init__(self, mean, std):
        self.data_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

    def __call__(self, img):
        return self.data_transform(img)

class GAN_Img_Dataset(data.Dataset):
    """图像的 Dataset 类,继承自 PyTorchd 的 Dataset 类"""

    def __init__(self, file_list, transform):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        '''返回图像的张数'''
        return len(self.file_list)

    def __getitem__(self, index):
        '''获取经过预处理后的图像的张量格式的数据'''

        img_path = self.file_list[index]
        img = Image.open(img_path)  #[ 高度 ][ 宽度 ] 黑白

        #图像的预处理
        img_transformed = self.transform(img)

        return img_transformed

#创建DataLoader并确认执行结果

#创建文件列表
train_img_list=make_datapath_list()

#创建Dataset
mean = (0.5,)
std = (0.5,)
train_dataset = GAN_Img_Dataset(
    file_list=train_img_list, transform=ImageTransform(mean, std))

#创建DataLoader
batch_size = 64

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

#确认执行结果
batch_iterator = iter(train_dataloader)  #转换为迭代器
imges = next(batch_iterator)  #取出位于第一位的元素
print(imges.size())  # torch.Size([64, 1, 64, 64])

生成网络实现

需要根据输入的随机数生成图像,对数据的维度进行放大,并增加维度中的元素数量,通过 nn.ConvTranspose2d 转置卷积进行实现。转置卷积是卷积的反向操作,卷积输出特征通常比输入数据小,反向卷积输出比输入大,可以看做数据放大操作。

# 导入软件包
import random
import math
import time
import pandas as pd
import numpy as np
from PIL import Image

import torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import transforms

# Setup seeds
torch.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)
class Generator(nn.Module):

    def __init__(self, z_dim=20, image_size=64):
        super(Generator, self).__init__()

        self.layer1 = nn.Sequential(
            nn.ConvTranspose2d(z_dim, image_size * 8,
                               kernel_size=4, stride=1),
            nn.BatchNorm2d(image_size * 8),
            nn.ReLU(inplace=True))

        self.layer2 = nn.Sequential(
            nn.ConvTranspose2d(image_size * 8, image_size * 4,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(image_size * 4),
            nn.ReLU(inplace=True))

        self.layer3 = nn.Sequential(
            nn.ConvTranspose2d(image_size * 4, image_size * 2,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(image_size * 2),
            nn.ReLU(inplace=True))

        self.layer4 = nn.Sequential(
            nn.ConvTranspose2d(image_size * 2, image_size,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(image_size),
            nn.ReLU(inplace=True))

        self.last = nn.Sequential(
            nn.ConvTranspose2d(image_size, 1, kernel_size=4,
                               stride=2, padding=1),
            nn.Tanh())
        # 注意:因为是黑白图像,所以只有一个输出通道

    def forward(self, z):
        out = self.layer1(z)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.last(out)

        return out

根据输入,生成图片

#动作确认
import matplotlib.pyplot as plt
%matplotlib inline

G = Generator(z_dim=20, image_size=64)

# 输入的随机数
input_z = torch.randn(1, 20)

# 将张量尺寸变形为(1,20,1,1)
input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)

#输出假图像
fake_images = G(input_z)

img_transformed = fake_images[0][0].detach().numpy()
plt.imshow(img_transformed, 'gray')
plt.show()

没有经过学习生成的图片,目标是通过学习生成手写数字的效果。
在这里插入图片描述

鉴别网络实现

<
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值