Tensorflow1.x加载模型的方法

博客围绕Tensorflow加载模型展开,先指出有同学加载模型结果不对的问题。接着介绍加载模型需网络参数和图结构,加载图有重新搭建网络和用.meta文件两种方式,并给出测试结果。最后分析错误原因是图加载两次,导致结构混乱。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

代码地址:查看完整代码

一个错误的使用

之前有同学问过我这个问题,TF加载模型,跑出来的结果不对,代码见incurrect_usage.py,正确率和猜的一样,怀疑是模型加载那里出问题了。

#****************** incurrent usage.py*********************
x = tf.placeholder(tf.float32, [None, 784], name="input")
y_ = tf.placeholder(tf.float32, [None, 10], name="label")
pred = forward(x)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.import_meta_graph('./model/mnist_model-4000.meta')
    saver.restore(sess, './model/mnist_model-4000')
    correct_prediction = tf.equal(tf.argmax(pred,1), tf.argmax(test_label, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    acc = sess.run(accuracy, feed_dict={x: test_image, y_: test_label})

跑出来的正确率都在0.1左右,训练正确率都在0.9以上,再差也不会这样,所以加载模型哪里出错了。

Tensorflow加载模型的方法

本例使用tf.train.Saver()保存模型的方法,执行saver.save(sess, model_name)后,会得到3个名为model_name的文件,.data-00000-of-00001中保存了网络训练的参数,.meta保存了网络的图结构。

Tensorflow在加载模型的时候就需要上述的两个东西,网络参数和图结构,而加载图有两种方式,重新搭建网络直接用.meta文件。

重新搭建网络

顾名思义,在测试代码中重新把训练时forward的流程再搭一遍,这样就能得到由训练好的参数得到forward的结果。

#****************** test with network.py*********************
x = tf.placeholder(tf.float32, [None, 784], name="input")
y_ = tf.placeholder(tf.float32, [None, 10], name="label")
pred = forward(x)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess, './model/mnist_model-4000')
    correct_prediction = tf.equal(tf.argmax(pred,1), tf.argmax(test_label, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    acc = sess.run(accuracy, feed_dict={x: test_image, y_: test_label})

因为forward流程和训练时一样,所以直接在训练代码里拿来用,已经重新搭建了图,就不要加载.meta文件了,所以直接restore参数文件就可以了。

拿测试集中前5000个样本做测试,测试结果:

test with network: 
INFO:tensorflow:Restoring parameters from ./model/mnist_model-4000
accuracy is:  0.9784

网络结构:

在这里插入图片描述

使用.meta文件构建图

使用.meta文件需要注意,在训练时最好为输入和输出取一个名字,因为需要直接从.meta保存的图结构中取输入和输出,有名字的时候会更明确一些。

像这样:

x = tf.placeholder(tf.float32, [None, 784], name="input")
y_ = tf.placeholder(tf.float32, [None, 10], name="label")

加载.meta代码如下:

#****************** test with meta.py*********************
with tf.Session() as sess:
    saver = tf.train.import_meta_graph('./model/mnist_model-4000.meta')
    saver.restore(sess, tf.train.latest_checkpoint("./model/"))
    
    graph = tf.get_default_graph()
    input_x = graph.get_operation_by_name("input").outputs[0]
    feed_dict = {"input:0":test_image, "label:0":test_label}
    pred = graph.get_tensor_by_name("output:0")
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(test_label, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    acc = sess.run(accuracy, feed_dict=feed_dict)

使用.meta文件,直接根据名字找到对应的输出和输出,获取默认图结构,不需要重新初始化参数。

拿测试集中前5000个样本做测试,测试结果:

test with .meta:
INFO:tensorflow:Restoring parameters from ./model/mnist_model-4000
accuracy is:  0.9784

测试结果和重新构建网络是一样的。

网络结构:

在这里插入图片描述

使用.meta测试时,网络输出那里出现了两个分支,猜测是.meta保存了训练时测试accuracy那部分图,我在测试的代码里又写了一个测试accuracy的部分,所以两部分都被保存了,但不影响测试的结果。

错误的原因

很容易猜到,图加载了两次,已经重建网络了,然后又加载了.meta,导致图的结构乱了,看图:

在这里插入图片描述

网络的结构已经变了,所以加载训练好的模型时,要么重建图,要么加载.meta,混合起来就容易出错。

TODO:使用滑动平均如何加载模型

参考

Mnist网络backbone:点击前往

TF加载模型方法: 点击前往

### 如何在 TensorFlow 1.x 中正确保存 Keras 模型.h5 格式 在 TensorFlow 1.x 版本中,Keras 提供了多种方式来保存模型及其权重。为了将模型保存为 `.h5` 文件格式,可以利用 `model.save()` 方法或者 `model.save_weights()` 方法[^3]。 #### 使用 `model.save()` 如果希望保存整个模型(包括架构和权重),可以直接调用 `model.save(filepath)` 方法。此方法会自动检测文件扩展名并决定存储格式。当指定的路径以 `.h5` 结尾时,模型将以 HDF5 格式保存: ```python model.save('my_model.h5') ``` 上述代码片段中的 `'my_model.h5'` 将作为最终保存的目标文件名称。 #### 使用 `model.save_weights()` 另一种情况是只保存模型的权重而不保存其结构,在这种情况下可使用 `model.save_weights(filepath)` 方法。同样地,若路径带有 `.h5` 扩展名,则权重会被序列化到 HDF5 文件中: ```python model.save_weights('my_model_weights.h5') ``` 需要注意的是,这种方式不会保留模型配置信息,因此后续加载时需先重建相同的模型实例再加载权重。 以下是基于这两种方法的一个完整示例: ```python import tensorflow as tf from tensorflow import keras # 假设已经定义了一个简单的 Sequential 模型 model = keras.Sequential([ keras.layers.Dense(64, activation='relu', input_shape=(32,)), keras.layers.Dense(10, activation='softmax') ]) # 编译模型 (此处省略具体参数设置) model.compile(optimizer=..., loss=...) # 训练模型 (此处省略) # 方案一:保存整个模型至 H5 文件 model.save('full_model.h5') # 方案二:仅保存模型权重至 H5 文件 model.save_weights('weights_only.h5') ``` 以上两种方案均适用于 TensorFlow 1.x 的环境,并能有效生成标准的 `.h5` 文件用于进一步处理或部署。 ### 注意事项 尽管 `.h5` 文件便于共享和迁移学习场景下的应用,但在某些特定需求下(如导出为 Frozen Graph 或 SavedModel 格式以便于生产环境中高效推理),还需要额外考虑其他形式的支持][^[^24]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值