这个函数本身的注释看起来不太清晰, 这里举个小例子, 介绍一下最基本的用法:
首先参数列表是
tf.strided_slice(input, begin, end, stride, ...)
tf.fill( dims, value, name=None)
dims 是1-D的向量表示shape, value 是一个实数, 表示整个矩阵填充的值.
实例: 定义 decoder_input, 其实就是在 decoder_target开始处添加一个 <go>, 并删除结尾处的<end>. decoder_input shape (batch, max-step) 对应到下面的例子就是 (3,4) 把所有最后一位9去掉, 在第1位添加一个0(表示go符号)
# coding:utf-8
import tensorflow as tf
t = tf.constant([[1, 1, 1, 9], [2, 2, 2, 9], [7, 7, 7, 9]])
# 只对第0维进行操作, 取出[0,3), 也就是相当与 t[0:3]
z0 = tf.strided_slice(t, [0], [3], [1])
# 在第0维操作的基础上进行第1维度的操作, 其中第0维度的操作如z0所示. 整体操作等于 t[:,0:-1]
z1 = tf.strided_slice(t, [0, 0], [3, -1], [1, 1])
# 首先生成一个填充为0的(3,1)的矩阵, 然后与上面取出的进行拼接
z2 = tf.concat([tf.fill((3, 1), 0), z1], axis=1)
with tf.Session() as sess:
print(sess.run(z0))
print(sess.run(z1))
print(sess.run(z2))
输出:
[[1 1 1 9]
[2 2 2 9]
[7 7 7 9]]
[[1 1 1]
[2 2 2]
[7 7 7]]
[[0 1 1 1]
[0 2 2 2]
[0 7 7 7]]
tf.strided_slice() 函数
最新推荐文章于 2022-03-18 12:28:33 发布