生成对抗网络(GAN)入门

一、理论基础

1.什么是GAN
GAN(生成对抗网络)是一种神经网络架构,它的设计灵感来自于博弈论的思想。这种网络由两个关键组件组成:生成器(Generator)和判别器(Discriminator)。生成器的任务是从随机噪声中生成与真实数据相似的合成样本,而判别器则负责辨别给定的样本是真实数据还是生成器产生的人工样本。这两个组件通过反复对抗性的训练相互影响和提升。

在训练过程中,生成器努力生成越来越逼真的样本,而判别器则不断提高辨别真实与生成样本的能力。这种博弈过程推动着两者的不断进步,直到最终生成器生成的样本足够逼真,使得判别器无法有效地区分真实和生成的样本。这就达到了一个状态,即生成器能够以惊人的逼真度产生伪造的真实样本。

GAN的应用广泛,包括图像生成、风格转换、图像编辑等领域,其独特的博弈训练方法使其在生成高质量数据方面表现出色。

2.什么是生成器
在生成对抗网络(GANs)中,生成器(G)是一个关键组件,其任务是利用随机噪声(通常表示为z)作为输入,并通过不断的学习和拟合过程生成一个与真实样本在尺寸和分布上相似的伪造样本G(z)。生成器本质上是一种生成式方法的模型,它通过学习数据的分布和分布参数来生成新的样本。

从数学的角度来看,生成式方法首先对数据的显式或隐含变量进行分布假设,然后通过将真实数据输入模型并训练变量和参数,最终得到一个学习后的近似分布。这个学习后的分布可以被用来生成新的数据。与传统的数学方法不同,生成器使用机器学习的方法,通过不断学习真实数据并修正模型,最终得到一个可以执行样本生成任务的学习后模型,这一过程可能相对较为直观。

生成器通过借助现有数据生成新数据,例如从随机产生的一组数字向量(称为潜在空间 latent space)中生成图像、音频等数据。在构建生成器时,首先需要明确生成的目标,然后将生成的结果传递给判别器网络进行进一步处理。这协同的过程在训练中推动生成器生成越来越逼真的样本。

3.什么是判别器
在生成对抗网络(GANs)中,判别器(D)是另一个关键的组件,其任务是对于输入的样本x输出一个介于[0,1]之间的概率数值D(x)。这个样本x可能是来自原始数据集中的真实样本,也可能是由生成器G生成的人工样本G(z)。通常的约定是,概率值D(x)越接近于1,表示此样本为真实样本的可能性越大;反之,概率值越小,则代表此样本为伪造样本的可能性越大。判别器因此是一个二分类的神经网络分类器,其目标是区分输入样本的真伪,而非判定样本的原始类别。这表明GAN是一个无监督学习过程,没有使用样本的类别信息。

判别器的任务是尝试区分接收到的数据是真实数据还是由生成网络生成的数据。它根据预定义的类别对输入进行分类,通常在GAN中是进行二分类。判别器的输出结果是一个介于0和1之间的数字,用来表示当前输入被认为是真实数据的可能性。当判别结果为1时,判别器认为输入来自真实数据;反之,如果判别结果接近0,则判别器将其视为生成数据。判别器的目标是不断优化以在这个对抗性的过程中更准确地区分真实和生成的样本。

4.原理
生成对抗网络(GAN)是博弈论和机器学习相结合的创新性产物,由Ian Goodfellow于2014年提出。这一算法的问世引起了广泛的研究热潮,表明人们对这种算法的认可和热切的研究兴趣。

研究者最初尝试通过计算机来实现自动生成数据的功能。早期的生成算法通常采用均方误差作为损失函数来衡量生成图片和真实图片之间的差距。然而,研究者发现,有时两张生成图片的均方误差相同,但它们的视觉效果却迥然不同。鉴于这种不足,Ian Goodfellow提出了生成对抗网络。

GAN的核心思想是由两个模型组成:生成模型(G)和判别模型(D)。生成模型首先接收随机噪声z作为输入,生成一张初级图片。然后,训练一代判别模型(D)进行二分类操作,将生成的图片判别为0,真实图片判别为1。为了欺骗一代判别器,生成模型开始优化。随着一代生成模型的进步,它成功欺骗判别模型1D,然后判别模型也会优化更新,升级为2D。这样的迭代过程不断进行,生成模型和判别模型相互对抗、相互优化,生成越来越逼真的样本,形成了一个动态的博弈过程。

GAN的原理在于通过这种对抗性训练,生成器学习生成逼真的数据,而判别器学习更好地区分真实和生成的数据。这种博弈过程使得GAN在生成高质量、逼真的数据方面取得了显著的成功。

二、模型搭建与训练 

import os
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.utils import save_image
import torchvision.transforms as transforms

# 创建文件夹
os.makedirs('./output/images/', exist_ok=True)
os.makedirs('./output/', exist_ok=True)
os.makedirs('./data/MNIST/', exist_ok=True)

# 超参数配置
n_epochs = 50
batch_size = 64
lr = 0.0002
b1 = 0.5
b2 = 0.999
latent_dim = 100
img_size = 28
channels = 1
sample_interval = 500

# 图像的尺寸:(1, 28, 28),和图像的像素面积:(784)
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)

# 设定设备为 CPU
device = torch.device('cpu')

# mnist数据集下载
mnist = datasets.MNIST(root='./data/',
                       train=True,
                       download=True,
                       transform=transforms.Compose([
                           transforms.Resize(img_size),
                           transforms.ToTensor(),
                           transforms.Normalize([0.5], [0.5])
                       ]))
# 配置数据到加载器
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)

# 定义判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid())

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

# 定义生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, img_area),
            nn.Tanh())

    def forward(self, z):
        imgs = self.model(z)
        imgs = imgs.view(imgs.size(0), *img_shape)
        return imgs

# 创建生成器和判别器对象
generator = Genera
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值