Vae 变分自编码
由两部分网络构成 第一部分是编码网络提取物体特征,第二部分为解码网络生成物体
思路:控制分布
一般为判别网络+生成网络
判别为条件概率
生成为联合概率
编码网络 输入物体 输出方差和均值假设为正太分布
解码 输入正太分布 分布*方差+均值 在输出生成样本
loss 由两部分组成第一部分自身损失,第二部分控制分布损失
用kl距离公式优化loss
推荐文章https://www.jianshu.com/p/43318a3dc715
代码实现
编码网络
class EncoderNet:
def __init__(self):
# w的格式和b的格式
self.in_w=tf.Variable(tf.truncated_normal(shape=[784,100],stddev=0.1))
self.in_b=tf.Variable(tf.zeros([100]))
# 输入方差的权重和均值的权重
self.logvar_w=tf.Variable(tf.truncated_normal(shape=[100,128],stddev=0.1))
self.mean_w = tf.Variable(tf.truncated_normal(shape=[100,128],stddev=0.1))
def forward(self,x):
# 全连接
y = tf.nn.relu(tf.matmul(x,self.in_w)+self.in_b)
# 均值全连接(没定范围)
mean =tf.matmul(y,self.mean_w)
# 方差全连接(没定范围)
logvar = tf.matmul(y,self.logvar_w)
return mean,logvar
解码网络
# 解码网络
class DecoderNet:
def __init__(self):
self.in_w = tf.Variable(tf.truncated_normal(shape=[128,100],stddev=0.1))
self.in_b = tf.Variable(tf.zeros([100]))
self.out_w=tf.Variable(tf.truncated_normal(shape=[100,784],stddev=0.1))
def forward(self,x):
# 全连接
y = tf.nn.relu(tf.matmul(x,self.in_w)+self.in_b)
return tf.matmul(y,self.out_w)
主网络
class Net:
def __init__(self):
# 输入样本
self.x = tf.placeholder(dtype=tf.float32, shape=[None, 28 * 28])
# 初始化编码器
self.encoderNet = EncoderNet()
# 初始化解码器
self.decoderNet = DecoderNet()
# 调用forward()
# 调用backward()
self.forward()
self.backward()
def forward(self):
# --------------------------------编码网络--------------------------------
# 输出均值和方差
self.mean, self.logVar = self.encoderNet.forward(self.x)
# 随机正态分布(NV)广播
normal_y = tf.random_normal(shape=[128])
# 方差
self.var = tf.exp(self.logVar)
# 标准差
std = tf.sqrt(self.var)
# 随机正态分布*标准差+均值
y = normal_y * std + self.mean
# --------------------------------解码网络--------------------------------
self.output = self.decoderNet.forward(y)
#随机传个正态分布进去
def decode(self):
normal_x=tf.random_normal(shape=[1,128])
return self.decoderNet.forward(normal_x)
def backward(self):
#样本-输出 再平方 再 求均值
# 输出loss
#kl loss
# 总loss
self.loss =kl_loss+output_loss
# 优化器
self.opt = tf.train.AdamOptimizer().minimize(self.loss)
训练
if __name__ == '__main__':
net =Net()
test_output = net.decode()
init = tf.global_variables_initializer()
with tf.Session() as sess:
# 初始化所有
sess.run(init)
for epoch in range(1000000):
xs, _ = mnist.train.next_batch(100)
_loss,_=sess.run([net.loss,net.opt],feed_dict={net.x:xs})
if epoch%100==0:
print(_loss)
test_img_data = sess.run([test_output])
test_img = np.reshape(test_img_data,[28,28])
plt.imshow(test_img)
plt.pause(0.1)
提供了部分代码及思路,根据不同的项目可以对编码解码网络处进行替换,比如图片用Cnn网络等灵活运用效果更佳