Vae总结

Vae 变分自编码

由两部分网络构成  第一部分是编码网络提取物体特征,第二部分为解码网络生成物体

思路:控制分布

一般为判别网络+生成网络

判别为条件概率

生成为联合概率

编码网络  输入物体  输出方差和均值假设为正太分布

解码  输入正太分布  分布*方差+均值  在输出生成样本

loss 由两部分组成第一部分自身损失,第二部分控制分布损失

用kl距离公式优化loss

推荐文章https://www.jianshu.com/p/43318a3dc715

代码实现

编码网络

class EncoderNet:
    def __init__(self):
        # w的格式和b的格式
        self.in_w=tf.Variable(tf.truncated_normal(shape=[784,100],stddev=0.1))
        self.in_b=tf.Variable(tf.zeros([100]))
        # 输入方差的权重和均值的权重
        self.logvar_w=tf.Variable(tf.truncated_normal(shape=[100,128],stddev=0.1))
        self.mean_w = tf.Variable(tf.truncated_normal(shape=[100,128],stddev=0.1))
    def forward(self,x):
        # 全连接
        y = tf.nn.relu(tf.matmul(x,self.in_w)+self.in_b)
        # 均值全连接(没定范围)
        mean =tf.matmul(y,self.mean_w)
        # 方差全连接(没定范围)
        logvar = tf.matmul(y,self.logvar_w)
        return mean,logvar

解码网络

# 解码网络
class DecoderNet:
    def __init__(self):
        self.in_w = tf.Variable(tf.truncated_normal(shape=[128,100],stddev=0.1))
        self.in_b = tf.Variable(tf.zeros([100]))
        self.out_w=tf.Variable(tf.truncated_normal(shape=[100,784],stddev=0.1))
    def forward(self,x):
        # 全连接
        y = tf.nn.relu(tf.matmul(x,self.in_w)+self.in_b)
        return tf.matmul(y,self.out_w)

 主网络

class Net:
    def __init__(self):
        # 输入样本
        self.x = tf.placeholder(dtype=tf.float32, shape=[None, 28 * 28])
        # 初始化编码器
        self.encoderNet = EncoderNet()
        # 初始化解码器
        self.decoderNet = DecoderNet()
        # 调用forward()
        # 调用backward()
        self.forward()
        self.backward()
    def forward(self):
        # --------------------------------编码网络--------------------------------
        # 输出均值和方差
        self.mean, self.logVar = self.encoderNet.forward(self.x)
        # 随机正态分布(NV)广播
        normal_y = tf.random_normal(shape=[128])
        # 方差
        self.var = tf.exp(self.logVar)
        # 标准差
        std = tf.sqrt(self.var)
        # 随机正态分布*标准差+均值
        y = normal_y * std + self.mean
        # --------------------------------解码网络--------------------------------
        self.output = self.decoderNet.forward(y)
        #随机传个正态分布进去
    def decode(self):
        normal_x=tf.random_normal(shape=[1,128])
        return self.decoderNet.forward(normal_x)
    def backward(self):
        #样本-输出 再平方 再 求均值
 
        # 输出loss
  
        #kl loss
        # 总loss
        self.loss =kl_loss+output_loss
        # 优化器
        self.opt = tf.train.AdamOptimizer().minimize(self.loss)

训练

if __name__ == '__main__':
    net =Net()
    test_output = net.decode()
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        # 初始化所有
        sess.run(init)
        for epoch in range(1000000):
            xs, _ = mnist.train.next_batch(100)
            _loss,_=sess.run([net.loss,net.opt],feed_dict={net.x:xs})
            if epoch%100==0:
                print(_loss)
                test_img_data = sess.run([test_output])
                test_img = np.reshape(test_img_data,[28,28])
                plt.imshow(test_img)
                plt.pause(0.1)

 提供了部分代码及思路,根据不同的项目可以对编码解码网络处进行替换,比如图片用Cnn网络等灵活运用效果更佳

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值