什么是GAN?

一、基本概念

        生成对抗网络(Generative Adversarial Network,GAN)是一种由两个神经网络共同组成深度学习模型:生成器(Generator)和判别器(Discriminator)。这两个网络通过对抗的方式进行训练,生成器尝试伪造逼真的样本数据,而判别器则负责判断输入的数据是真实数据还是生成器伪造出来的数据。理想情况下,判别器对真实样本和生成样本的判断概率都是1/2,意味着判别器已经无法判断生成器生成的数据真假。

二、模型原理

        GAN的模型原理并不复杂。首先,GAN由以下两个子模型组成:

  • 生成器(Generator)从随机噪声中生成数据,目标是欺骗判别器,使其认为生成的数据是真实的。
  • 判别器(Discriminator):判断输入数据是来自真实数据分布还是生成器,目标是正确区分真实数据和生成数据。

        然后,GAN的损失函数是训练的核心,我们需要构建一个合适的损失函数用于衡量生成器和判别器的表现:

  • 生成器损失(G_loss):通常表示为最大化判别器对其生成样本的错误分类概率,也就是判别器判定所有生成数据均为真。
  • 判别器损失(D_loss):由两部分组成,一部分是真实样本的损失(标签为1),另一部分是生成样本的损失(标签为0)。

        最后,我们通过算法设计来交替训练生成器和判别器,例如生成器每训练5个Epoch,我们就训练一次判别器:

  • 训练判别器:提高其区分真实样本和生成样本的能力。
  • 训练生成器:提高其生成真实样本的能力,目标是最大化判别器将其生成样本识别为真实样本的概率。

三、python实现

1、导库

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from torch.utils.data import DataLoader, TensorDataset
from sklearn.decomposition import PCA

2、数据处理

        这里我们的目标是训练一个生成对抗网络来生成iris数据,使用sklearn的iris数据集训练。这意味着,我们输入给生成器的信息中需要包含类别信息,这样生成器才能生成对应类别的数据样本。当然,这一步不是必要的,在类别不敏感的任务中,只需要生成符合要求的数据即可。

# 加载Iris数据集
iris = load_iris()
data = iris.data
labels = iris.target

# 标准化数据
scaler = StandardScaler()
data = scaler.fit_transform(data)

# One-hot编码标签
encoder = OneHotEncoder(sparse=False)
# torch.Size([100, 3])
labels = encoder.fit_transform(labels.reshape(-1, 1))

# 转换为PyTorch张量
data = torch.FloatTensor(data)
labels = torch.FloatTensor(labels)

# 创建数据加载器
batch_size = 32
dataset = TensorDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

3、构建生成器

        这里,我们构建一个全连接神经网络。生成器的输入包括随机初始化的x,以及x对应的期望类别,期望类别是可以真实标签,表示生成对应类别下的数据样本。

# 生成器网络
class Generator(nn.Module):
    def __init__(self, input_dim, label_dim, output_dim):
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值