对抗神经网络(GAN)

本文介绍了对抗神经网络(GAN)的原理。GAN由生成式模型和判别式模型构成,生成式模型产生模拟数据,判别式模型判断数据真伪。二者目标相反形成对抗,经过迭代训练使整个GAN达到纳什均衡,在无监督及半监督领域效果良好。

       对抗神经网络其实是两个网络的组合,可以理解为一个网络生成模拟数据,另一个网络判断生成的数据是真实的还是模拟的。生成模拟数据的网络要不断优化自己让判别的网络判断不出来,判别的网络也要优化自己让自己判断得更准确。二者关系形成对抗,因此叫对抗神经网络。实验证明,利用这种网络间的对抗关系所形成的网络,在无监督及半监督领域取得了很好的效果,可以算是用网络来监督网络的一个自学习过程。

1、理论知识

       GAN由generator(生成式模型)和discriminator(判别式模型)两部分构成。

  • generator:主要是从训练数据中产生相同分布的samples,对于输入x,类别标签y,在生成式模型中估计其联合概率分布(两个及以上随机变量组成的随机向量的概率分布)。
  • discriminator:判断输入时真实数据还是generator生成的数据,即估计样本属于某类的条件概率分布。它采用传统的监督学习的方法。

      二者结合后,经过大量次数的迭代训练会使generator尽可能模拟出以假乱真的样本,而discriminator会有更精确的鉴别真伪数据的能力,最终整个GAN会达到所谓的纳什均衡,即discriminator对于generator的数据鉴别结果为正确率和错误率各占50%。

      网络结构如图所示:

  • 生成式模型又叫生成器。它先用一个随机编码向量来输出一个模拟样本。
  • 判别式模型又叫判别器。它的输入是一个样本(可以是真实样本也可以是模拟样本),输出一个判断该样本是真样本还是模拟样本(假样本)的结果。

    判别器的目标是区分真假样本,生成器的目标是让判别器区分不出真假样本,两者目标相反,存在对抗。

 

 

### 对抗神经网络 (GAN) 的代码实现 对抗神经网络GAN)的实现通常涉及定义生成器和判别器两个部分,并设置它们之间的对抗训练机制。下面是一个基于 PyTorch 实现的手写数字图片生成的 GAN 示例[^1]。 #### 导入必要的库 ```python import torch from torch import nn, optim from torchvision import datasets, transforms from torch.utils.data import DataLoader ``` #### 定义生成器 生成器的任务是从随机噪声向量生成看似真实的图像。 ```python class Generator(nn.Module): def __init__(self, input_size=100, hidden_dim=128, output_size=(1, 28, 28)): super(Generator, self).__init__() # 构建多层感知机结构 self.model = nn.Sequential( nn.Linear(input_size, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim * 2), nn.ReLU(), nn.Linear(hidden_dim * 2, int(torch.prod(torch.tensor(output_size)))), nn.Tanh() ) def forward(self, z): img_flat = self.model(z) img = img_flat.view(img_flat.size(0), *output_size) return img ``` #### 定义判别器 判别器负责评估传入的数据是否为真实样本或是来自生成器的伪造品。 ```python class Discriminator(nn.Module): def __init__(self, input_size=(1, 28, 28), hidden_dim=128): super(Discriminator, self).__init__() # 设计一个多层感知机架构 self.model = nn.Sequential( nn.Linear(int(torch.prod(torch.tensor(input_size))), hidden_dim*2), nn.LeakyReLU(0.2), nn.Dropout(0.3), nn.Linear(hidden_dim*2, hidden_dim), nn.LeakyReLU(0.2), nn.Dropout(0.3), nn.Linear(hidden_dim, 1), nn.Sigmoid() ) def forward(self, imgs): logits = self.model(imgs.view(imgs.size(0), -1)) return logits.squeeze() ``` #### 训练循环 此段代码展示了如何交替更新生成器和判别器参数,以促进两者的竞争关系。 ```python def train_gan(generator, discriminator, dataloader, num_epochs=200, lr=0.0002): criterion = nn.BCELoss() # 使用二元交叉熵损失函数 optimizer_G = optim.Adam(generator.parameters(), lr=lr) optimizer_D = optim.Adam(discriminator.parameters(), lr=lr) fixed_noise = torch.randn((64, 100)) for epoch in range(num_epochs): for i, (imgs, _) in enumerate(dataloader): real_labels = torch.ones(imgs.shape[0]) fake_labels = torch.zeros(imgs.shape[0]) # 更新 D 网络: max log(D(x)) + log(1 - D(G(z))) optimizer_D.zero_grad() outputs = discriminator(imgs).squeeze() d_loss_real = criterion(outputs, real_labels) real_score = outputs noise = torch.randn(imgs.shape[0], 100) fake_images = generator(noise) outputs = discriminator(fake_images.detach()).squeeze() d_loss_fake = criterion(outputs, fake_labels) fake_score = outputs d_loss = d_loss_real + d_loss_fake d_loss.backward() optimizer_D.step() # 更新 G 网络: min log(1 - D(G(z))) <-> max log(D(G(z)) optimizer_G.zero_grad() outputs = discriminator(fake_images).squeeze() g_loss = criterion(outputs, real_labels) g_loss.backward() optimizer_G.step() print(f'Epoch [{epoch}/{num_epochs}], ' f'd_loss: {d_loss.item():.4f}, ' f'g_loss: {g_loss.item():.4f}') ``` 上述代码片段提供了完整的 GAN 模型构建流程以及基本的训练逻辑[^4]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值