今天在编写自己的tensorflow-AutoEnocder网络是出现错误:
百度一下有大佬说是维度合并不当。我仔细查看了一下代码发现自己在tf.concat时sample的维度不对:
x_hat = tf.sigmoid(logits)
x_hat = tf.reshape(x_hat,[-1,28,28])
x_concat = tf.concat([sample,x_hat],axis=0)
x_concat = x_concat.numpy() * 255.0
x_concat =x_concat.astype(np.uint8)
print打印sample.shape发现其为[512,784],而x_hat经过网络输出和reshape后变为[512,28,28]两个维度不对无法合并,然后我将x_concat修改为
x_concat = tf.concat([tf.reshape(sample,[-1,28,28]),x_hat],axis=0)
然后网络就可以跑了。
可以参考大佬总结:https://blog.youkuaiyun.com/weixin_43543210/article/details/108230823