【深度学习】Tensorflow模型保存与恢复

本文介绍了如何使用tf.train.Saver()在TensorFlow中保存和恢复模型。通过示例代码展示了在训练过程中,模型在不同阶段的损失值变化,并详细说明了Saver对象的定义、保存和恢复操作的步骤及注意事项。

tf.train.Saver()的定义与使用

Saver对象:用于在tf中保存,恢复Session
定义

model_path="/tmp/model.ckpt"
saver=tf.train.Saver()

Saver保存操作:saver.save(sess,model_path)

save_path=saver.save(sess,model_path)

Saver恢复操作:saver.restore(sess,save_path)

saver.restore(sess,model_path)

注意事项:
1.tf.train.Saver()定义在Session之前
2.saver.save()和saver.restore()都在Session里进行

tf.train.Saver()使用代码示例

# -*- coding: utf-8 -*-
"""
Created on Wed Jul 19 22:59:41 2017
@author: ZMJ
"""
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
print "Package Loaded"

np.random.seed(1)
def f(x,weight,bias):
  return x*weight+bias

Wref=0.7
Bref=-0.1
n=20
noise_var=0.05
train_X=np.random.random((n,1))
ref_Y=f(train_X,Wref,Bref)
train_Y=ref_Y+noise_var*np.random.randn(n,1)
model_path="/tmp/linear_model.ckpt"

lr=0.01
epochs=5000
display_step=250
n_samples=train_X.size

plt.subplot(121)
plt.axis("equal")
plt.plot(train_X[:,0],ref_Y[:,0],"ro",label="Original Data")
plt.plot(train_X[:,0],train_Y[:,0],"bo",label="Training Data")
plt.title("Sactter Plot of Data")
plt.legend(loc="lower right")

weight=tf.Variable(np.random.randn(),name="weight")
bias=tf.Variable(np.random.randn(),name="bias")
x=tf.placeholder(tf.float32,shape=[n_samples,1],name="input")
y=tf.placeholder(tf.float32,shape=[n_samples,1],name="output")

"""
Model
"""
pred=x*weight+bias
cost=tf.reduce_mean(tf.pow(pred-y,2))
optimizer=tf.train.GradientDescentOptimizer(lr).minimize(cost)
init=tf.global_variables_initializer()

"""
Saver Defination
"""
saver=tf.train.Saver()

"""
Run Model in First Session
"""
with tf.Session() as sess:
  sess.run(init)
  for epoch in range(500):
    l=sess.run(optimizer,feed_dict={x:train_X,y:train_Y})
    if epoch%display_step==0:
      c=sess.run(cost,feed_dict={x:train_X,y:train_Y})
      print "Epoch %s .Cost=%s"%(epoch,c)
  print "First Session Compelted!"

  save_path=saver.save(sess,model_path)
  print "Save Completed,Save Path = %s"%save_path

"""
Run Model in Second Session
"""
with tf.Session() as sess:
  #sess.run(init)
  saver.restore(sess,model_path)
  print "Model Restored From %s"%model_path

  for epoch in range(epochs-500):
    l=sess.run(optimizer,feed_dict={x:train_X,y:train_Y})
    if epoch%display_step==0:
      c=sess.run(cost,feed_dict={x:train_X,y:train_Y})
      print "Epoch %s .Cost=%s"%(epoch,c)

  print "Second Session Compelted!"
  save_path=saver.save(sess,model_path)
  print "Save Completed,Save Path = %s"%save_path

  Wop=sess.run(weight)
  Bop=sess.run(bias)
  fop=f(train_X,Wop,Bop)      
  plt.subplot(122)
  plt.plot()
  plt.plot(train_X[:,0],ref_Y[:,0],"ro",label="Original Data")
  plt.plot(train_X[:,0],train_Y[:,0],"bo",label="Training Data")
  plt.plot(train_X[:,0],fop[:,0],"k-",label="Predicted Line")
  plt.title("Predicted Line")
  plt.legend(loc="lower right")
  plt.show()

打印的日志:

Epoch 0 .Cost=0.269742
Epoch 250 .Cost=0.0531464
First Session Compelted!
Save Completed,Save Path = /tmp/linear_model.ckpt
Model Restored From /tmp/linear_model.ckpt
Epoch 0 .Cost=0.0323754
Epoch 250 .Cost=0.019944
Epoch 500 .Cost=0.0125031
Epoch 750 .Cost=0.00804937
Epoch 1000 .Cost=0.00538358
Epoch 1250 .Cost=0.00378797
Epoch 1500 .Cost=0.00283292
Epoch 1750 .Cost=0.00226127
Epoch 2000 .Cost=0.00191911
Epoch 2250 .Cost=0.00171431
Epoch 2500 .Cost=0.00159173
Epoch 2750 .Cost=0.00151836
Epoch 3000 .Cost=0.00147444
Epoch 3250 .Cost=0.00144815
Epoch 3500 .Cost=0.00143242
Epoch 3750 .Cost=0.001423
Epoch 4000 .Cost=0.00141736
Epoch 4250 .Cost=0.00141399
Second Session Compelted!
Save Completed,Save Path = /tmp/linear_model.ckpt

这里写图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值