GAN生成手写数字识别
生成对抗网络介绍:
生成对抗网络就是两个网络进行相互对抗,相互进行的过程。GAN的主要部分为生成器Generator和判别器Discriminator。
生成器:输入一个随机向量,生成一个图片。希望生成的图片越像真的越好。
判别器:输出一个图片,判别图片的真伪。希望生成的图片判别为假,数据集中的图片判别为真。
判别器:
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.fc1=nn.Linear(28*28,256)
self.lr1=nn.LeakyReLU(0.2)
self.fc2=nn.Linear(256,256)
self.lr2=nn.LeakyReLU(0.2)
self.fc3=nn.Linear(256,1)
self.sigmoid=nn.Sigmoid()
def forward(self,input):
x=input.view(-1,28*28)
x=self.fc1(x)
x=self.lr1(x)
x=self.fc2(x)
x=self.lr2(x)
x=self.fc3(x)
y=self.sigmoid(x)
return y
生成器:
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc1=nn.Linear(100,256)
self.relu1=nn.ReLU()
self.fc2=nn.Linear(256,256)
self.relu2=nn.ReLU()
self.fc3=nn.Linear(256,28*28)
self.tanh=nn.Tanh()
def forward(self,input):

该博客介绍了如何利用生成对抗网络(GAN)训练一个手写数字识别模型。文章详细阐述了GAN的组成,包括生成器和判别器的结构,并提供了PyTorch实现的代码示例。通过训练,生成器试图生成逼真的手写数字图像,而判别器则努力区分真实图像和生成的图像。最终,经过多个epoch的训练,模型能够生成接近MNIST数据集的手写数字图像。
最低0.47元/天 解锁文章
783

被折叠的 条评论
为什么被折叠?



