cGAN/cDCGAN,MNIST数据集初体验(内含原理,代码)

本文介绍了cGAN/cDCGAN在MNIST数据集上的应用,探讨了条件GAN(cGAN)的概念,以及其与深度卷积GAN(DCGAN)的区别。通过代码示例展示了网络训练、输入输出、网络结构和损失函数设计,帮助读者理解生成器和判别器的相互作用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

​生成式对抗网络(Generative Adversarial Networks, GAN),简称GAN网络。有人说这是21世纪最让人激动的“发明”,虽然我忘了我是从哪看到的这句话,貌似是发明了卷积神经网络那位大佬说的。我试过以后,对于AI兴趣爱好者来说

确实挺激动的!

对于标题中的cGAN/cDCGAN,小c,全称是conditional,条件的。DC,全称是Deep Convolution,深度卷积。都是GAN网络的一个变种。对于DCGAN与GAN的关系,也很简单,因为最开始GAN网络是用神经网络设计的,而后来出现了计算能力更强的卷积(CNN),训练逻辑相同,只是计算操作不同,当然可以相互替换。

对于原理,网传:一个生成器(Generator),一个判别器(Discriminator),他两相互博弈,相爱相杀,最后产生一个好的结果。。。

What?还要动手吗?

对于此种高端解释,我等菜鸡无法领会,我只想知道网络是怎么训练的?两个部分的输入输出分别是什么?网络如何搭建?Loss如何设计?有了这些,你的程序就可以跑了

还是从代码中理解啥是相爱相杀吧。

先放一张整体原理图,来个大致印象
图片来源:github一个老哥的仓库
那个G,就是生成器,那个D,就是判别器。其余就是常规表示网络的结构了,是如何设计的。各位应该发现图中还有个小y,这就是cGAN网络中的c

较常规GAN网络,多了个条件标签

这里想啰嗦一句,这个版本的cGAN在条件标签的处理上,用的是concatenate操作,也就是在某个维度上,直接叠加相关数据,一会在代码中也有显现。其余的还可不可以用别的操作来改善效果,本人很菜,还没有试过。

如图所示,因为用的是MNIST(手写数字体)数据集,每张图片的shape是[1, 28, 28],也就是单通道,分辨率是28x28。又因为是使用神经网络提取特征,所以需要将图片打平操作,所以生成器(G)最后生成的本来应该是一张图片的shape,这里的话就是784,这个数字各位应该不陌生,不多废话。

可以看到,G的输入就是100维的一个随机数,shape是[100, 100],这里生成100张假的数字体图像,对应的label,小y,也就是[100, 10],做了One-Hot编码。然而输出就是[100, 784],经过一些类似imshow等显示图片的函数的时候,在reshape成[100, 1, 28, 28], 就可以显示啦

再看D,判别器,这个就相对简单一些,就是平常看到的分类网络的结构。输入是由G生成的假的图像数据,输出只有两个,真or假,real or fake,用1和0代替结果,shape为[batch, 1],只有一堆0或1作为label。在代码中一看便知,就理解了。

ok,少废话,上代码(下面是完整的,来源也是GitHub的那位老哥的仓库,稍做了些修改,要不在我的环境下直接跑不了)

import os, time
import matplotlib.pyplot as plt
import itertools
import pickle
import imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
​
# G(z)
class generator(nn.Module):
    # initializers
    def __init__(self):
        super(generator, self).__init__()
        self.fc1_1 = nn.Linear(100, 256)
        self.fc1_1_bn = nn.BatchNorm1d(256)
        self.fc1_2 = nn.Linear(10, 256)
        self.fc1_2_bn = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(512, 512)
        self.fc2_bn = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512, 1024)
        self.fc3_bn = nn.BatchNorm1d(1024)
        self.fc4 = nn.Linear(1024, 784)# weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)# forward method
    def forward(self, input, label):
        x = F.relu(self.fc1_1_bn(self.fc1_1(input)))
        y = F.relu(self.fc1_2_bn(self.fc1_2(label)))
        x = torch.cat([x, y], 1)
        x = F.relu(self.fc2_bn(self.fc2(x)))
        x = F.relu(self.fc3_bn(self.fc3(x)))
        x = F.tanh(self.fc4(x))return x
​
class discriminator(nn.Module):
    # initializers
    def __init__(self):
        super(discriminator, self).__init__()
        self.fc1_1 = nn.Linear(784, 1024)
        self.fc1_2 = nn.Linear(10, 1024)
        self.fc2 = nn.Linear(2048, 512)
        self.fc2_bn = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512, 256)
        self.fc3_bn = nn.BatchNorm1d(256)
        self.fc4 = nn.Linear(256, 1)# weight_init
    def weight_init(self, mean,
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值