今天看tensorflow时遇到 tf.TensorArray,写一个见到的教程指导使用
import tensorflow as tf
sess = tf.Session()
x = np.arange(20)
input_ta = tf.TensorArray(size=0, dtype=tf.int32, dynamic_size=True)
input_ta = input_ta.unstack(x) #TensorArray可以传入array或者tensor
for time in range(len(x)):
print(sess.run(input_ta.read(time))) #遍历查看元素
output = input_ta.stack() #合成
print(sess.run(output))
for time in range(5):
input_ta = input_ta.write(time+len(x), time) #写入
output = input_ta.stack()
print(sess.run(output))
输出结果为:
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 0 1 2 3 4]
注意;这里的unstack()操作能够将输入的张量的维度减少一,如我这里输入的X是(1,20)input_ta中变成20个元素,stack()将这20个元素聚合成一个元素
ta.stack(name=None)
将TensorArray中元素叠起来当做一个Tensor输出
ta.unstack(value, name=None)
可以看做是stack的反操作,输入Tensor,输出一个新的TensorArray对象
ta.write(index, value, name=None)
指定index位置写入Tensor
ta.read(index, name=None)
读取指定index位置的Tensor