tensorflow保存部分变量

本文详细介绍了在TensorFlow中如何使用模型保存功能来选择性地保存和加载部分变量,这对于模型维护和优化至关重要。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

tensorflow模型保存函数为:

saver = tf.train.Saver()

查看该函数初始化函数,输入参数为:

  def __init__(self,
               var_list=None,
               reshape=False,
               sharded=False,
               max_to_keep=5,
               keep_checkpoint_every_n_hours=10000.0,
               name=None,
               restore_sequentially=False,
               saver_def=None,
               builder=None,
               defer_build=False,
               allow_empty=False,
               write_version=saver_pb2.SaverDef.V2,
               pad_step_number=False,
               save_relative_paths=False):
var_list参数为我们需要保存的变量数组,如果不输入var_list,则默认保存所有的变量.例如下面代码:

import tensorflow as tf

v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    saver.save(sess,"checkpoint/model_test",global_step=1)


当我们需要保存部分变量时,我们可以定义一个需要保存的变量数组,例如我们只想保存变量v2,代码如下:

import tensorflow as tf

v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")

saver = tf.train.Saver( [v2])
# saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    saver.save(sess,"checkpoint/model_test",global_step=1)

需要注意的是,模型保存需要在某一目录下,例如上面代码保存在checkpoint目录下,模型名为model_test,如果保存在当前目录下,如下代码:

import tensorflow as tf

v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")

saver = tf.train.Saver( [v2])
# saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    saver.save(sess,"model_test",global_step=1)
则会报错:


Traceback (most recent call last):
  File "/home/qinghua/pythonWork/pix2pix-tensorflow/demo_model_part.py", line 11, in <module>
    saver.save(sess,"model_test",global_step=1)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1488, in save
    raise exc
tensorflow.python.framework.errors_impl.FailedPreconditionError: checkpoint.tmp2059e44e2c59453c824e5784c5bbef8e


Process finished with exit code 1





评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值