在看gpt2源码时,有这样一段:
def attention_mask(nd, ns, *, dtype):
"""1's in the lower triangle, counting from the lower right corner.
Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
"""
i = tf.range(nd)[:,None]
j = tf.range(ns)
m = i >= j - ns + nd
return tf.cast(m, dtype)
之前没见过这个函数 tf.matrix_band_part,搜到的基本都是源码的注释,这个注释也没看懂:
def matrix_band_part(input, num_lower, num_upper, name=None):
r"""Copy a tensor setting everything outside a central band in each innermost matrix
to zero.
The `band` part is computed as follows:
Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a
tensor with the same shape where
`band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
The indicator function
`in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) &&
(num_upper < 0 || (n-m) <= num_upper)`.
For example:
```
# if 'input' is [[ 0, 1, 2, 3]
[-1, 0, 1, 2]
[-2, -1, 0, 1]
[-3, -2, -1, 0]],
tf.matrix_band_part(input, 1, -1) ==> [[ 0, 1, 2, 3]
[-1, 0, 1, 2]
[ 0, -1, 0, 1]
[ 0, 0, -1, 0]],
tf.matrix_band_part(input, 2, 1) ==> [[ 0, 1, 0, 0]
[-1, 0, 1, 0]
[-2, -1, 0, 1]
[ 0, -2, -1, 0]]
```
Useful special cases:
```
tf.matrix_band_part(input, 0, -1) ==> Upper triangular part.
tf.matrix_band_part(input, -1, 0) ==> Lower triangular part.
tf.matrix_band_part(input, 0, 0) ==> Diagonal.
```
Args:
input: A `Tensor`. Rank `k` tensor.
num_lower: A `Tensor` of type `int64`.
0-D tensor. Number of subdiagonals to keep. If negative, keep entire
lower triangle.
num_upper: A `Tensor` of type `int64`.
0-D tensor. Number of superdiagonals to keep. If negative, keep
entire upper triangle.
name: A name for the operation (optional).
Returns:BandPart", input=input,
num_lower=num_lower, num_upper=num_upper,
name=name)
return result
A `Tensor`. Has the same type as `input`.
Rank `k` tensor of the same shape as input. The extracted banded tensor.
"""
result = _op_def_lib.apply_op("MatrixBandPart", input=input,
num_lower=num_lower, num_upper=num_upper,
name=name)
return result
找到一个切入点:subdiagonal / superdiagonal
super-diagonal 超对角元 是对角线上方的那条线。
sub-diagonal 次对角元 是对角线下方的那条线。
"对于A(i,j)元素
如果i=j就叫对角元(diagonal)
如果i≠j就叫非对角元(off-diagonal)
如果i=j+1就叫次对角元(subdiagonal)
如果i=j-1就叫超对角元(superdiagonal)"
不清楚这个概念前,都没明确注意到例子中 对角线平行方向上方的数据依次是 1 2 3 ,对角线平行方向下方的数据依次是-1 -2 -3
这个函数会copy得到一个新tensor,把带外(outside a central band)的数据置0。第二和第三个参数:
num_lower 是对角线平行方向的下方 需要保留的对角元数量
num_upper 是对角线平行方向的上方 需要保留的对角元数量
num_lower为0时,对角线平行方向的下方 不保留任何对角元,则都会被置0;如果为2,对角线平行方向的下方 保留两组对象元信息,第三组及以上被置0;如果为负数,则对角线平行方向的下方都不会被置0,都会保持原来的值。
num_upper类似。
所以判断是否在band内的方式是:
`in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) &&
(num_upper < 0 || (n-m) <= num_upper)`.
效果相当于:
对于矩阵的左下部分来说: num_lower为负 或者 行标-列标 在num_lower范围内
对于矩阵的右上部分来说: num_upper为负 或者 列标-行标 在num_upper范围内
显然地:
tf.matrix_band_part(input, 0, -1) ==> Upper triangular part. 得到上对角矩阵
tf.matrix_band_part(input, -1, 0) ==> Lower triangular part. 得到下对角矩阵
tf.matrix_band_part(input, 0, 0) ==> Diagonal. 得到对角矩阵
来看粟子:
import numpy as np
import tensorflow as tf
tf.enable_eager_execution()
A = tf.reshape(tf.range(1, 26), [5, 5])
A
Out[5]:
<tf.Tensor: id=5, shape=(5, 5), dtype=int32, numpy=
array([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10],
[11, 12, 13, 14, 15],
[16, 17, 18, 19, 20],
[21, 22, 23, 24, 25]])>
tf.matrix_band_part(A, 0, 0)
Out[6]:
<tf.Tensor: id=9, shape=(5, 5), dtype=int32, numpy=
array([[ 1, 0, 0, 0, 0],
[ 0, 7, 0, 0, 0],
[ 0, 0, 13, 0, 0],
[ 0, 0, 0, 19, 0],
[ 0, 0, 0, 0, 25]])>
tf.matrix_band_part(A, 0, 1)
Out[7]:
<tf.Tensor: id=13, shape=(5, 5), dtype=int32, numpy=
array([[ 1, 2, 0, 0, 0],
[ 0, 7, 8, 0, 0],
[ 0, 0, 13, 14, 0],
[ 0, 0, 0, 19, 20],
[ 0, 0, 0, 0, 25]])>
tf.matrix_band_part(A, 0, 2)
Out[8]:
<tf.Tensor: id=17, shape=(5, 5), dtype=int32, numpy=
array([[ 1, 2, 3, 0, 0],
[ 0, 7, 8, 9, 0],
[ 0, 0, 13, 14, 15],
[ 0, 0, 0, 19, 20],
[ 0, 0, 0, 0, 25]])>
tf.matrix_band_part(A, -2, 1)
Out[9]:
<tf.Tensor: id=21, shape=(5, 5), dtype=int32, numpy=
array([[ 1, 2, 0, 0, 0],
[ 6, 7, 8, 0, 0],
[11, 12, 13, 14, 0],
[16, 17, 18, 19, 20],
[21, 22, 23, 24, 25]])>
tf.matrix_band_part(A, 2, 1)
Out[10]:
<tf.Tensor: id=25, shape=(5, 5), dtype=int32, numpy=
array([[ 1, 2, 0, 0, 0],
[ 6, 7, 8, 0, 0],
[11, 12, 13, 14, 0],
[ 0, 17, 18, 19, 20],
[ 0, 0, 23, 24, 25]])>
B = A[0:3]
tf.matrix_band_part(B, 0, 0)
Out[12]:
<tf.Tensor: id=33, shape=(3, 5), dtype=int32, numpy=
array([[ 1, 0, 0, 0, 0],
[ 0, 7, 0, 0, 0],
[ 0, 0, 13, 0, 0]])>
tf.matrix_band_part(B, 0, 1)
Out[13]:
<tf.Tensor: id=37, shape=(3, 5), dtype=int32, numpy=
array([[ 1, 2, 0, 0, 0],
[ 0, 7, 8, 0, 0],
[ 0, 0, 13, 14, 0]])>
另外 attention_mask函数中,为什么不使用tf.matrix_band_part, 为什么这个在TPU上会产生garbage?
猜测是否与 tf.matrix_band_part 最终调用的C代码来实现。而用当前的实现,是在graph中 实现的张量运算。
不过TPU上好像是不能运行C++程序,而不是注释中的“产生garbage”。。 不清楚,欢迎知道的大侠指导一下。
“TPU上无法运行CPU上跑的Java或C++程序,也无法运行GPU上的CUDA程序。虽然尚未有公开信息,但它的编程方式非常可能是这样:TensorFlow把神经网络用一种中间格式表示出来,然后这种中间格式被编译器转换为TPU上独特的程序。这种中间格式被称为TensorFlow XLA,它也将是TensorFlow支持其它线性代数加速器的工具。”