tf.linalg.band_part
- 新版本:tf.matrix_band_part变成tf.linalg.band_par
函数原型:
tf.linalg.band_part(
input,
num_lower,
num_upper,
name=None
)
参数:
- 作用:主要功能是以对角线为中心,取它的副对角线部分,其他部分用0填充。
- input:输入的张量.
- num_lower:下三角矩阵保留的副对角线数量,从主对角线开始计算,相当于下三角的带宽。取值为负数时,则全部保留。
- num_upper:上三角矩阵保留的副对角线数量,从主对角线开始计算,相当于上三角的带宽。取值为负数时,则全部保留。
例子:
import tensorflow as tf
tf.enable_eager_execution()
a=tf.constant( [[ 1, 1, 2, 3],[-1, 2, 1, 2],[-2, -1, 3, 1],
[-3, -2, -1, 5]],dtype=tf.float32)
b=tf.linalg.band_part(a,2,0)
c=tf.linalg.band_part(a,1,1)
d=tf.linalg.band_part(a,-1,1)
print(a)
print(b)
print(c)
print(d)
输出:
tf.Tensor(
[[ 1. 1. 2. 3.]
[-1. 2. 1. 2.]
[-2. -1. 3. 1.]
[-3. -2. -1. 5.]], shape=(4, 4), dtype=float32)
=============================================================
tf.Tensor(
[[ 1. 0. 0. 0.]
[-1. 2. 0. 0.]
[-2. -1. 3. 0.]
[ 0. -2. -1. 5.]], shape=(4, 4), dtype=float32)
=============================================================
tf.Tensor(
[[ 1. 1. 0. 0.]
[-1. 2. 1. 0.]
[ 0. -1. 3. 1.]
[ 0. 0. -1. 5.]], shape=(4, 4), dtype=float32)
=============================================================
tf.Tensor(
[[ 1. 1. 0. 0.]
[-1. 2. 1. 0.]
[-2. -1. 3. 1.]
[-3. -2. -1. 5.]], shape=(4, 4), dtype=float32)