项目及代码来源:
在看代码前可以简单理解vae的基本概念上,推荐一篇知乎文章:https://zhuanlan.zhihu.com/p/55557709
还有https://yuanxiaosc.github.io/2018/08/26/%E5%8F%98%E5%88%86%E8%87%AA%E7%BC%96%E7%A0%81%E5%99%A8/
尤其第二篇,我就是在第二篇中理解通透了VAE的具体流程,在理解过程中我发现了模型中的几个小trick,我在注释中详细标注出来了并加上了自己的理解。方便之前没接触过vae的同志们直接上手,从程序中理解。
中文注释均为我个人的理解,对vae新手非常友好,看了注释后应该会对vae有了更深刻的认识,如果代码注释中有问题或错误请大家指出。
以下为一个mnist手写体识别图片的生成项目代码,可直接运行,并且最后展示了多个均匀分布的z解压后生成的图片间的关系即变化流程。
# -*- coding: utf-8 -*-
from __future__ import division, print_function, absolute_import
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
import tensorflow as tf
# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST/", one_hot=True)
# Parameters
learning_rate = 0.001
num_steps = 3000 #迭代次数 30000
batch_size = 64
print(num_steps)
# Network Parameters
image_dim = 784 # MNIST images are 28x28 pixels
hidden_dim = 512
latent_dim = 2#潜在向量的维度
# A custom initialization (see Xavier Glorot init)
def glorot_init(shape):
return tf.random_normal(shape=shape, stddev=1. / tf.sqrt(shape[0] / 2.))
#当我们在初始化网络的权重时,需要设置一个合理的随机值,避免出