tf.train.Saver()

本文详细介绍TensorFlow中Saver类的使用方法,包括模型的保存与加载过程,以及如何利用Saver对象调整保存的模型数量。同时,文章提供了两个实例,演示如何重新加载模型进行测试或继续训练。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

我们通常希望将辛辛苦苦训练好的模型保存起来,以便模型重用。tensorflow的Saver类可以提供此功能。

建Saver类时可传的参数如下所示:

def __init__(self,
               var_list=None,
               reshape=False,
               sharded=False,
               max_to_keep=5,
               keep_checkpoint_every_n_hours=10000.0,
               name=None,
               restore_sequentially=False,
               saver_def=None,
               builder=None,
               defer_build=False,
               allow_empty=False,
               write_version=saver_pb2.SaverDef.V2,
               pad_step_number=False,
               save_relative_paths=False,
               filename=None):

‘var_list’指明我们需要保存或重新加载的变量,可以是一个列表或一个字典,比如:

v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')

# 将参数作为字典传到:
saver = tf.train.Saver({'v1': v1, 'v2': v2})

# Or pass them as a list.
saver = tf.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]}),

其默认值是模型中所有可训练的模型。另一个经常用的参数是max_to_keep,它用于设置保存的模型的个数,默认为5.如果你想每训练一个epoch就保存一次模型,可以将max_to_keep设置为None或0。其他参数说明可参考tensorflow官网说明。

下面用一个实例介绍如何用Saver保存和加载模型:

下面的代码的功能是求解方程y=a*x+b,给定样本(x,y),求解a和b。

#coding:utf-8
# @Time     :2018/10/24 12:39
# @Author   :YY
# @Des      :
import tensorflow as tf
import numpy as np
x = tf.placeholder( dtype=tf.float32)
y = tf.placeholder(dtype=tf.float32)

a = tf.Variable(1.0,name='a')#参数初始化
b = tf.Variable(0.0,name='b')

pre = a*x+b #要求解的方程
model_loss = tf.nn.l2_loss(pre-y)
optimizer=tf.train.AdamOptimizer(learning_rate=0.1).minimize(model_loss)
X=np.asarray([1,2,3,4,5,6,7,8,9,10])
print(X)
Y=2*X+3
print(Y)
saver=tf.train.Saver()#创建Saver的对象saver
with tf.Session() as sess:
    epoches=200
    sess.run(tf.global_variables_initializer())
    for i in range(epoches):
        loss=0
        for i in range(len(X)):
            sess.run(optimizer,{x:X[i],y:Y[i]})
            loss+=sess.run(model_loss,{x:X[i],y:Y[i]})
        loss/=len(X)
        print('loss',loss)
    saver.save(sess,'model/my_model')#保存模型,指明保存的路径
    print(sess.run(a))

saver是创建的Saver的一个对象,Saver的save方法的功能是保存模型,其参数如下:

def save(self,
           sess,
           save_path,
           global_step=None,
           latest_filename=None,
           meta_graph_suffix="meta",
           write_meta_graph=True,
           write_state=True,
           strip_default_attrs=False)

其中sess是用于保存模型的session,而save_path指明模型保存的路径。执行save方法后,在保存模型的路径中会自动生成一个新的文件夹‘model’,其中的文件有:

文件说明:

checkpoint:存放最新保存的checkpoint文件的记录

my_model.meta:这是个协议缓冲区,它保存了完整的Tensorflow图形,即所有变量、操作等

my_model.data:包含训练变量的文件,加载时主要加载这里面的变量

my_model.index:

这里值得注意的是:Tensorflow变量仅存在会话中,所以必须在一个会话中保存模型。

重新加载训练好的模型

如果你想重新加载训练好的模型进行测试或继续训练,有两种方式可以实现重用模型的目的:

第一种:你可以重新构建一个模型,通过以下方式将之前保存的模型参数加载进来,记住,如果新建的网络的参数名称与要加载的模型的参数名称相同,则会自动用要加载的模型的参数覆盖你新建的模型的参数

# _*_coding:utf-8 _*_
#@Time :2019-01-14 15:48
#@Author :YY
#Describe :重载模型后,原来变量的值会被覆盖
import tensorflow as tf

import numpy as np
x = tf.placeholder( dtype=tf.float32)
y = tf.placeholder(dtype=tf.float32)

a = tf.Variable(1.0,name='a')#参数初始化
b = tf.Variable(0.0,name='b')

saver=tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, 'model/my_model')#加载模型
    a=tf.get_default_graph().get_tensor_by_name("a:0")
    print(sess.run(a))

如上述代码所示,我们初始化参数a的值为1.0,而当我们加载之前训练好的模型后,a的值变成2.0,即我们已经训练好的值,这说明同名的变量其值会被覆盖。我们还可以直接加载模型

第二种:通过tf.train.import()方法来直接重建模型。我们已经在my_model.meta文件中保存了构建的网络,可以通过以下方式重新创建这个网络(相当于复制粘贴)

saver = tf.train.import_meta_graph('model/my_model.meta')

但import_meta_graph只是将.meta文件中定义的网络附加到当前图中,此时还未将参数的值加载进来。可以通过以下方式将参数加载进来:

# _*_coding:utf-8 _*_
#@Time :2019-01-14 15:48
#@Author :YY 
#Describe :直接加载模型,不用再手动创建
import tensorflow as tf
with tf.Session() as sess:
    saver = tf.train.import_meta_graph('model/my_model.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./model'))
    print(sess.run('a:0'))

我们可以看到,变量a已经被加载进来。我们还可以使用graph.get_tensor_by_name()方法来操作这个保存的模型:

 a=graph.get_tensor_by_name('a:0')
 print(sess.run(a))

至此,Saver类的基本用法介绍完毕,以后会继续补充。

 

 

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值