tf.concat()
作用:沿着某一维度合并张量
举例:
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
# 当axis = 0时
print(tf.concat([t1, t2], 0))
# [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
# 当axis = 1时
print(tf.concat([t1, t2], 1))
# [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]
在举个例子:
# tensor t3 with shape [2, 3]
# tensor t4 with shape [2, 3]
tf.shape(tf.concat([t3, t4], 0))
# [4, 3]
tf.shape(tf.concat([t3, t4], 1))
# [2, 6]
可以看出来,按照axis = 0 时得到的张量shape就是将第一个位置上的两个2相加,后者不变。axis = 1 时得到的张量的shape就是将第二个位置上的两个3相加,前者不变。
更加普遍的有:
Values[i].shape=[D0,D1,...Daxis(i),...Dn]Values[i].shape = [D_0, D_1, ... D_{axis(i)}, ...D_n]Values[i].shape=[D0,D1,...Daxis(i),...Dn]
得到的结果的
shape=[D0,D1,...Raxis,...Dn]shape = [D_0, D_1, ... R_{axis}, ...D_n]shape=[D0,D1,...Raxis,...Dn]
其中
Raxis=sum(Daxis(i))R_{axis} = sum(D_{axis(i)})Raxis=sum(Daxis(i))