关于 tf.matrix_band_part

在看gpt2源码时,有这样一段:


def attention_mask(nd, ns, *, dtype):
    """1's in the lower triangle, counting from the lower right corner.

    Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
    """
    i = tf.range(nd)[:,None]
    j = tf.range(ns)
    m = i >= j - ns + nd
    return tf.cast(m, dtype)

 

之前没见过这个函数 tf.matrix_band_part,搜到的基本都是源码的注释,这个注释也没看懂:
 

def matrix_band_part(input, num_lower, num_upper, name=None):
  r"""Copy a tensor setting everything outside a central band in each innermost matrix

  to zero.

  The `band` part is computed as follows:
  Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a
  tensor with the same shape where

  `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.

  The indicator function

  `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) &&
                   (num_upper < 0 || (n-m) <= num_upper)`.

  For example:

  ```
  # if 'input' is [[ 0,  1,  2, 3]
                   [-1,  0,  1, 2]
                   [-2, -1,  0, 1]
                   [-3, -2, -1, 0]],

  tf.matrix_band_part(input, 1, -1) ==> [[ 0,  1,  2, 3]
                                         [-1,  0,  1, 2]
                                         [ 0, -1,  0, 1]
                                         [ 0,  0, -1, 0]],

  tf.matrix_band_part(input, 2, 1) ==> [[ 0,  1,  0, 0]
                                        [-1,  0,  1, 0]
                                        [-2, -1,  0, 1]
                                        [ 0, -2, -1, 0]]
  ```

  Useful special cases:

  ```
   tf.matrix_band_part(input, 0, -1) ==> Upper triangular part.
   tf.matrix_band_part(input, -1, 0) ==> Lower triangular part.
   tf.matrix_band_part(input, 0, 0) ==> Diagonal.
  ```

  Args:
    input: A `Tensor`. Rank `k` tensor.
    num_lower: A `Tensor` of type `int64`.
      0-D tensor. Number of subdiagonals to keep. If negative, keep entire
      lower triangle.
    num_upper: A `Tensor` of type `int64`.
      0-D tensor. Number of superdiagonals to keep. If negative, keep
      entire upper triangle.
    name: A name for the operation (optional).

  Returns:BandPart", input=input,
                                num_lower=num_lower, num_upper=num_upper,
                                name=name)
  return result
    A `Tensor`. Has the same type as `input`.
    Rank `k` tensor of the same shape as input. The extracted banded tensor.
  """
  result = _op_def_lib.apply_op("MatrixBandPart", input=input,
                                num_lower=num_lower, num_upper=num_upper,
                                name=name)
  return result

 

找到一个切入点:subdiagonal / superdiagonal
super-diagonal 超对角元  是对角线上方的那条线。
sub-diagonal    次对角元  是对角线下方的那条线。

"对于A(i,j)元素
如果i=j就叫对角元(diagonal)
如果i≠j就叫非对角元(off-diagonal)
如果i=j+1就叫次对角元(subdiagonal)
如果i=j-1就叫超对角元(superdiagonal)"

不清楚这个概念前,都没明确注意到例子中  对角线平行方向上方的数据依次是 1 2 3 ,对角线平行方向下方的数据依次是-1 -2 -3
 
这个函数会copy得到一个新tensor,把带外(outside a central band)的数据置0。第二和第三个参数:
num_lower  是对角线平行方向的下方 需要保留的对角元数量
num_upper 是对角线平行方向的上方 需要保留的对角元数量

num_lower为0时,对角线平行方向的下方 不保留任何对角元,则都会被置0;如果为2,对角线平行方向的下方 保留两组对象元信息,第三组及以上被置0;如果为负数,则对角线平行方向的下方都不会被置0,都会保持原来的值。
num_upper类似。

所以判断是否在band内的方式是:

  `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) &&
                   (num_upper < 0 || (n-m) <= num_upper)`.

效果相当于:              
对于矩阵的左下部分来说: num_lower为负 或者 行标-列标 在num_lower范围内
对于矩阵的右上部分来说: num_upper为负 或者 列标-行标 在num_upper范围内

显然地:
   tf.matrix_band_part(input, 0, -1) ==> Upper triangular part.   得到上对角矩阵
   tf.matrix_band_part(input, -1, 0) ==> Lower triangular part.   得到下对角矩阵
   tf.matrix_band_part(input, 0, 0) ==> Diagonal.  得到对角矩阵

 

来看粟子:

import numpy as np

import tensorflow as tf

tf.enable_eager_execution()

A = tf.reshape(tf.range(1, 26), [5, 5])

A
Out[5]: 
<tf.Tensor: id=5, shape=(5, 5), dtype=int32, numpy=
array([[ 1,  2,  3,  4,  5],
       [ 6,  7,  8,  9, 10],
       [11, 12, 13, 14, 15],
       [16, 17, 18, 19, 20],
       [21, 22, 23, 24, 25]])>

tf.matrix_band_part(A, 0, 0)
Out[6]: 
<tf.Tensor: id=9, shape=(5, 5), dtype=int32, numpy=
array([[ 1,  0,  0,  0,  0],
       [ 0,  7,  0,  0,  0],
       [ 0,  0, 13,  0,  0],
       [ 0,  0,  0, 19,  0],
       [ 0,  0,  0,  0, 25]])>

tf.matrix_band_part(A, 0, 1)
Out[7]: 
<tf.Tensor: id=13, shape=(5, 5), dtype=int32, numpy=
array([[ 1,  2,  0,  0,  0],
       [ 0,  7,  8,  0,  0],
       [ 0,  0, 13, 14,  0],
       [ 0,  0,  0, 19, 20],
       [ 0,  0,  0,  0, 25]])>

tf.matrix_band_part(A, 0, 2)
Out[8]: 
<tf.Tensor: id=17, shape=(5, 5), dtype=int32, numpy=
array([[ 1,  2,  3,  0,  0],
       [ 0,  7,  8,  9,  0],
       [ 0,  0, 13, 14, 15],
       [ 0,  0,  0, 19, 20],
       [ 0,  0,  0,  0, 25]])>

tf.matrix_band_part(A, -2, 1)
Out[9]: 
<tf.Tensor: id=21, shape=(5, 5), dtype=int32, numpy=
array([[ 1,  2,  0,  0,  0],
       [ 6,  7,  8,  0,  0],
       [11, 12, 13, 14,  0],
       [16, 17, 18, 19, 20],
       [21, 22, 23, 24, 25]])>

tf.matrix_band_part(A, 2, 1)
Out[10]: 
<tf.Tensor: id=25, shape=(5, 5), dtype=int32, numpy=
array([[ 1,  2,  0,  0,  0],
       [ 6,  7,  8,  0,  0],
       [11, 12, 13, 14,  0],
       [ 0, 17, 18, 19, 20],
       [ 0,  0, 23, 24, 25]])>

B = A[0:3]

tf.matrix_band_part(B, 0, 0)
Out[12]: 
<tf.Tensor: id=33, shape=(3, 5), dtype=int32, numpy=
array([[ 1,  0,  0,  0,  0],
       [ 0,  7,  0,  0,  0],
       [ 0,  0, 13,  0,  0]])>

tf.matrix_band_part(B, 0, 1)
Out[13]: 
<tf.Tensor: id=37, shape=(3, 5), dtype=int32, numpy=
array([[ 1,  2,  0,  0,  0],
       [ 0,  7,  8,  0,  0],
       [ 0,  0, 13, 14,  0]])>

 

另外 attention_mask函数中,为什么不使用tf.matrix_band_part, 为什么这个在TPU上会产生garbage?

猜测是否与 tf.matrix_band_part 最终调用的C代码来实现。而用当前的实现,是在graph中 实现的张量运算。

不过TPU上好像是不能运行C++程序,而不是注释中的“产生garbage”。。 不清楚,欢迎知道的大侠指导一下。

 

 

“TPU上无法运行CPU上跑的Java或C++程序,也无法运行GPU上的CUDA程序。虽然尚未有公开信息,但它的编程方式非常可能是这样:TensorFlow把神经网络用一种中间格式表示出来,然后这种中间格式被编译器转换为TPU上独特的程序。这种中间格式被称为TensorFlow XLA,它也将是TensorFlow支持其它线性代数加速器的工具。”

 

 

 

 

 

 

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值