Tensorflow 实现常见mask

一、self-attention中的mask

1.1 attention的mask.

1.1.1 举例

q_mask = [1, 1, 1, 1, 0, 0]  # seq_len. 其中1表示有效, 0表示无效. 

# self-attention的score为 [seq_len, seq_len]

q_mask = tf.expand_dims(q_mask, axis=-1)  # [seq_len, 1]
k_mask = tf.reshape(q_mask, [1, -1])  # [1, seq_len] 
attention_mask = tf.matmul(q_mask, k_mask)   # [seq_len, seq_len]

1.1.2 调用

import tensorflow as tf

q_mask = tf.expand_dims(tf.constant([1, 1, 1, 1, 0, 0]), axis=-1)      # seq_len, 1
k_mask = tf.reshape(q_mask, [1, -1])   # 1, seq_len
output = tf.matmul(q_mask, k_mask)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    print_output = sess.run([output])
    print(print_output)

# [[1, 1, 1, 1, 0, 0],
#  [1, 1, 1, 1, 0, 0],
#  [1, 1, 1, 1, 0, 0],
#  [1, 1, 1, 1, 0, 0],
#  [0, 0, 0, 0, 0, 0],
#  [0, 0, 0, 0, 0, 0]]

二、序列mask

1.1 tf.sequence_mask

1.1.1 函数参数

tf.sequence_mask(
    lengths, maxlen=None, dtype=tf.dtypes.bool, name=None
)

1.1.2 调用

tf.sequence_mask([1, 3, 2], 5)  # [[True, False, False, False, False],
                                #  [True, True, True, False, False],
                                #  [True, True, False, False, False]]

tf.sequence_mask([1, 3, 2], 5, dtype=tf.int64)
								# [[1 0 0 0 0]
								#  [1 1 1 0 0]
                                #  [1 1 0 0 0]]

tf.sequence_mask([[1, 3],[2,0]])  # [[[True, False, False],
                                  #   [True, True, True]],
                                  #  [[True, True, False],
                                  #   [False, False, False]]

由此可以得知:padding部分的mask是False, 有效部分的mask是True。
如果指定dtype=int64, 那么输出也便是如此。

三、decoder中的attention mask

1.1 下三角mask

保证一个词只能看到前面的词语,看不到后面的词。

import tensorflow as tf

q_len, k_len = 5, 5
mask = tf.ones(shape=[q_len, k_len])                       
tril = tf.linalg.LinearOperatorLowerTriangular(mask).to_dense()


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    print_output = sess.run([tril])
    print(print_output)

[[1., 0., 0., 0., 0.],
 [1., 1., 0., 0., 0.],
 [1., 1., 1., 0., 0.],
 [1., 1., 1., 1., 0.],
 [1., 1., 1., 1., 1.]]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值