"微信公众号"

思考一个问题:
我们搭建好一个神经网络,用大量的数据训练好之后,肯定希望保存神经网络里面的参数,用于下次加载。那我们该怎么做呢?
TensorFlow为我们提供了Saver来保存和加载神经网络的参数。
一、保存
(1)import所需的模块,然后建立神经网络当中的W和b,并初始化变量。
import tensorflow as tf
import numpy as np
# Save to file
# remember to define the same dtype and shape when restore
W = tf.Variable([[1,2,3],[3,4,5]],dtype=tf.float32,name="weights")
b = tf.Variable([[1,2,3]],dtype=tf.float32,name="biases")
init = tf.global_variables_initializer()(2)保存时,首先要建立一个tf.train.Saver()用来保存,提取变量。再创建一个名为my_net的文件夹,用这个saver来保存变量到这个目录“my_net/save_net.ckp”。
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
save_path = saver.save(sess,"F:/my_net/save_net.ckpt")
print("Save to path:",save_path)(3)效果图:

(4)给出保存参数的完整代码。
import tensorflow as tf
import numpy as np
# Save to file
# remember to define the same dtype and shape when restore
W = tf.Variable([[1,2,3],[3,4,5]],dtype=tf.float32,name="weights")
b = tf.Variable([[1,2,3]],dtype=tf.float32,name="biases")
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
save_path = saver.save(sess,"F:/my_net/save_net.ckpt")
print("Save to path:",save_path)二、提取
(1)提取时,先建立临时的W和b容器。找到文件目录,并用saver.restore()提取变量。
#conding:utf-8
import tensorflow as tf
import numpy as np
# restore variables
# 先建立W,b的容器
# redefine the same shape and same type for your variables
W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32,name="weights")
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32,name="biases")
# not need init step
saver = tf.train.Saver()
with tf.Session() as sess:
# 提取变量
saver.restore(sess,"F:/my_net/save_net.ckpt")
print("weights:",sess.run(W))
print("biases:",sess.run(b))观看视频笔记:https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/5-06-save/

本文详细介绍如何使用TensorFlow中的Saver组件来保存和加载神经网络的参数。首先通过实例演示了神经网络参数的保存过程,包括定义变量、初始化及使用Saver进行保存。接着介绍了如何重新定义相同形状和类型的变量并从中恢复之前保存的参数。
8万+

被折叠的 条评论
为什么被折叠?



