实例83:使用标签指导变分自编码网络生成MNIST数据

本文介绍了如何使用条件变分自编码器(CVAE)结合MNIST数据集,通过指定标签来生成对应类型的模拟手写数字图像。在编码阶段,标签被转换为特征并添加到输入;解码阶段则将标签与随机噪声结合,生成新的图像。经过训练,模型能够学习数据分布,并创造出与标签相符的新图像。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在编码阶段需要在输入端添加标签对应的特征,在解码阶段同样也需要将标签加入输入,这样,再解码的结果向原始的输入样本不断逼近,最终得到的模型将会把输入的标签特征当成MNIST数据的一部分,从而实现通过标签生成MNIST数据。
在输入端添加标签时,一般是通过一个全连接层的变换将得到的结果使用contact函数连接到原始输入的地方,在解码阶段也将标签作为样本数据,与高斯分布的随机值一并运算,生成模拟样本。

实例描述

使用条件变分自编码模型,通过指定标签输入生成对应类型的MNIST模拟数据

1.添加标签占位符

x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_labels])

2.添加输入全连接权重

添加全连接层的权重‘wlabl’和‘blabl’,作为输入标签和特征转换。这里也需要将输入的标签也转换成256个维度的输出。因为最终也需要连接到原始的图片全连接输出中,所以到第二层全连接时,输入需要编程256×2,因此也需要将mean_w1和log_sigma_w1的输入修改成n_hidden_1×2.

weights = {

    'w1': tf.Variable(tf.truncated_normal([n_input, n_hidden_1],
                                   stddev=0.001)),
    'b1': tf.Variable(tf.zeros([n_hidden_1])),

    'wlab1': tf.Variable(tf.truncated_normal([n_labels, n_hidden_1],
                                   stddev=0.001)),
    'blab1': tf.Variable(tf.zeros([n_hidden_1])),

    'mean_w1': tf.Variable(tf.truncated_normal([n_hidden_1*2, n_hidden_2],
                                   stddev=0.001)),
    'log_sigma_w1': tf.Variable(tf.truncated_normal([n_hidden_1*2, n_hidden_2],
                                   stddev=0.001)),    
    
    
    'w2': tf.Variable(tf.truncated_normal([n_hidden_2+n_labels, n_hidden_1],
                                   stddev=0.001)),

    'b2': tf.Variable(tf.zeros([n_hidden_1])),
    'w3': tf.Variable(tf.truncated_normal([n_hidden_1, n_input],
                                   stddev=0.001)),

    'b3': tf.Variable(tf.zeros([n_input])),

    'mean_b1': tf.Variable(tf.zeros([n_hidden_2])),

    'log_sigma_b1': tf.Variable(tf.zeros([n_hidden_2]))
}

同样在解码器的生成时,要讲z和label连接起来输入解码器中,所以w2的输入维度需要改成n_hidden_2+n_labels

3.修改模型,将标签输出介入编码

#将h1输出和标签输出concat一起
h1=tf.nn.relu(tf.add(tf.matmul(x, weights['w1']), weights['b1']))

hlab1=tf.nn.relu(tf.add(tf.matmul(y, weights['wlab1']), weights['blab1']))

hall1= tf.concat([h1,hlab1],1)#256*2
#接着生成对应的mean和log_sigma
z_mean = tf.add(tf.matmul(hall1, weights['mean_w1']), weights['mean_b1'])
z_log_sigma_sq = tf.add(tf.matmul(hall1, weights['log_sigma_w1']), weights['log_sigma_b1'])

4.修改模型将标签接入解码

在解码过程中需要注意的是,需要将z和y(标签)一起concat作为输入来进行解码,修改reconstruction和reconstructionout节点。

zall=tf.concat([z,y],1)
h2=tf.nn.relu( tf.matmul(zall, weights['w2'])+ weights['b2'])
reconstruction = tf.matmul(h2, weights['w3'])+ weights['b3']

zinputall = tf.concat([zinput,y],1)
h2out=tf.nn.relu( tf.matmul(zinputall, weights['w2'])+ weights['b2'])
reconstructionout = tf.matmul(h2out, weights['w3'])+ weights['b3']

5. 修改session中的feed部分


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    for epoch in range(training_epochs):
        avg_cost = 0.
        total_batch = int(mnist.train.num_examples/batch_size)

        # 遍历全部数据集
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)#取数据
            
    
            # Fit training using batch data
            _,c = sess.run([optimizer,cost], feed_dict={x: batch_xs,y:batch_ys})
            #c = autoencoder.partial_fit(batch_xs)
        # 显示训练中的详细信息
        if epoch % display_step == 0:
            print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(c))

6.运行模型生成模拟数据

  # 根据图片模拟生成图片
    show_num = 10
    pred = sess.run(
        reconstruction, feed_dict={x: mnist.test.images[:show_num],y: mnist.test.labels[:show_num]})

    f, a = plt.subplots(2, 10, figsize=(10, 2))
    for i in range(show_num):
        a[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28)))
        a[1][i].imshow(np.reshape(pred[i], (28, 28)))
    plt.draw()

在这里插入图片描述
根据原始图片生成自编码数据。

  # 根据label模拟生产图片可视化结果
    show_num = 10
    z_sample = np.random.randn(10,2)
    
    pred = sess.run(
        reconstructionout, feed_dict={zinput:z_sample,y: mnist.test.labels[:show_num]})

    f, a = plt.subplots(2, 10, figsize=(10, 2))
    for i in range(show_num):
        a[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28)))
        a[1][i].imshow(np.reshape(pred[i], (28, 28)))
    plt.draw()    
    

在这里插入图片描述
上图则是根据label生成的自编码数据。
比较两幅图片可以看书,使用原图生成的自编码数据还有一些原来的样子,而以标签生成的解码数据,则彻底学习数据的分布,并生成截然不同的数据

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值