pytorch gans

本文介绍生成对抗网络(GANs)的基本原理及其实现方法,并通过手写数字生成实例展示了GANs的强大能力。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

GANs

GANs(生成对抗网络),顾名思义,这个网络第一部分是生成网络,第二部分对抗模型严格来讲是一个判别器;简单来说,就是让两个网络相互竞争,生成网络来生成假的数据,对抗网络通过判别器去判别真伪,最后希望生成器生成的数据能够以假乱真。

可以用下图来简单的看一看这两个过程。

 

下面我们就来依次介绍。

Discriminator Network

首先我们来讲一下对抗过程,因为这个过程更加简单。

对抗过程简单来说就是一个判断真假的判别器,相当于一个二分类问题,我们输入一张真的图片希望判别器输出的结果是1,输入一张假的图片希望判别器输出的结果是0。这其实已经和原图片的label没有关系了,不管原图片到底是一个多少类别的图片,他们都统一称为真的图片,label是1表示真实的;而生成的假的图片的label是0表示假的。

我们训练的过程就是希望这个判别器能够正确的判出真的图片和假的图片,这其实就是一个简单的二分类问题,对于这个问题可以用我们前面讲过的很多方法去处理,比如logistic回归,深层网络,卷积神经网络,循环神经网络都可以。

Generative Network

接着我们要看看如何生成一张假的图片。首先给出一个简单的高维的正态分布的噪声向量,如上图所示的D-dimensional noise vector,这个时候我们可以通过仿射变换,也就是xw+b将其映射到一个更高的维度,然后将他重新排列成一个矩形,这样看着更像一张图片,接着进行一些卷积、池化、激活函数处理,最后得到了一个与我们输入图片大小一模一样的噪音矩阵,这就是我们所说的假的图片,这个时候我们如何去训练这个生成器呢?就是通过判别器来得到结果,然后希望增大判别器判别这个结果为真的概率,在这一步我们不会更新判别器的参数,只会更新生成器的参数。

如下图所示

 

以上的过程已经简单的阐述了生成对抗网络的学习过程,如果仍然不太清楚这个过程,下面我们会通过代码来更清晰地展示整个过程。

 

Code

我们会使用mnist手写数字来做数据集,通过生成对抗网络我们希望生成一些“以假乱真”的手写字体。为了加快训练过程,我们不使用卷积网络来做判别器,我们使用简单的多层网络来进行判别。

Discriminator Network


 
  1. class discriminator(nn.Module):

  2. def __init__(self):

  3. super(discriminator, self).__init__()

  4. self.dis = nn.Sequential(

  5. nn.Linear(784, 256),

  6. nn.LeakyReLU(0.2),

  7. nn.Linear(256, 256),

  8. nn.LeakyReLU(0.2),

  9. nn.Linear(256, 1),

  10. nn.Sigmoid()

  11. )

  12.  
  13. def forward(self, x):

  14. x = self.dis(x)

  15. return x

以上这个网络是一个简单的多层神经网络,将图片28x28展开成784,然后通过多层感知器,中间经过斜率设置为0.2的LeakyReLU激活函数,最后接sigmoid激活函数得到一个0到1之间的概率进行二分类。之所以使用LeakyRelu而不是用ReLU激活函数是因为经过实验LeakyReLU的表现更好。

Generative Network


 
  1. class generator(nn.Module):

  2. def __init__(self, input_size):

  3. super(generator, self).__init__()

  4. self.gen = nn.Sequential(

  5. nn.Linear(input_size, 256),

  6. nn.ReLU(True),

  7. nn.Linear(256, 256),

  8. nn.ReLU(True),

  9. nn.Linear(256, 784),

  10. nn.Tanh()

  11. )

  12.  
  13. def forward(self, x):

  14. x = self.gen(x)

  15. return x

输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维,然后通过ReLU激活函数,接着进行一个线性变换,再经过一个ReLU激活函数,然后经过线性变换将其变成784维,最后经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间。

Discriminator Train

判别器的训练由两部分组成,第一部分是真的图像判别为真,第二部分是假的图片判别为假,在这两个过程中,生成器的参数不参与更新。

首先我们需要定义loss的度量方式和优化器,loss度量使用二分类的交叉熵,优化器注意使用的学习率是0.0003


 
  1. criterion = nn.BCELoss()

  2. d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)

  3. g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

接着进入训练


 
  1. img = img.view(num_img, -1) # 将图片展开乘28x28=784

  2. real_img = Variable(img).cuda() # 将tensor变成Variable放入计算图中

  3. real_label = Variable(torch.ones(num_img)).cuda() # 定义真实label为1

  4. fake_label = Variable(torch.zeros(num_img)).cuda() # 定义假的label为0

  5.  
  6. # compute loss of real_img

  7. real_out = D(real_img) # 将真实的图片放入判别器中

  8. d_loss_real = criterion(real_out, real_label) # 得到真实图片的loss

  9. real_scores = real_out # 真实图片放入判别器输出越接近1越好

  10.  
  11. # compute loss of fake_img

  12. z = Variable(torch.randn(num_img, z_dimension)).cuda() # 随机生成一些噪声

  13. fake_img = G(z) # 放入生成网络生成一张假的图片

  14. fake_out = D(fake_img) # 判别器判断假的图片

  15. d_loss_fake = criterion(fake_out, fake_label) # 得到假的图片的loss

  16. fake_scores = fake_out # 假的图片放入判别器越接近0越好

  17.  
  18. # bp and optimize

  19. d_loss = d_loss_real + d_loss_fake # 将真假图片的loss加起来

  20. d_optimizer.zero_grad() # 归0梯度

  21. d_loss.backward() # 反向传播

  22. d_optimizer.step() # 更新参数

我已经把每一步都注释在了代码上,这样更加便于大家阅读,这是一个判别器的训练过程,我们希望判别器能够正确辨别出真假图片。

Generative Train

在生成网络的训练中,我们希望生成一张假的图片,然后经过判别器之后希望他能够判断为真的图片,在这个过程中,我们将判别器固定,将假的图片传入判别器的结果与真实label对应,反向传播更新的参数是生成网络里面的参数,这样我们就可以通过跟新生成网络里面的参数来使得判别器判断生成的假的图片为真,这样就达到了生成对抗的作用。


 
  1. # compute loss of fake_img

  2. z = Variable(torch.randn(num_img, z_dimension)).cuda() # 得到随机噪声

  3. fake_img = G(z) # 生成假的图片

  4. output = D(fake_img) # 经过判别器得到结果

  5. g_loss = criterion(output, real_label) # 得到假的图片与真实图片label的loss

  6.  
  7. # bp and optimize

  8. g_optimizer.zero_grad() # 归0梯度

  9. g_loss.backward() # 反向传播

  10. g_optimizer.step() # 更新生成网络的参数

这样我们就写好了一个简单的生成网络,通过不断地训练我们希望能够生成很真的图片。

Result

通过不断训练,我们可以得到下面的图片

这是真实图片

 

第1幅为第一次生成的噪声图片,之后分别是跑完15次生成的图片,跑完30次,跑完50次,跑完70次,最后一个是跑完100次生成的图片

 

 

怎么样,是不是特别神奇,我们居然可以生成一副看着很真的图片,这里我们只是用了简单的多层感知器来生成和判别模型,我们可以用更复杂的卷积神经网络来做同样的事情,代码将和本文的代码放在一起,有兴趣的同学可以自己去看看,然后放几张卷积网络生成的图片

 

 

可以发现产生的噪声更少了,训练也更加稳定,主要是里面引入了Batchnormalization,另外gan的训练过程是特别困难的,两个对偶网络相互学习,这个时候有一些训练技巧可以使得训练生成更加稳定。

 

最后我们来说一下为何Gans能够成为最近20年来机器学习以及深度学习界革命性的发现。这是因为不管是深度学习还是机器学习仍然很大一部分是监督学习,但是创建这么多有label的数据集所需要的人力物力是极大的,同时遇到的新的任务时我们很容易得到原始的没有label的数据集,这是我们需要花大量的时间去给其标定label,所以很多人都认为无监督学习才是机器学习的未来,这个时候Gans的出现为无监督学习提供了有力的支持,这当然引起了学界的大量关注,同时基于Gans的应用也越来越多,业界对其也非常狂热。

最后引用Yan Lecun的话:”它(Gans)为创建无监督学习模型提供了强有力的算法框架,有望帮助我们为 AI 加入常识(common sense)。我们认为,沿着这条路走下去,有不小的成功机会能开发出更智慧的 AI 。”

以上我们简单的介绍了Gans,通过网络实现了手写字体的生成,当然还有更多的变形和应用,有兴趣的同学可以自己阅读相关论文深入了解。

全部代码

简单网络(非卷积),训练快


 
  1. import torch

  2. import torchvision

  3. import torch.nn as nn

  4. import torch.nn.functional as F

  5. from torchvision import datasets

  6. from torchvision import transforms

  7. from torchvision.utils import save_image

  8. from torch.autograd import Variable

  9. import os

  10.  
  11. if not os.path.exists('./img'):

  12. os.mkdir('./img')

  13.  
  14.  
  15. def to_img(x):

  16. out = 0.5 * (x + 1)

  17. out = out.clamp(0, 1)

  18. out = out.view(-1, 1, 28, 28)

  19. return out

  20.  
  21.  
  22. batch_size = 128

  23. num_epoch = 100

  24. z_dimension = 100

  25.  
  26. # Image processing

  27. img_transform = transforms.Compose([

  28. transforms.ToTensor(),

  29. transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))

  30. ])

  31. # MNIST dataset

  32. mnist = datasets.MNIST(

  33. root='./data/', train=True, transform=img_transform, download=True)

  34. # Data loader

  35. dataloader = torch.utils.data.DataLoader(

  36. dataset=mnist, batch_size=batch_size, shuffle=True)

  37.  
  38.  
  39. # Discriminator

  40. class discriminator(nn.Module):

  41. def __init__(self):

  42. super(discriminator, self).__init__()

  43. self.dis = nn.Sequential(

  44. nn.Linear(784, 256),

  45. nn.LeakyReLU(0.2),

  46. nn.Linear(256, 256),

  47. nn.LeakyReLU(0.2),

  48.             nn.Linear(256, 1),

  49.             nn.Sigmoid())

  50.  
  51. def forward(self, x):

  52. x = self.dis(x)

  53. return x

  54.  
  55.  
  56. # Generator

  57. class generator(nn.Module):

  58. def __init__(self):

  59. super(generator, self).__init__()

  60. self.gen = nn.Sequential(

  61. nn.Linear(100, 256),

  62. nn.ReLU(True),

  63. nn.Linear(256, 256),

  64.             nn.ReLU(True),

  65.             nn.Linear(256, 784),

  66.             nn.Tanh())

  67.  
  68. def forward(self, x):

  69. x = self.gen(x)

  70. return x

  71.  
  72.  
  73. D = discriminator()

  74. G = generator()

  75. if torch.cuda.is_available():

  76. D = D.cuda()

  77. G = G.cuda()

  78. # Binary cross entropy loss and optimizer

  79. criterion = nn.BCELoss()

  80. d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)

  81. g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

  82.  
  83. # Start training

  84. for epoch in range(num_epoch):

  85. for i, (img, _) in enumerate(dataloader):

  86. num_img = img.size(0)

  87. # =================train discriminator

  88. img = img.view(num_img, -1)

  89. real_img = Variable(img).cuda()

  90. real_label = Variable(torch.ones(num_img)).cuda()

  91. fake_label = Variable(torch.zeros(num_img)).cuda()

  92.  
  93. # compute loss of real_img

  94. real_out = D(real_img)

  95. d_loss_real = criterion(real_out, real_label)

  96. real_scores = real_out # closer to 1 means better

  97.  
  98. # compute loss of fake_img

  99. z = Variable(torch.randn(num_img, z_dimension)).cuda()

  100. fake_img = G(z)

  101. fake_out = D(fake_img)

  102. d_loss_fake = criterion(fake_out, fake_label)

  103. fake_scores = fake_out # closer to 0 means better

  104.  
  105. # bp and optimize

  106. d_loss = d_loss_real + d_loss_fake

  107. d_optimizer.zero_grad()

  108. d_loss.backward()

  109. d_optimizer.step()

  110.  
  111. # ===============train generator

  112. # compute loss of fake_img

  113. z = Variable(torch.randn(num_img, z_dimension)).cuda()

  114. fake_img = G(z)

  115. output = D(fake_img)

  116. g_loss = criterion(output, real_label)

  117.  
  118. # bp and optimize

  119. g_optimizer.zero_grad()

  120. g_loss.backward()

  121. g_optimizer.step()

  122.  
  123. if (i + 1) % 100 == 0:

  124. print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '

  125. 'D real: {:.6f}, D fake: {:.6f}'.format(

  126. epoch, num_epoch, d_loss.data[0], g_loss.data[0],

  127. real_scores.data.mean(), fake_scores.data.mean()))

  128. if epoch == 0:

  129. real_images = to_img(real_img.cpu().data)

  130. save_image(real_images, './img/real_images.png')

  131.  
  132. fake_images = to_img(fake_img.cpu().data)

  133. save_image(fake_images, './img/fake_images-{}.png'.format(epoch + 1))

  134.  
  135. torch.save(G.state_dict(), './generator.pth')

  136. torch.save(D.state_dict(), './discriminator.pth')

卷积网络版


 
  1. import torch

  2. import torch.nn as nn

  3. from torch.autograd import Variable

  4. from torch.utils.data import DataLoader

  5. from torchvision import transforms

  6. from torchvision import datasets

  7. from torchvision.utils import save_image

  8. import os

  9.  
  10. if not os.path.exists('./dc_img'):

  11. os.mkdir('./dc_img')

  12.  
  13.  
  14. def to_img(x):

  15. out = 0.5 * (x + 1)

  16. out = out.clamp(0, 1)

  17. out = out.view(-1, 1, 28, 28)

  18. return out

  19.  
  20.  
  21. batch_size = 128

  22. num_epoch = 100

  23. z_dimension = 100 # noise dimension

  24.  
  25. img_transform = transforms.Compose([

  26. transforms.ToTensor(),

  27. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

  28. ])

  29.  
  30. mnist = datasets.MNIST('./data', transform=img_transform)

  31. dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True,

  32. num_workers=4)

  33.  
  34.  
  35. class discriminator(nn.Module):

  36. def __init__(self):

  37. super(discriminator, self).__init__()

  38. self.conv1 = nn.Sequential(

  39. nn.Conv2d(1, 32, 5, padding=2), # batch, 32, 28, 28

  40. nn.LeakyReLU(0.2, True),

  41. nn.AvgPool2d(2, stride=2), # batch, 32, 14, 14

  42. )

  43. self.conv2 = nn.Sequential(

  44. nn.Conv2d(32, 64, 5, padding=2), # batch, 64, 14, 14

  45. nn.LeakyReLU(0.2, True),

  46. nn.AvgPool2d(2, stride=2) # batch, 64, 7, 7

  47. )

  48. self.fc = nn.Sequential(

  49. nn.Linear(64*7*7, 1024),

  50. nn.LeakyReLU(0.2, True),

  51. nn.Linear(1024, 1),

  52. nn.Sigmoid()

  53. )

  54.  
  55. def forward(self, x):

  56. '''

  57. x: batch, width, height, channel=1

  58. '''

  59. x = self.conv1(x)

  60. x = self.conv2(x)

  61. x = x.view(x.size(0), -1)

  62. x = self.fc(x)

  63. return x

  64.  
  65.  
  66. class generator(nn.Module):

  67. def __init__(self, input_size, num_feature):

  68. super(generator, self).__init__()

  69. self.fc = nn.Linear(input_size, num_feature) # batch, 3136=1x56x56

  70. self.br = nn.Sequential(

  71. nn.BatchNorm2d(1),

  72. nn.ReLU(True)

  73. )

  74. self.downsample1 = nn.Sequential(

  75. nn.Conv2d(1, 50, 3, stride=1, padding=1), # batch, 50, 56, 56

  76. nn.BatchNorm2d(50),

  77. nn.ReLU(True)

  78. )

  79. self.downsample2 = nn.Sequential(

  80. nn.Conv2d(50, 25, 3, stride=1, padding=1), # batch, 25, 56, 56

  81. nn.BatchNorm2d(25),

  82. nn.ReLU(True)

  83. )

  84. self.downsample3 = nn.Sequential(

  85. nn.Conv2d(25, 1, 2, stride=2), # batch, 1, 28, 28

  86. nn.Tanh()

  87. )

  88.  
  89. def forward(self, x):

  90. x = self.fc(x)

  91. x = x.view(x.size(0), 1, 56, 56)

  92. x = self.br(x)

  93. x = self.downsample1(x)

  94. x = self.downsample2(x)

  95. x = self.downsample3(x)

  96. return x

  97.  
  98.  
  99. D = discriminator().cuda() # discriminator model

  100. G = generator(z_dimension, 3136).cuda() # generator model

  101.  
  102. criterion = nn.BCELoss() # binary cross entropy

  103.  
  104. d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)

  105. g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

  106.  
  107. # train

  108. for epoch in range(num_epoch):

  109. for i, (img, _) in enumerate(dataloader):

  110. num_img = img.size(0)

  111. # =================train discriminator

  112. real_img = Variable(img).cuda()

  113. real_label = Variable(torch.ones(num_img)).cuda()

  114. fake_label = Variable(torch.zeros(num_img)).cuda()

  115.  
  116. # compute loss of real_img

  117. real_out = D(real_img)

  118. d_loss_real = criterion(real_out, real_label)

  119. real_scores = real_out # closer to 1 means better

  120.  
  121. # compute loss of fake_img

  122. z = Variable(torch.randn(num_img, z_dimension)).cuda()

  123. fake_img = G(z)

  124. fake_out = D(fake_img)

  125. d_loss_fake = criterion(fake_out, fake_label)

  126. fake_scores = fake_out # closer to 0 means better

  127.  
  128. # bp and optimize

  129. d_loss = d_loss_real + d_loss_fake

  130. d_optimizer.zero_grad()

  131. d_loss.backward()

  132. d_optimizer.step()

  133.  
  134. # ===============train generator

  135. # compute loss of fake_img

  136. z = Variable(torch.randn(num_img, z_dimension)).cuda()

  137. fake_img = G(z)

  138. output = D(fake_img)

  139. g_loss = criterion(output, real_label)

  140.  
  141. # bp and optimize

  142. g_optimizer.zero_grad()

  143. g_loss.backward()

  144. g_optimizer.step()

  145.  
  146. if (i+1) % 100 == 0:

  147. print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '

  148. 'D real: {:.6f}, D fake: {:.6f}'

  149. .format(epoch, num_epoch, d_loss.data[0], g_loss.data[0],

  150. real_scores.data.mean(), fake_scores.data.mean()))

  151. if epoch == 0:

  152. real_images = to_img(real_img.cpu().data)

  153. save_image(real_images, './dc_img/real_images.png')

  154.  
  155. fake_images = to_img(fake_img.cpu().data)

  156. save_image(fake_images, './dc_img/fake_images-{}.png'.format(epoch+1))

  157.  
  158. torch.save(G.state_dict(), './generator.pth')

  159. torch.save(D.state_dict(), './discriminator.pth')

参考:

1.https://zhuanlan.zhihu.com/p/27386749

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值