【tensorflow】tf.shape() 和 x.get_shape().as_list()

本文介绍在TensorFlow中如何使用tf.shape()和get_shape().as_list()两种方法获取张量的形状。前者适用于所有类型的数据,后者仅适用于Tensor,并能返回更直观的列表形式。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1、tf.shape()获取张量的大小,比较直接

import tensorflow as tf
import numpy as np

a_array = np.array([[1, 2, 3], [4, 6, 5]])
b_list = [[11, 22, 33], [44, 55, 66]]
c_tensor = tf.constant([[111, 222, 333], [444, 555, 666]], name = 'c_tensor')

with tf.Session() as sess:
    print(sess.run(tf.shape(a_array)))
    print(sess.run(tf.shape(b_list)))
    print(sess.run(tf.shape(c_tensor)))



输出:
[2 3]
[2 3]
[2 3]

2、get_shape().as_list()

x.get_shape()只有tensor才可以使用这种方法,返回的是TensorShape对象,再调用as_list()才能转换成列表

import tensorflow as tf
import numpy as np

a_array = np.array([[1, 2, 3], [4, 6, 5]])
b_list = [[11, 22, 33], [44, 55, 66]]
c_tensor = tf.constant([[111, 222, 333], [444, 555, 666]], name = 'c_tensor')

print(c_tensor.get_shape())
print(c_tensor.get_shape().as_list())



输出:
(2, 3)
[2, 3]

只能用tensor来返回shape,但是是TensorShape对象,需要通过as_list()的操作转换成list

 

需要注意的是:

第一点:tensor.get_shape()返回的是元组,不能放到sess.run()里面,这个里面只能放operation和tensor

第二点:tf.shape()返回的是一个tensor,必须通过sess.run()运行

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值