TensorFlow中张量连接操作 tf.concat 与张量堆叠操作 tf.stack 的用法

本文详细解析了TensorFlow中张量的连接(concat)与堆叠(stack)操作,通过实例展示了不同轴上操作的结果,帮助理解张量维度的变化。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、环境

TensorFlow API r1.12

CUDA 9.2 V9.2.148

cudnn64_7.dll

Python 3.6.3

Windows 10

二、官方说明

  1. 张量连接:https://www.tensorflow.org/api_docs/python/tf/concat
    对输入的张量进行内容的连接,不涉及维度的改变
    使用说明:https://blog.youkuaiyun.com/sdnuwjw/article/details/84960377

  2. 张量堆叠:https://www.tensorflow.org/api_docs/python/tf/stack
    对输入的张量进行堆叠,维度增加 1 维
    使用说明:https://blog.youkuaiyun.com/sdnuwjw/article/details/85049603

简单来说,两者的区别是:
tf.concat() 内容连接,维度不变
tf.stack() 张量堆叠,维度加一

三、实例

t1 = [[1,2,3],[4,5,6]]
t2 = [[7,8,9],[10,11,12]]
t3 = [[13,14,15],[16,17,18]]

concat0 = tf.concat(values=[t1,t2,t3],axis=0)
concat1 = tf.concat(values=[t1,t2,t3],axis=1)
concat2 = tf.concat(values=[t1,t2,t3],axis=2)

stack0 = tf.stack(values=[t1,t2,t3],axis=0)
stack1 = tf.stack(values=[t1,t2,t3],axis=1)
stack2 = tf.stack(values=[t1,t2,t3],axis=2)

with tf.Session() as sess:
	print(sess.run(tf.shape(concat0)))
	print(sess.run(concat0))
	print("---------------------")
	print(sess.run(tf.shape(concat1)))
	print(sess.run(concat1))
	print("---------------------")
	print(sess.run(tf.shape(stack0)))
	print(sess.run(stack0))
	print("---------------------")
	print(sess.run(tf.shape(stack1)))
	print(sess.run(stack1))
	print("---------------------")
	print(sess.run(tf.shape(stack2)))
	print(sess.run(stack2))

输出结果:
########################
concat0,shape: [6 3] (3个两行三列的按照第0维进行内容连接,得到6行3列,维度总数不变,还是二维)
[[ 1 2 3]
[ 4 5 6]
[ 7 8 9]
[10 11 12]
[13 14 15]
[16 17 18]]
########################
concat1,shape: [2 9] (3个两行三列的按照第1维进行内容连接,得到1行9列,维度总数不变,还是二维)
[[ 1 2 3 7 8 9 13 14 15]
[ 4 5 6 10 11 12 16 17 18]]
########################
stack0,shape: [3 2 3] (3个两行三列的按照第0维进行堆叠,维度总数变为3)
[[[ 1 2 3]
[ 4 5 6]]

[[ 7 8 9]
[10 11 12]]

[[13 14 15]
[16 17 18]]]
########################
stack1,shape: [2 3 3] (3个两行三列的按照第1维进行堆叠,维度总数变为3)1
[[[ 1 2 3]
[ 7 8 9]
[13 14 15]]

[[ 4 5 6]
[10 11 12]
[16 17 18]]]
########################
stack2,shape: [2 3 3] (3个两行三列的按照第2维进行堆叠,维度总数变为3)
[[[ 1 7 13]
[ 2 8 14]
[ 3 9 15]]

[[ 4 10 16]
[ 5 11 17]
[ 6 12 18]]]
########################

细心的读者可以发现,这里没有 concat2,这是因为 concat 的连接维度不能等于或大于原张量的阶。如果超过,则会报错:

Traceback (most recent call last):
  File "<stdin>", line 6, in <module>
NameError: name 'concat2' is not defined
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

csdn-WJW

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值