一、环境
TensorFlow API r1.12
CUDA 9.2 V9.2.148
cudnn64_7.dll
Python 3.6.3
Windows 10
二、官方说明
-
张量连接:https://www.tensorflow.org/api_docs/python/tf/concat
对输入的张量进行内容的连接,不涉及维度的改变
使用说明:https://blog.youkuaiyun.com/sdnuwjw/article/details/84960377 -
张量堆叠:https://www.tensorflow.org/api_docs/python/tf/stack
对输入的张量进行堆叠,维度增加 1 维
使用说明:https://blog.youkuaiyun.com/sdnuwjw/article/details/85049603
简单来说,两者的区别是:
tf.concat() 内容连接,维度不变
tf.stack() 张量堆叠,维度加一
三、实例
t1 = [[1,2,3],[4,5,6]]
t2 = [[7,8,9],[10,11,12]]
t3 = [[13,14,15],[16,17,18]]
concat0 = tf.concat(values=[t1,t2,t3],axis=0)
concat1 = tf.concat(values=[t1,t2,t3],axis=1)
concat2 = tf.concat(values=[t1,t2,t3],axis=2)
stack0 = tf.stack(values=[t1,t2,t3],axis=0)
stack1 = tf.stack(values=[t1,t2,t3],axis=1)
stack2 = tf.stack(values=[t1,t2,t3],axis=2)
with tf.Session() as sess:
print(sess.run(tf.shape(concat0)))
print(sess.run(concat0))
print("---------------------")
print(sess.run(tf.shape(concat1)))
print(sess.run(concat1))
print("---------------------")
print(sess.run(tf.shape(stack0)))
print(sess.run(stack0))
print("---------------------")
print(sess.run(tf.shape(stack1)))
print(sess.run(stack1))
print("---------------------")
print(sess.run(tf.shape(stack2)))
print(sess.run(stack2))
输出结果:
########################
concat0,shape: [6 3] (3个两行三列的按照第0维进行内容连接,得到6行3列,维度总数不变,还是二维)
[[ 1 2 3]
[ 4 5 6]
[ 7 8 9]
[10 11 12]
[13 14 15]
[16 17 18]]
########################
concat1,shape: [2 9] (3个两行三列的按照第1维进行内容连接,得到1行9列,维度总数不变,还是二维)
[[ 1 2 3 7 8 9 13 14 15]
[ 4 5 6 10 11 12 16 17 18]]
########################
stack0,shape: [3 2 3] (3个两行三列的按照第0维进行堆叠,维度总数变为3)
[[[ 1 2 3]
[ 4 5 6]]
[[ 7 8 9]
[10 11 12]]
[[13 14 15]
[16 17 18]]]
########################
stack1,shape: [2 3 3] (3个两行三列的按照第1维进行堆叠,维度总数变为3)1
[[[ 1 2 3]
[ 7 8 9]
[13 14 15]]
[[ 4 5 6]
[10 11 12]
[16 17 18]]]
########################
stack2,shape: [2 3 3] (3个两行三列的按照第2维进行堆叠,维度总数变为3)
[[[ 1 7 13]
[ 2 8 14]
[ 3 9 15]]
[[ 4 10 16]
[ 5 11 17]
[ 6 12 18]]]
########################
细心的读者可以发现,这里没有 concat2,这是因为 concat 的连接维度不能等于或大于原张量的阶。如果超过,则会报错:
Traceback (most recent call last):
File "<stdin>", line 6, in <module>
NameError: name 'concat2' is not defined