在使用Keras写模型的时候总是会忘记一些操作,这里做一些小记,方便自己用到的时候查阅。
(1)mask的生成及计算loss时的使用
mask = Lambda(lambda x: K.cast(K.greater(K.expand_dims(x, 2), 0), 'float32'))(tokens)
#通过Lambda层创建mask,就不需要再输入mask了
sub_heads_loss = K.binary_crossentropy(gold_sub_heads, pred_sub_heads)
sub_heads_loss = K.sum(sub_heads_loss * mask) / K.sum(mask) # * 是对应位置相乘,最后得到一个数值
(2)K.tile 复制某个维度
query_ = K.expand_dims(query_, 2)
query_ = K.tile(query_, [K.shape(tokens_feature)[0], 1, K.shape(tokens_feature)[1], 1])
第二个参数中1表示不复制
(3)关于compute_mask
创建自己的层的时候:
def compute_mask(self, inputs, mask=None):#由于传入数据有Embedding操作,默认传入上一步Embedding的mask的结果,这一步不用就需要return None,不然会报错
return None
详细介绍参考:https://blog.youkuaiyun.com/znevegiveup1/article/details/114552725
(4)各种乘法 * K.dot K.batch_dot计算
星号 乘 是对应位置相乘 [1,2,3] * [4,5,6]=[4,10,18]
[[1,2,3,4],[5,6,7,8]]*[1,1,0,0]=[[1,2,0,0],[5,6,0,0]] 计算带mask损失函数时会用到
K.dot(A,B) A.shape(-1)和B.shape(-2)必须一样,其余没有限制
矩阵乘法,A.shape=(2,5) B.shape=(5,2) K.dot(A,B)后shape为(2,2)
a.shape:(1, 2, 4) b.shape:(8, 7, 4, 5) 待补充
K.batch_dot(A,B)
A.shape(-1)和B.shape(-2)必须一样,A.shape(-2)和B.shape(-1)可以不一样,shape(-2)以前的必须完全一样
a = K.ones((9, 8, 7, 4, 2))
b = K.ones((9, 8, 7, 2, 5))
c = K.batch_dot(a, b)
c.shape (9, 8, 7, 4, 5)
详细的计算过程待补充
(5)Lambda使用自己的函数
def my_gather(x):
data, index = x
index = K.cast(index, 'int32')
return tf.gather(data,index)
type_rep = Lambda(my_gather)([type_emb,type_id])
keras的model里面必须都是层,一些操作可以借助Lambda封装为层
(6)Input(tensor=K.variable(np.arange(0,type_num)))情况
有时候需要这种,Input层有tensor,就不再创建占位符,也就不需要传入数据了(data_generator不再需要传入这个Input)
(7)bug:Dimensions must be equal, but are 11 and 23 for xxxxx with input shapes: [?,11,?,768], [?,23,?,768].
明明计算好的shape,不会出现这种错误,以为是训练数据维度错误,没有!没有训练的时候就报错了,原来是multi_gpu_model的原因,去掉就好了!
(8)安装allennlp0.9.0的时候,要sudo apt-get g++ ,还要安装jsonnet0.17.0,0.18.0会出错!坑
(9) 插曲torch错误:
from torch._C import *
ImportError: dlopen: cannot load any more object with static TLS
解决方案:把import torch放在最上面。
(10)计算attention_socres的做法,与scores = scores.masked_fill_(mask, -1e10)做法等同
(11) 关于单机多卡等分布式训练
https://zhuanlan.zhihu.com/p/73580663
未完待续。。。
有了这些再也不用怕复现模型了