Keras及tf一些操作小记

本文总结了使用 Keras 构建模型时的一些关键操作,包括生成 mask、K.tile 复制维度、自定义层的 compute_mask 方法、不同类型的乘法操作、Lambda 层的应用、Input 层指定 tensor、多 GPU 训练问题以及安装库的注意事项。此外,还提到了注意力得分的计算方法和分布式训练的初步介绍。这些笔记对于理解和复现 Keras 模型非常有帮助。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在使用Keras写模型的时候总是会忘记一些操作,这里做一些小记,方便自己用到的时候查阅。

(1)mask的生成及计算loss时的使用

mask = Lambda(lambda x: K.cast(K.greater(K.expand_dims(x, 2), 0), 'float32'))(tokens)
#通过Lambda层创建mask,就不需要再输入mask了
sub_heads_loss = K.binary_crossentropy(gold_sub_heads, pred_sub_heads)
sub_heads_loss = K.sum(sub_heads_loss * mask) / K.sum(mask) # * 是对应位置相乘,最后得到一个数值

(2)K.tile 复制某个维度

query_ = K.expand_dims(query_, 2)
query_ = K.tile(query_, [K.shape(tokens_feature)[0], 1, K.shape(tokens_feature)[1], 1])

第二个参数中1表示不复制

(3)关于compute_mask
创建自己的层的时候:

    def compute_mask(self, inputs, mask=None):#由于传入数据有Embedding操作,默认传入上一步Embedding的mask的结果,这一步不用就需要return None,不然会报错
        return None

详细介绍参考:https://blog.youkuaiyun.com/znevegiveup1/article/details/114552725

(4)各种乘法 * K.dot K.batch_dot计算
星号 乘 是对应位置相乘 [1,2,3] * [4,5,6]=[4,10,18]
[[1,2,3,4],[5,6,7,8]]*[1,1,0,0]=[[1,2,0,0],[5,6,0,0]] 计算带mask损失函数时会用到
K.dot(A,B) A.shape(-1)和B.shape(-2)必须一样,其余没有限制
矩阵乘法,A.shape=(2,5) B.shape=(5,2) K.dot(A,B)后shape为(2,2)

a.shape:(1, 2, 4) b.shape:(8, 7, 4, 5) 待补充

K.batch_dot(A,B)
A.shape(-1)和B.shape(-2)必须一样,A.shape(-2)和B.shape(-1)可以不一样,shape(-2)以前的必须完全一样
a = K.ones((9, 8, 7, 4, 2))
b = K.ones((9, 8, 7, 2, 5))
c = K.batch_dot(a, b)
c.shape (9, 8, 7, 4, 5)

详细的计算过程待补充

(5)Lambda使用自己的函数

def my_gather(x):
    data, index = x
    index = K.cast(index, 'int32')
    return tf.gather(data,index)

type_rep = Lambda(my_gather)([type_emb,type_id])

keras的model里面必须都是层,一些操作可以借助Lambda封装为层

(6)Input(tensor=K.variable(np.arange(0,type_num)))情况

有时候需要这种,Input层有tensor,就不再创建占位符,也就不需要传入数据了(data_generator不再需要传入这个Input)

(7)bug:Dimensions must be equal, but are 11 and 23 for xxxxx with input shapes: [?,11,?,768], [?,23,?,768].
明明计算好的shape,不会出现这种错误,以为是训练数据维度错误,没有!没有训练的时候就报错了,原来是multi_gpu_model的原因,去掉就好了!

(8)安装allennlp0.9.0的时候,要sudo apt-get g++ ,还要安装jsonnet0.17.0,0.18.0会出错!坑

(9) 插曲torch错误:
from torch._C import *
ImportError: dlopen: cannot load any more object with static TLS
解决方案:把import torch放在最上面。

(10)计算attention_socres的做法,与scores = scores.masked_fill_(mask, -1e10)做法等同
在这里插入图片描述

(11) 关于单机多卡等分布式训练
https://zhuanlan.zhihu.com/p/73580663

未完待续。。。
有了这些再也不用怕复现模型了

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值