TensorFlow-模型的保存和调用(freeze方式)

TensorFlow-模型的保存和调用(freeze方式

硬件:NVIDIA-GTX1080

软件:Windows7、python3.6.5、tensorflow-gpu-1.4.0

一、基础知识

freeze:将ckpt的三个文件融合为一个文件,将variables转换为constant,文件更小,更易于移植

二、代码展示

1、保存模型

import tensorflow as tf
from tensorflow.python.framework import graph_util

# input
x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')

# value b to be saved, like weight or bias
b = tf.Variable(1, name='b')

# output
xy = tf.multiply(x, y)
output = tf.add(xy, b, name='output') # name = 'output' must be added

inti = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(inti)

    # define graph to write
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['output'])
 
    # write pb file
    with tf.gfile.FastGFile('model.pb', mode='wb') as f:
        f.write(constant_graph.SerializeToString())

    # test
    print(sess.run(output, feed_dict = {x: 10, y: 3}))

2、调用模型

import tensorflow as tf
from tensorflow.python.platform import gfile

sess = tf.Session()

# import graph, restore
with gfile.FastGFile('model.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')

# need initial  
sess.run(tf.global_variables_initializer())
 
# input
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')

# check variable b, like weight or bias
print(sess.run('b:0'))

# output
output = sess.graph.get_tensor_by_name('output:0')

# test
print(sess.run(output, feed_dict={input_x: 5, input_y: 5}))

 

任何问题请加唯一QQ2258205918(名称samylee)!

唯一VX:samylee_csdn

WARNING:tensorflow:From F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\ops\init_ops.py:87: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with distribution=normal is deprecated and will be removed in a future version. Instructions for updating: `normal` is a deprecated alias for `truncated_normal` Traceback (most recent call last): File "e:/20200611_Disease_System_v5/test.py", line 13, in <module> custom_objects={'BatchNormalization': BatchNormalizationFreeze} File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\engine\saving.py", line 230, in load_model model = model_from_config(model_config, custom_objects=custom_objects) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\engine\saving.py", line 310, in model_from_config return deserialize(config, custom_objects=custom_objects) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\layers\serialization.py", line 64, in deserialize printable_module_name='layer') File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\utils\generic_utils.py", line 173, in deserialize_keras_object list(custom_objects.items()))) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\engine\network.py", line 1292, in from_config process_layer(layer_data) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\engine\network.py", line 1278, in process_layer layer = deserialize_layer(layer_data, custom_objects=custom_objects) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\layers\serialization.py", line 64, in deserialize printable_module_name='layer') File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\utils\generic_utils.py", line 163, in deserialize_keras_object raise ValueError('Unknown ' + printable_module_name + ': ' + class_name) ValueError: Unknown layer: UpsampleLike PS E:\20200611_Disease_System_v5> & F:/RJ/envs/Tensorflow-GPU/python.exe e:/20200611_Disease_System_v5/test.py Traceback (most recent call last): File "e:/20200611_Disease_System_v5/test.py", line 12, in <module> 'retinanet/retinanet_inference_heng.h5' File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\engine\saving.py", line 230, in load_model model = model_from_config(model_config, custom_objects=custom_objects) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\engine\saving.py", line 310, in model_from_config return deserialize(config, custom_objects=custom_objects) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\layers\serialization.py", line 64, in deserialize printable_module_name='layer') File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\utils\generic_utils.py", line 173, in deserialize_keras_object list(custom_objects.items()))) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\engine\network.py", line 1292, in from_config process_layer(layer_data) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\engine\network.py", line 1278, in process_layer layer = deserialize_layer(layer_data, custom_objects=custom_objects) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\layers\serialization.py", line 64, in deserialize printable_module_name='layer') File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\utils\generic_utils.py", line 175, in deserialize_keras_object return cls.from_config(config['config']) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1606, in from_config return cls(**config) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\layers\normalization.py", line 147, in __init__ name=name, trainable=trainable, **kwargs) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\training\checkpointable\base.py", line 474, in _method_wrapper method(self, *args, **kwargs) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 138, in __init__ raise TypeError('Keyword argument not understood:', kwarg) TypeError: ('Keyword argument not understood:', 'freeze') PS E:\20200611_Disease_System_v5> & F:/RJ/envs/Tensorflow-GPU/python.exe e:/20200611_Disease_System_v5/test.py Traceback (most recent call last): File "e:/20200611_Disease_System_v5/test.py", line 12, in <module> 'retinanet/retinanet_inference_heng.h5' File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\engine\saving.py", line 230, in load_model model = model_from_config(model_config, custom_objects=custom_objects) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\engine\saving.py", line 310, in model_from_config return deserialize(config, custom_objects=custom_objects) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\layers\serialization.py", line 64, in deserialize printable_module_name='layer') File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\utils\generic_utils.py", line 173, in deserialize_keras_object list(custom_objects.items()))) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\engine\network.py", line 1292, in from_config process_layer(layer_data) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\engine\network.py", line 1278, in process_layer layer = deserialize_layer(layer_data, custom_objects=custom_objects) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\layers\serialization.py", line 64, in deserialize printable_module_name='layer') File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\utils\generic_utils.py", line 175, in deserialize_keras_object return cls.from_config(config['config']) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1606, in from_config return cls(**config) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\layers\normalization.py", line 147, in __init__ name=name, trainable=trainable, **kwargs) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\training\checkpointable\base.py", line 474, in _method_wrapper method(self, *args, **kwargs) File "F:\RJ\envs\Tensorflow-GPU\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 138, in __init__ raise TypeError('Keyword argument not understood:', kwarg) TypeError: ('Keyword argument not understood:', 'freeze')
08-15
在使用 TensorFlow/Keras 加载模型时,如果模型中包含自定义层(如 `UpsampleLike`)或在模型保存后环境发生变化(如某些参数被弃用或修改),可能会遇到 `Unknown layer: UpsampleLike` 或 `TypeError: Keyword argument not understood: 'freeze'` 等错误。这些问题通常源于模型保存方式与加载环境之间的不兼容性。 ### `Unknown layer: UpsampleLike` 错误的解决方法 该错误表明在加载模型时,Keras 无法识别模型中使用的自定义层 `UpsampleLike`。通常,Keras 在保存模型时会将模型结构序列化为 JSON,而自定义层不会被自动识别。因此,加载时需要手动注册自定义对象。 解决方法如下: -调用 `load_model` 时,通过 `custom_objects` 参数注册自定义层: ```python from tensorflow.keras.models import load_model # 假设 UpsampleLike 是从某个模块导入的 from your_module import UpsampleLike model = load_model('your_model.h5', custom_objects={'UpsampleLike': UpsampleLike}) ``` - 如果 `UpsampleLike` 是一个函数而不是类,也可以直接传入函数定义: ```python def UpsampleLike(inputs, **kwargs): # 实现逻辑 pass model = load_model('your_model.h5', custom_objects={'UpsampleLike': UpsampleLike}) ``` ### `TypeError: Keyword argument not understood: 'freeze'` 错误的解决方法 此错误通常发生在模型中某些层或函数使用了 `freeze` 参数,而加载时当前版本的 Keras 不再支持该参数。这可能由于模型保存时使用的 Keras 版本与当前加载环境版本不一致导致。 解决方法如下: - 检查模型中涉及 `freeze` 参数的层或函数,并确保其定义与当前 Keras 版本兼容。 - 如果是自定义层中使用了 `freeze` 参数,可以在自定义层的 `__init__` 方法中忽略或适配该参数: ```python class CustomLayer(tf.keras.layers.Layer): def __init__(self, freeze=False, **kwargs): if 'freeze' in kwargs: kwargs.pop('freeze') # 忽略不支持的参数 super(CustomLayer, self).__init__(**kwargs) def call(self, inputs): return inputs ``` - 在加载模型时,同样通过 `custom_objects` 注册该层: ```python model = load_model('your_model.h5', custom_objects={'CustomLayer': CustomLayer}) ``` ### 注意事项 - 保存模型时,建议使用 `save_model` 方法保存整个模型(包括结构、权重优化器状态)。 -模型中使用了自定义训练逻辑(如 `train_step`),则需要继承 `Model` 类并重写相关方法,加载时也需相应注册。 - 尽量保持训练推理环境的库版本一致,以减少兼容性问题。
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值