CGAN的超简单实现,基于pytorch 0.4。
刚开始搭建了一个原始GAN网络,没多久就遇到模型崩溃的问题,生成的样本丰富性很少,所以索性直接改成CGAN ,整个原理还是很简单的,改起来很快,主要是参数调整真的让人头大。
GAN 训练了35个epoch的效果,几乎只生成3和5的样本。
#CGAN训练效果,第8个epoch , 可以看到生成样本丰富性很高,而且质量很不错。
代码
代码有点点乱,将就能用就行~
#CGANnets
import torch
import torch.nn as nn
import torch.functional as F
#变成CGAN 在fc层嵌入 one-ho编码
class discriminator(nn.Module):
def __init__(self):
super(discriminator,self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1,32,5),
nn.LeakyReLU(0.2,True),
nn.MaxPool2d(2,stride = 2),
)
self.conv2 = nn.Sequential(
nn.Conv2d(32, 64, 5,padding=2),
nn.LeakyReLU(0.2, True),
nn.MaxPool2d(2, stride=2),
)
self.fc = nn.Sequential(
nn.Linear(64*6*6+10,1024),
nn.LeakyReLU(0.2,True),
nn.Linear(1024,1),
nn.Sigmoid()
)
def forward(self, x,labels):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0),-1)
x = torch.cat((x,labels),1)
x = self.fc(x)
return x
class generator(nn.Module):
def __init__(self, input_size, num_feature):
super(generator, self).__init__()
self.fc = nn.Linear(input_size+10, num_feature) # batch, 3136=1x56x56
self.br = nn.Sequential(
nn.BatchNorm2d(1),
nn.ReLU(True)
)
self.downsample1 = nn.Sequential(
nn.Conv2d(1,50,3,stride=1,padding=1),
nn.BatchNorm2d(50),
nn.ReLU(True)
)
self.downsample2 = nn.Sequential(
nn.Conv2d(50,25,3,stride=1,padding=1),
nn.BatchNorm2d(25),
nn.ReLU(True)
)
self.downsample3 = nn.Sequential(
nn.Conv2d(25,1,2,stride = 2),
nn.Tanh()
)
def forward(self,z,labels):
'''
:param x: (batchsize,100)的随机噪声
:param label: (batchsize,10) 的one-hot 标签编码
:return:
'''
x = torch.cat((z,labels),1) #沿1维拼接
x = self.fc(x)
x = x.view(x.size(0),1,56,56)
x = self.br(x)
x = self.downsample1(x)
x = self.downsample2(x)
x = self.downsample3(x)
return x
##train.py
import torc