pytorch实现DCGAN 生成人脸 celeba数据集

涉及的论文

GAN
https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf
DCGAN
https://arxiv.org/pdf/1511.06434.pdf

测试用的数据集

Celeb-A Faces
数据集网站:
http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
下载链接:
百度 网盘 :https://pan.baidu.com/s/1eSNpdRG#list/path=%2F
谷歌 网盘 :https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg

数据集下载后,找到一个文件叫 img_align_celeba.zip
创建一个文件夹data,然后在data内创建一个文件夹celeba.
将img_align_celeba.zip 拷贝进celeba,然后解压

unzip img_align_celeba.zip

会生成这样的目录结构

./data/celeba/
		->img_align_celeba
			->188242.jpg
			->173822.jpg
			->284792.jpg
			...

这一步很重要,因为我们的代码中使用这样的文件结构.

实现DCGAN 包含的文件

main.py
etc.py
graph.py
model.py
show.py
record.py
DCGAN_architecture.py
celeba_dataset.py

main.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : main.py
# Create date : 2019-01-25 14:07
# Modified date : 2019-01-27 22:36
# Author : DARREN
# Describe : not set
# Email : lzygzh@126.com
#####################################
from __future__ import division
from __future__ import print_function

import celeba_dataset
from etc import config
from graph import NNGraph

def run():
    dataloader = celeba_dataset.get_dataloader(config)
    g = NNGraph(dataloader, config)
    g.train()

run()

etc.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : etc.py
# Create date : 2019-01-24 17:02
# Modified date : 2019-01-28 23:37
# Author : DARREN
# Describe : not set
# Email : lzygzh@126.com
#####################################
from __future__ import division
from __future__ import print_function

import torch

config = {
   
   }

config["dataset"] = "celeba"
config["batch_size"] = 128
config["image_size"] = 64
config["num_epochs"] = 5
config["data_path"] = "data/%s" % config["dataset"]
config["workers"] = 2
config["print_every"] = 200
config["save_every"] = 500
config["manual_seed"] = 999
config["train_load_check_point_file"] = False
#config["manual_seed"] = random.randint(1, 10000) # use if you want new results

config["number_channels"] = 3
config["size_of_z_latent"] = 100
config["number_gpus"] = 1
config["number_of_generator_feature"] = 64
config["number_of_discriminator_feature"] = 64
config["learn_rate"] = 0.0002
config["beta1"] =0.5
config["real_label"] = 1
config["fake_label"] = 0
config["device"] = torch.device("cuda:0" if (torch.cuda.is_available() and config["number_gpus"] > 0) else "cpu")

graph.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : graph.py
# Create date : 2019-01-24 17:17
# Modified date : 2019-01-28 17:46
# Author : DARREN
# Describe : not set
# Email : lzygzh@126.com
#####################################
from __future__ import division
from __future__ import print_function

import os
import time
import torch
import torchvision.utils as vutils
import model
import show
import record

class NNGraph(object):
    def __init__(self, dataloader, config):
        super(NNGraph, self).__init__()
        self.config = config
        self.train_model = self._get_train_model(config)
        record.record_dict(self.config, self.train_model["config"])
        self.config = self.train_model["config"]
        self.dataloader = dataloader

    def _get_train_model(self, config):
        train_model = model.init_train_model(config)
        train_model = self._load_train_model(train_model)
        return train_model

    def _save_train_model(self):
        model_dict = model.get_model_dict(self.train_model)
        file_full_path = record.get_check_point_file_full_path(self.config)
        torch.save(model_dict, file_full_path)

    def _load_train_model(self, train_model):
        file_full_path = record.get_check_point_file_full_path(self.config)
        if os.path.exists(file_full_path) and self.config["train_load_check_point_file"]:
            checkpoint = torch.load(file_full_path)
            train_model = model.load_model_dict(train_model, checkpoint)
        return train_model

    def _train_step(self, data):
        netG = self.train_model["netG"]
        optimizerG = self.train_model["optimizerG"]
        netD = self.train_model["netD"]
        optimizerD = self.train_model["optimizerD"]
        criterion = self.train_model["criterion"]
        device = self.config["device"]

        real_data = data[0].to(device)

        noise = model.get_noise(real_data, self.config)
        fake_data = netG(noise)
        label = model.get_label(real_data, self.config)

        errD, D_x, D_G_z1 = model.get_Discriminator_loss(netD, optimizerD, real_data, fake_data.detach(), label, criterion, self.config)
        errG, D_G_z2 = model.get_Generator_loss(netG, netD, optimizerG, fake_data, label, criterion, self.config)

        return errD, errG, D_x, D_G_z1, D_G_z2

    def _train_a_step(self, data, i, epoch):
        start = time.time()
        errD, errG, D_x, D_G_z1, D_G_z2 = self._train_step(data)
        end = time.time()
        step_time 
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值