在编码阶段需要在输入端添加标签对应的特征,在解码阶段同样也需要将标签加入输入,这样,再解码的结果向原始的输入样本不断逼近,最终得到的模型将会把输入的标签特征当成MNIST数据的一部分,从而实现通过标签生成MNIST数据。
在输入端添加标签时,一般是通过一个全连接层的变换将得到的结果使用contact函数连接到原始输入的地方,在解码阶段也将标签作为样本数据,与高斯分布的随机值一并运算,生成模拟样本。
实例描述
使用条件变分自编码模型,通过指定标签输入生成对应类型的MNIST模拟数据
1.添加标签占位符
x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_labels])
2.添加输入全连接权重
添加全连接层的权重‘wlabl’和‘blabl’,作为输入标签和特征转换。这里也需要将输入的标签也转换成256个维度的输出。因为最终也需要连接到原始的图片全连接输出中,所以到第二层全连接时,输入需要编程256×2,因此也需要将mean_w1和log_sigma_w1的输入修改成n_hidden_1×2.
weights = {
'w1': tf.Variable(tf.truncated_normal([n_input, n_hidden_1],
stddev=0.001)),
'b1': tf.Variable(tf.zeros([n_hidden_1])),
'wlab1': tf.Variable(tf.truncated_normal([n_labels, n_hidden_1],
stddev=0.001)),
'blab1': tf.Variable(tf.zeros([n_hidden_1])),
'mean_w1': tf.Variable(tf.truncated_normal([n_hidden_1*2, n_hidden_2],
stddev=0.001)),
'log_sigma_w1': tf.Variable(tf.truncated_normal([n_hidden_1*2, n_hidden_2],
stddev=0.001)),
'w2': tf.Variable(tf.truncated_normal([n_hidden_2+n_labels, n_hidden_1],
stddev=0.001)),
'b2': tf.Variable(tf.zeros([n_hidden_1])),
'w3': tf.Variable(tf.truncated_normal([n_hidden_1, n_input],
stddev=0.001)),
'b3': tf.Variable(tf.zeros([n_input])),
'mean_b1': tf.Variable(tf.zeros([n_hidden_2])),
'log_sigma_b1': tf.Variable(tf.zeros([n_hidden_2]))
}
同样在解码器的生成时,要讲z和label连接起来输入解码器中,所以w2的输入维度需要改成n_hidden_2+n_labels
3.修改模型,将标签输出介入编码
#将h1输出和标签输出concat一起
h1=tf.nn.relu(tf.add(tf.matmul(x, weights['w1']), weights['b1']))
hlab1=tf.nn.relu(tf.add(tf.matmul(y, weights['wlab1']), weights['blab1']))
hall1= tf.concat([h1,hlab1],1)#256*2
#接着生成对应的mean和log_sigma
z_mean = tf.add(tf.matmul(hall1, weights['mean_w1']), weights['mean_b1'])
z_log_sigma_sq = tf.add(tf.matmul(hall1, weights['log_sigma_w1']), weights['log_sigma_b1'])
4.修改模型将标签接入解码
在解码过程中需要注意的是,需要将z和y(标签)一起concat作为输入来进行解码,修改reconstruction和reconstructionout节点。
zall=tf.concat([z,y],1)
h2=tf.nn.relu( tf.matmul(zall, weights['w2'])+ weights['b2'])
reconstruction = tf.matmul(h2, weights['w3'])+ weights['b3']
zinputall = tf.concat([zinput,y],1)
h2out=tf.nn.relu( tf.matmul(zinputall, weights['w2'])+ weights['b2'])
reconstructionout = tf.matmul(h2out, weights['w3'])+ weights['b3']
5. 修改session中的feed部分
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(mnist.train.num_examples/batch_size)
# 遍历全部数据集
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)#取数据
# Fit training using batch data
_,c = sess.run([optimizer,cost], feed_dict={x: batch_xs,y:batch_ys})
#c = autoencoder.partial_fit(batch_xs)
# 显示训练中的详细信息
if epoch % display_step == 0:
print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(c))
6.运行模型生成模拟数据
# 根据图片模拟生成图片
show_num = 10
pred = sess.run(
reconstruction, feed_dict={x: mnist.test.images[:show_num],y: mnist.test.labels[:show_num]})
f, a = plt.subplots(2, 10, figsize=(10, 2))
for i in range(show_num):
a[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28)))
a[1][i].imshow(np.reshape(pred[i], (28, 28)))
plt.draw()
根据原始图片生成自编码数据。
# 根据label模拟生产图片可视化结果
show_num = 10
z_sample = np.random.randn(10,2)
pred = sess.run(
reconstructionout, feed_dict={zinput:z_sample,y: mnist.test.labels[:show_num]})
f, a = plt.subplots(2, 10, figsize=(10, 2))
for i in range(show_num):
a[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28)))
a[1][i].imshow(np.reshape(pred[i], (28, 28)))
plt.draw()
上图则是根据label生成的自编码数据。
比较两幅图片可以看书,使用原图生成的自编码数据还有一些原来的样子,而以标签生成的解码数据,则彻底学习数据的分布,并生成截然不同的数据