GAN基础代码

GAN

对抗生成网络

 

生成器:生成假的数据,骗过判别器

判别器:将假数据判别为假,真实数据判别为真。

流程:

1.导入数据

2.网络架构:

输入层:待生成图像(噪音)和真实数据

生成网络:将噪音图像进行生成

判别网络

(1)判断真实图像输出结果 1

(2)判断生成图像输出结果 0

3.目标函数:

(1):对于生成网络,要使得 生成结果 通过判别网络为真

(2):对于判别网络,要使得输入为真实图像时判别为真,输入为生成图像时判别为假


判别模型:log(D1(x)) + log(1-D2(G(z)))

生成网络:log(D2(G(z)))
 

代码

BCELoss:

import torch    #交叉熵和BCEloss差不多
from torch import autograd
#随机生成一个数据作为预测值
input = autograd.Variable(torch.tensor([[ 1.9072,  1.1079,  1.4906],
        [-0.6584, -0.0512,  0.7608],
        [-0.0614,  0.6583,  0.1095]]), requires_grad=True)
print(input)
print('-'*100)
#因为得到的数据有正有负,用sigmoid函数映射一下,变成规范一点的0-1之间的预测结果
from torch import nn
m = nn.Sigmoid()
print(m(input))
print('-'*100)

#构建标签  想办法让判别器判断真时预测1,判别假时预测0
target = torch.FloatTensor([[0, 1, 1], [1, 1, 1], [0, 0, 0]])
print(target)
print('-'*100)

import math
#损失函数   左边当作正确预测类别,右边为错误预测类别
r11 = 0 * math.log(0.8707) + (1-0) * math.log((1 - 0.8707))
r12 = 1 * math.log(0.7517) + (1-1) * math.log((1 - 0.7517))
r13 = 1 * math.log(0.8162) + (1-1) * math.log((1 - 0.8162))

r21 = 1 * math.log(0.3411) + (1-1) * math.log((1 - 0.3411))
r22 = 1 * math.log(0.4872) + (1-1) * math.log((1 - 0.4872))
r23 = 1 * math.log(0.6815) + (1-1) * math.log((1 - 0.6815))

r31 = 0 * math.log(0.4847) + (1-0) * math.log((1 - 0.4847))
r32 = 0 * math.log(0.6589) + (1-0) * math.log((1 - 0
<think>好的,用户现在想要了解GAN的源代码。之前已经讨论过GAN的基本概念、应用场景、训练难点以及cGAN的特点。现在用户的需求更具体,涉及到代码层面,可能需要实际操作的指导。 首先,我需要确定用户的背景。他们可能是刚开始学习GAN的学生,或者有一定编程经验但想实现GAN的开发者。考虑到用户之前的问题,可能已经对GAN有基本了解,现在需要代码示例来实践。 接下来,用户的需求可能不仅仅是获取代码,而是理解代码的结构和实现细节。所以,我应该提供一个简单但完整的代码示例,同时解释关键部分,比如生成器和判别器的架构、损失函数、训练循环等。使用PyTorch框架是合理的选择,因为它在研究领域广泛使用,代码相对简洁。 需要考虑用户可能遇到的困难,例如环境配置、依赖库安装,或者代码中的某些部分不易理解。因此,在代码注释中详细说明每一步的作用,并在解释部分强调关键点,比如优化器的选择、噪声向量的生成、损失函数的计算等。 另外,用户可能希望知道如何运行和测试这段代码,所以需要给出训练过程的简要说明,比如需要准备数据集(这里用MNIST作为示例),调整训练轮数,以及如何观察训练结果。可能还需要提到常见的调试方法,比如检查损失值的变化,生成样本的可视化等。 同时,应该预见到用户可能对代码中的某些参数有疑问,比如为什么选择特定的隐藏层大小、学习率,或者优化器的参数。在解释中需要提到这些是常见配置,但可以根据实际情况调整,并引导用户参考更详细的教程或文档。 最后,确保代码格式正确,符合PyTorch的最佳实践,避免常见的错误,比如忘记清零梯度、错误的损失函数应用等。检查代码是否有语法错误,并测试过是否可以运行,以增加可靠性。</think>以下是使用 **PyTorch** 框架实现一个基础GAN的简化代码示例,包含生成器(Generator)和判别器(Discriminator)的核心逻辑: ```python import torch import torch.nn as nn import torch.optim as optim import torchvision.datasets as datasets import torchvision.transforms as transforms # 定义生成器(Generator) class Generator(nn.Module): def __init__(self, latent_dim=100, img_shape=(28, 28)): super(Generator, self).__init__() self.model = nn.Sequential( nn.Linear(latent_dim, 256), nn.LeakyReLU(0.2), nn.Linear(256, 512), nn.LeakyReLU(0.2), nn.Linear(512, 1024), nn.LeakyReLU(0.2), nn.Linear(1024, img_shape[0] * img_shape[1]), # 输出与图像尺寸匹配 nn.Tanh() # 输出范围[-1, 1] ) def forward(self, z): return self.model(z).view(-1, 1, 28, 28) # 假设生成MNIST图像(1x28x28) # 定义判别器(Discriminator) class Discriminator(nn.Module): def __init__(self, img_shape=(28, 28)): super(Discriminator, self).__init__() self.model = nn.Sequential( nn.Linear(img_shape[0] * img_shape[1], 1024), nn.LeakyReLU(0.2), nn.Dropout(0.3), nn.Linear(1024, 512), nn.LeakyReLU(0.2), nn.Dropout(0.3), nn.Linear(512, 256), nn.LeakyReLU(0.2), nn.Dropout(0.3), nn.Linear(256, 1), nn.Sigmoid() # 输出概率值[0,1] ) def forward(self, img): img_flat = img.view(img.size(0), -1) return self.model(img_flat) # 超参数设置 latent_dim = 100 lr = 0.0002 batch_size = 64 epochs = 200 # 初始化模型与优化器 generator = Generator(latent_dim) discriminator = Discriminator() optimizer_G = optim.Adam(generator.parameters(), lr=lr) optimizer_D = optim.Adam(discriminator.parameters(), lr=lr) loss_fn = nn.BCELoss() # 二分类交叉熵损失 # 数据加载(以MNIST为例) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) # 归一化到[-1,1] ]) dataset = datasets.MNIST(root='./data', download=True, transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) # 训练循环 for epoch in range(epochs): for i, (real_imgs, _) in enumerate(dataloader): # 生成假数据 z = torch.randn(real_imgs.size(0), latent_dim) # 随机噪声 fake_imgs = generator(z) # 训练判别器 optimizer_D.zero_grad() real_labels = torch.ones(real_imgs.size(0), 1) # 真实样本标签为1 fake_labels = torch.zeros(real_imgs.size(0), 1) # 假样本标签为0 loss_real = loss_fn(discriminator(real_imgs), real_labels) loss_fake = loss_fn(discriminator(fake_imgs.detach()), fake_labels) loss_D = (loss_real + loss_fake) / 2 loss_D.backward() optimizer_D.step() # 训练生成器 optimizer_G.zero_grad() gen_labels = torch.ones(real_imgs.size(0), 1) # 欺骗判别器 loss_G = loss_fn(discriminator(fake_imgs), gen_labels) loss_G.backward() optimizer_G.step() ``` --- ### 关键代码逻辑说明: 1. **生成器架构**: - 输入:随机噪声向量(`latent_dim`维) - 输出:与真实数据同维度的假样本(如MNIST为28x28) - 使用全连接层逐步放大维度,激活函数为`LeakyReLU`(缓解梯度消失) 2. **判别器架构**: - 输入:真实或生成的样本(展开为向量) - 输出:概率值(0=假,1=真) - 使用`Dropout`防止过拟合,最后通过`Sigmoid`输出概率 3. **训练步骤**: - **判别器训练**:同时用真实数据和生成数据计算损失,最大化对真假的区分能力 - **生成器训练**:固定判别器参数,最小化生成数据被判别为假的概率 --- ### 运行与调试建议: 1. **依赖安装**: ```bash pip install torch torchvision ``` 2. **可视化生成结果**: - 在训练过程中定期保存生成的图像(使用`torchvision.utils.save_image`) 3. **常见改进方向**: - 使用卷积结构(DCGAN)提升图像生成质量 - 调整学习率或使用`Adam`优化器的参数(beta1=0.5, beta2=0.999) - 添加梯度惩罚(如WGAN-GP)解决模式崩溃问题 如需更完整的代码实现(含可视化),可参考GitHub上的开源项目如[PyTorch-GAN](https://github.com/eriklindernoren/PyTorch-GAN)。
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值