使用tf.tensor_scatter_nd_update对张量进行赋值
tf.tensor_scatter_nd_update()是根据指定索引值(或切片)对张量进行新赋值。函数原型是
scatter_nd_update(
ref,
indices,
updates,
use_locking=True,
name=None
)
主要参数为三个:ref是被赋值的张量,indices是具体的索引位置,是整数类型的张量,updates是要赋值的张量,注意与ref为同样类型。该函数就是将ref[indices]的值替换为updates。
举例:
对于一维张量的坐标赋值
tensor = tf.constant([0, 0, 0, 0, 0, 0, 0, 0],dtype=tf.int32)
ipdb> indices = tf.constant([[3], [1], [4], [7]])
ipdb> updates = tf.constant([5, 6