import numpy as np
import tensorflow as tf
import pickle
import matplotlib.pyplot as plt
print("TensorFlow Version: {}".format(tf.__version__))
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/')
# ## Train
# 定义参数
batch_size = 64
noise_size = 100
epochs = 5
n_samples = 25
learning_rate = 0.001
beta1 = 0.4
class DCGAN():
@staticmethod
def get_inputs(noise_dim, image_height, image_width, image_depth):
inputs_real = tf.placeholder(tf.float32, [None, image_height, image_width, image_depth], name='inputs_real')
inputs_noise = tf.placeholder(tf.float32, [None, noise_dim], name='inputs_noise')
return inputs_real, inputs_noise
def get_generator(self, noise_img, output_dim, is_train=True, alpha=0.01):
"""
@Author: Nelson Zhao
--------------------
:param noise_img: 噪声信号,tensor类型
:param output_dim: 生成图片的depth
:param is_train: 是否为训练状态,该参数主要用于作为batch_normalization方法中的参数使用
:param alpha: Leaky ReLU系数
"""
with tf.variable_scope("generator") as scope0:
if not is_train:
scope0.reuse_variables()
# none*100 to none*4 x 4 x 512
# 全连接层
layer1 = tf.layers.dense(noise_img, 4*4*512)
layer1 = tf.reshape(layer1, [-1, 4, 4, 512])
# batch normalization
layer1 = tf.layers.batch_normalization(layer1, training=is_train)
# Leaky ReLU
layer1 = tf.maximum(alpha * layer1, layer1)
# dropout
layer1 = tf.nn.dropout(layer1, keep_prob=0.8)
tensorflow学习——DCGAN手写体生成
最新推荐文章于 2022-11-05 21:51:14 发布
本文详细介绍了使用TensorFlow实现Deep Convolutional Generative Adversarial Networks (DCGAN)进行手写体生成的过程。通过训练,DCGAN能够学习到手写数字的特征,并生成新的、逼真的手写数字图像。内容包括DCGAN的模型结构、训练策略以及生成结果的展示。

最低0.47元/天 解锁文章
5521

被折叠的 条评论
为什么被折叠?



