涉及的论文
GAN
https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf
DCGAN
https://arxiv.org/pdf/1511.06434.pdf
测试用的数据集
Celeb-A Faces
数据集网站:
http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
下载链接:
百度 网盘 :https://pan.baidu.com/s/1eSNpdRG#list/path=%2F
谷歌 网盘 :https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg
数据集下载后,找到一个文件叫 img_align_celeba.zip
创建一个文件夹data,然后在data内创建一个文件夹celeba.
将img_align_celeba.zip 拷贝进celeba,然后解压
unzip img_align_celeba.zip
会生成这样的目录结构
./data/celeba/
->img_align_celeba
->188242.jpg
->173822.jpg
->284792.jpg
...
这一步很重要,因为我们的代码中使用这样的文件结构.
实现DCGAN 包含的文件
main.py
etc.py
graph.py
model.py
show.py
record.py
DCGAN_architecture.py
celeba_dataset.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : main.py
# Create date : 2019-01-25 14:07
# Modified date : 2019-01-27 22:36
# Author : DARREN
# Describe : not set
# Email : lzygzh@126.com
#####################################
from __future__ import division
from __future__ import print_function
import celeba_dataset
from etc import config
from graph import NNGraph
def run():
dataloader = celeba_dataset.get_dataloader(config)
g = NNGraph(dataloader, config)
g.train()
run()
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : etc.py
# Create date : 2019-01-24 17:02
# Modified date : 2019-01-28 23:37
# Author : DARREN
# Describe : not set
# Email : lzygzh@126.com
#####################################
from __future__ import division
from __future__ import print_function
import torch
config = {
}
config["dataset"] = "celeba"
config["batch_size"] = 128
config["image_size"] = 64
config["num_epochs"] = 5
config["data_path"] = "data/%s" % config["dataset"]
config["workers"] = 2
config["print_every"] = 200
config["save_every"] = 500
config["manual_seed"] = 999
config["train_load_check_point_file"] = False
#config["manual_seed"] = random.randint(1, 10000) # use if you want new results
config["number_channels"] = 3
config["size_of_z_latent"] = 100
config["number_gpus"] = 1
config["number_of_generator_feature"] = 64
config["number_of_discriminator_feature"] = 64
config["learn_rate"] = 0.0002
config["beta1"] =0.5
config["real_label"] = 1
config["fake_label"] = 0
config["device"] = torch.device("cuda:0" if (torch.cuda.is_available() and config["number_gpus"] > 0) else "cpu")
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : graph.py
# Create date : 2019-01-24 17:17
# Modified date : 2019-01-28 17:46
# Author : DARREN
# Describe : not set
# Email : lzygzh@126.com
#####################################
from __future__ import division
from __future__ import print_function
import os
import time
import torch
import torchvision.utils as vutils
import model
import show
import record
class NNGraph(object):
def __init__(self, dataloader, config):
super(NNGraph, self).__init__()
self.config = config
self.train_model = self._get_train_model(config)
record.record_dict(self.config, self.train_model["config"])
self.config = self.train_model["config"]
self.dataloader = dataloader
def _get_train_model(self, config):
train_model = model.init_train_model(config)
train_model = self._load_train_model(train_model)
return train_model
def _save_train_model(self):
model_dict = model.get_model_dict(self.train_model)
file_full_path = record.get_check_point_file_full_path(self.config)
torch.save(model_dict, file_full_path)
def _load_train_model(self, train_model):
file_full_path = record.get_check_point_file_full_path(self.config)
if os.path.exists(file_full_path) and self.config["train_load_check_point_file"]:
checkpoint = torch.load(file_full_path)
train_model = model.load_model_dict(train_model, checkpoint)
return train_model
def _train_step(self, data):
netG = self.train_model["netG"]
optimizerG = self.train_model["optimizerG"]
netD = self.train_model["netD"]
optimizerD = self.train_model["optimizerD"]
criterion = self.train_model["criterion"]
device = self.config["device"]
real_data = data[0].to(device)
noise = model.get_noise(real_data, self.config)
fake_data = netG(noise)
label = model.get_label(real_data, self.config)
errD, D_x, D_G_z1 = model.get_Discriminator_loss(netD, optimizerD, real_data, fake_data.detach(), label, criterion, self.config)
errG, D_G_z2 = model.get_Generator_loss(netG, netD, optimizerG, fake_data, label, criterion, self.config)
return errD, errG, D_x, D_G_z1, D_G_z2
def _train_a_step(self, data, i, epoch):
start = time.time()
errD, errG, D_x, D_G_z1, D_G_z2 = self._train_step(data)
end = time.time()
step_time