Tensorflow2 加载别人模型权重失败时的一种解决方法:每一层分别设置值

在导入别人模型参数时,因版本等问题可能导致导入失败。本文以TensorFlow2为例,介绍手动逐层加载模型权重的方法。作者加载模型权重报错后,尝试手动逐层加载可行。还说明了可通过set_weights方式设置权重,并给出将别人已加载好模型参数设置到自己模型中的方法。
部署运行你感兴趣的模型镜像

有时自己建立模型导入别人模型参数时会因为版本等各个问题导致导入失败,这时可以一层一层的给网络赋值权重。

 本人在加载某个模型的权重通过load_weights方式加载时报错,但是我拥有别人已加载完权重的model。所以试了下手动一层一层的加载权重,结果还是可行的,就是比较麻烦。

tensorflow2的模型权重可以以set_weights方式进行设置权重。比如要改以下自己定义的模型每层加载权重:

from tensorflow.keras import Sequential, layers, Model
import numpy as np

class SetWeights(Model):
    def __init__(self):
        super(SetWeights, self).__init__()
        self.c = layers.Conv2D(filters = 32,kernel_size=(3,3),strides=(1,1),padding='same')
        self.bn = layers.BatchNormalization()
        self.ac = layers.LeakyReLU(0.1)

    def call(self, inputs, training=None, mask=None):
        x = self.ac(self.bn(self.c(inputs)))
        return x

首先初始化模型,查看可训练参数结构,如下:

model = SetWeights()
model.build((None,32,32,3))

# 查看所有可训练的参数
# 以列表的方式保存,长度4,包含卷积层的w,b,BN的gamma,beta
model.trainable_variables

通过以下方法设置权重

# model.属性.set_weights方法
# 需要注意的是传入的参数是一个list
model.c.set_weights()

比如我们实例化两个model (model1相当于自己建立的模型,model2相当于拥有的别人已加载好的模型,因为结构顺序或者版本等原因,不能save_weights再load_weights加载别人的权重)

将model2的参数设置到model1中方法:  

model1 = SetWeights()
model1.build((None,32,32,3))

model2 = SetWeights()
model2.build((None,32,32,3))

# 卷积层的权重设置(w、b)
# 传入的list顺序是kernel、bias
model1.c.set_weights([model2.trainable_variables[0].numpy(),model2.trainable_variables[1].numpy()])

# BN层参数设置
# 传入的list顺序是gamma、beta、moving_mean、moving_variance

# 获得BN层每个参数长度
bn_len = model2.trainable_variables[2].shape[0]
model1.bn.set_weights([model2.trainable_variables[2].numpy(),model2.trainable_variables[3].numpy(),np.zeros((bn_len),dtype=np.float32),np.ones((bn_len),dtype=np.float32)])

# 这样model1里的权重已经变成model2中的了

 

 

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.9

TensorFlow-v2.9

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值