【Tensorflow 大马哈鱼】保存模型、再次加载模型等操作

本文详细介绍使用TensorFlow进行模型的保存与加载方法,包括如何利用saver对象保存整个模型的状态,以及如何在后续训练中加载这些状态继续训练。同时介绍了在实际应用中需要注意的关键点。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

由于经常要使用tensorflow进行网络训练,但是在用的时候每次都要把模型重新跑一遍,这样就比较麻烦;另外由于某些原因程序意外中断,也会导致训练结果拿不到,而保存中间训练过程的模型可以以便下次训练时继续使用。

所以练习了tensorflow的save model和load model。

参考于http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/,这篇教程简单易懂!!


1、保存模型

# 首先定义saver类
#max_to_keep=4表示保存4个模型,默认=5
saver = tf.train.Saver(max_to_keep=4)

# 定义会话
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for epoch in range(200):
        # 训练
        sess.run(train_step)
        if epoch % 10 == 0:
            print "------------------------------------------------------"
            # 保存模型
            saver.save(sess, "model/my-model", global_step=epoch)
            print("save the model")

        

注意点:

  1. 创建saver时,可以指定需要存储的tensor,如果没有指定,则全部保存。

  2. 创建saver时,可以指定保存的模型个数,利用max_to_keep=4,则最终会保存最新的4个模型(下图中我保存了160、170、180、190step共4个模型)。默认为5个。   如果没有global_step=epoch这句话,则保存最终一个。

  3. saver.save()函数里面可以设定global_step,说明是哪一步保存的模型,如果为=epoch,则保存最后4个epoch的模型,如果为50,则保存第50个epoch的模型。

  4. 程序结束后,会生成四个文件:存储网络结构.meta、存储训练好的参数.data和.index、记录最新的模型checkpoint。

  5. 如果saver.save(sess, './params', write_meta_graph=False),则不保存网路结构.meta,默认是True

如:

这里写图片描述

2、加载模型

def load_model():
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('model/my-model-290.meta')
        saver.restore(sess, tf.train.latest_checkpoint("model/"))

注意点:

  1. 首先导入模型结构import_meta_graph,这里填的名字meta文件的名字。

  2. 然后restore时,是检查checkpoint,所以只填到checkpoint所在的路径下即可,不需要填checkpoint,不然会报错“ValueError: Can’t load save_path when it is None.”。

  3. 后面根据具体例子,介绍如何利用加载后的模型得到训练的结果,并进行预测。

  4. checkpoint文件的内容:

         

3、不保存模型结构

tf_x = tf.placeholder(tf.float32, x.shape)  # input x
tf_y = tf.placeholder(tf.float32, y.shape)  # input y
l = tf.layers.dense(tf_x, 10, tf.nn.relu)          # hidden layer
o = tf.layers.dense(l, 1)                     # output layer
loss = tf.losses.mean_squared_error(tf_y, o)   # compute cost
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss)

sess = tf.Session()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值