TensorFlow基础:tf.concat()

本文详细介绍了TensorFlow中tf.concat函数的使用方法,包括参数解释、常见错误及解决办法,通过一维、二维和三维数组示例,阐述了如何正确地在不同维度上连接数组。
部署运行你感兴趣的模型镜像

tf.concat函数: 主要用于连接两个数组
参数:
values:需要连接的数组
axis:从哪个维度来连接数组

一维数组
import tensorflow as tf

if __name__ == "__main__":
    a = [1,2,3]
    b = [4,5,6]
    c = tf.concat([a,b],0)
    sess = tf.InteractiveSession()
    print(sess.run(c)) #[1 2 3 4 5 6]

注意:axis参数不能超过数组的维度。如果超过数组的维度,如下:

c = tf.concat([a,b],1)

则会报,ValueError: Shape must be at least rank 2 but is rank 1 for 'concat',意思是数组至少是二维,axis才能为1

二维数组

 a = [[1,1],[2,2],[3,3]]
    b = [[4,4],[5,5],[6,6]]
    c = tf.concat([a,b],0)
    print(sess.run(c))
    """
    [[1 1]
     [2 2]
     [3 3]
     [4 4]
     [5 5]
     [6 6]]
    """
    c = tf.concat([a,b],1) #等价于tf.concat([a,b],-1)
    print(sess.run(c))
    """
    [[1 1 4 4]
     [2 2 5 5]
     [3 3 6 6]]
    """

三维数组

  a = [[[1,1],[2,2]],[[3,3],[4,4]]]
    b = [[[5,5]],[[6,6]]]

    c = tf.concat([a,b],1)
    print(sess.run(c))
    """
    [[[1 1]
      [2 2]
      [5 5]]

     [[3 3]
      [4 4]
      [6 6]]]
    """

注意:在使用tf.concat函数连接两个数组的时候,数组该维度必须是一致的,否则会报错,如下:

c = tf.concat([a,b],0)

错误提示ValueError: Dimension 0 in both shapes must be equal, but are 2 and 1,意思是a在第1个维度上shape是2,而b在第一个维度上shape是1。

总结:如何来判断数组是否在该个维度上的shape是相同的呢?其实很简单,我们根据tf.concat的axis参数来去数组的[],0表示去掉最外面的一层,1去掉两层,以此类推,下面举例说明一下。
如:最后一个例子中的c = tf.concat([a,b],1),我们先将a去掉最外面两层[],变成了[1,1],[2,2]和[3,3],[4,4]],然后再将b去掉最外面两层[],变成了[5,5]和[6,6],此时再进行concat,可以发现此时的shape是相等的。

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值