**
一 tf.where( )函数
**
tf.where(tensor)
当只有一个输入时,输入为布尔型,其返回的是值为True的位置.
In [3]: a = tf.random.normal([3,3])#生成3行3列的随机数组
In [4]: a
Out[4]:
<tf.Tensor: id=5, shape=(3, 3), dtype=float32, numpy=
array([[ 0.9277362 , -0.47748154, -0.01599854],
[-0.61649364, -0.01608791, 1.2601289 ],
[ 0.5817263 , -0.7267582 , -0.17820324]], dtype=float32)>
In [6]: mask = a>0 #返回>0的布尔值
In [7]: mask
Out[7]:
<tf.Tensor: id=8, shape=(3, 3), dtype=bool, numpy=
array([[ True, False, False],
[False, False, True],
[ True, False, False]])>
In [8]: idx = tf.where(mask)#返回值为True的位置索引
In [9]: idx
Out[9]:
<tf.Tensor: id=10, shape=(3, 2), dtype=int64, numpy=
array([[0, 0],
[1, 2],
[2, 0]])>
In [10]: tf.gather_nd(a,idx)#根据索引从数组a中提取出数据
Out[10]: <tf.Tensor: id=12, shape=(3,), dtype=float32, numpy=array([0.9277362, 1.2601289, 0.5817263], dtype=float32)>
In [11]: tf.boolean_mask(a,mask)
Out[11]: <tf.Tensor: id=40, shape=(3,), dtype=float32, numpy=array([0.9277362, 1.2601289, 0.5817263], dtype=float32)>
tf.where(cond,A,B)输入为三个参数时,当cond中值为True时,取a对应值,否则取b中对应值
In [12]: mask
Out[12]:
<tf.Tensor: id=8, shape=(3, 3), dtype=bool, numpy=
array([[ True, False, False],
[False, False, True],
[ True, False, False]])>
In [13]: a = tf.ones([3,3])
In [14]: a
Out[14]:
<tf.Tensor: id=45, shape=(3, 3), dtype=float32, numpy=
array([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]], dtype=float32)>
In [15]: b = tf.zeros([3,3])
In [16]: b
Out[16]:
<tf.Tensor: id=49, shape=(3, 3), dtype=float32, numpy=
array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float32)>
In [17]: c =tf.where(mask,a,b)
In [18]: c
Out[18]:
<tf.Tensor: id=51, shape=(3, 3), dtype=float32, numpy=
array([[1., 0., 0.],
[0., 0., 1.],
[1., 0., 0.]], dtype=float32)>
**
二 tf.scatter_nd( )函数
**
tf.scatter_nd(indices,updates,shape,name=None)
只能在数值全部为0的底板上面指定位置更新数据.
shape为底板的shape,类型为tensor
indices为更新位置的索引
updates为更新的新值
In [21]: shape = tf.constant([8])#指定底板的shape为(8,)
In [22]: shape
Out[22]: <tf.Tensor: id=55, shape=(1,), dtype=int32, numpy=array([8], dtype=int32)>
In [23]: indices = tf.constant([[4],[3],[1],[7]])#指定更新位置的索引为4,3,1,7
In [24]: indices
Out[24]:
<tf.Tensor: id=57, shape=(4, 1), dtype=int32, numpy=
array([[4],
[3],
[1],
[7]], dtype=int32)>
In [26]: updates = tf.constant([9,10,11,12])#指定更新的新值为[9,10,11,12]
In [27]: updates
Out[27]: <tf.Tensor: id=60, shape=(4,), dtype=int32, numpy=array([ 9, 10, 11, 12], dtype=int32)>
In [28]: tf.scatter_nd(indices,updates,shape)
Out[28]: <tf.Tensor: id=62, shape=(8,), dtype=int32, numpy=array([ 0, 11, 0, 10, 9, 0, 0, 12], dtype=int32)>
利用sactter_nd更新现有的tensor
将tensor A中需要更新位置的数据取出并更新到底板->得到A’
A = A-A’清零需要更新位置的数据
将需要更新的新数据更新到底板->得到A’’
A = A+A’'将数据更新到原tensor中.
二 meshgrid
In [39]: x = tf.linspace(-2.,2.,5)
In [40]: y = tf.linspace(-2.,2.,5)
In [41]: point_x,point_y = tf.meshgrid(x,y)
In [42]: point_x.shape
Out[42]: TensorShape([5, 5])
In [43]: point_y.shape
Out[43]: TensorShape([5, 5])
In [45]: points = tf.stack([point_x,point_y],axis=2)
In [46]: points
Out[46]:
<tf.Tensor: id=71, shape=(5, 5, 2), dtype=float32, numpy=
array([[[-2., -2.],
[-1., -2.],
[ 0., -2.],
[ 1., -2.],
[ 2., -2.]],
[[-2., -1.],
[-1., -1.],
[ 0., -1.],
[ 1., -1.],
[ 2., -1.]],
[[-2., 0.],
[-1., 0.],
[ 0., 0.],
[ 1., 0.],
[ 2., 0.]],
[[-2., 1.],
[-1., 1.],
[ 0., 1.],
[ 1., 1.],
[ 2., 1.]],
[[-2., 2.],
[-1., 2.],
[ 0., 2.],
[ 1., 2.],
[ 2., 2.]]], dtype=float32)>
In [47]: points.shape
Out[47]: TensorShape([5, 5, 2])