一、两种存储和加载模型的方法
tensorflow的API提供了两种方式来存储和加载模型
1、生成检查点文件(checkpoint file),扩展名是.ckpt,通过在tf.train.Saver对象上调用Saver.save()生成,它包括权重和其它在程序中定义的变量,不包括图结构。如果需要在另一个程序中使用,需要重新构建图结构,并告诉tensorflow如何处理这些权重。
2、生成图协议文件(graph proto file),这是一个二进制文件,扩展名一般是.pb,用tf.train.write_graph()保存,只包含图形结构,不包含权重,然后使用tf.import_graph_def()来加载图形。
二、代码
import tensorflow as tf
import numpy as np
import random
import matplotlib.pyplot as plt
### step1 开启eager
# tf.enable_eager_execution()
### step2 MNIST数据下载的地址
# 数据的下载地址 https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
### step3 从MNIST数据文件导入MNIST数据
with np.load("./estimator/0808_estimator/MNIST_data/mnist.npz") as f:
x_train, y_train = f['x_train'], f['y_train']
x_train = x_train.reshape([-1, 28,28,1])
y_train = y_train.reshape([-1])
x_test, y_test = f['x_test'], f['y_test']
x_test = x_test.reshape([-1, 28,28,1])
y_test = y_test.reshape([-1])
### step4 生成数据集
# train_dataset = np.concatenate([x_train, y_train], axis=-1)
# test_dataset = np.concatenate([y_test, y_test], axis=-1)
X = tf.placeholder(tf.float32, [None, 28,28,1])
y = tf.placeholder(tf.int64, [None,])
# step5 Build the model
mnist_model = tf.keras.Sequential([
tf.keras.layers.Conv2D(16,[3,3], activation='relu'),
tf.keras.layers.Conv2D(16,[3,3], activation='relu'),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(10)
])
### step6 训练
logits = mnist_model(X, training=True)
y_pred = tf.argmax(logits, axis=-1)
def accuracy(y_pred, y):
return np.mean(np.equal(y_pred, y))
cost = tf.losses.sparse_softmax_cross_entropy(labels=y, logits=logits)
train_op = tf.train.AdamOptimizer(0.001).minimize(cost)
### 定义global_step,trainable是False
global_step = tf.Variable(0, name="global_step", trainable=False)
### tf.train.Saver()要放在所有变量定义完之后
### 参数max_to_keep:保存多少个最新的checkpoint文件,默认为5,即保存最近五个checkpoint文件;
saver = tf.train.Saver()
loss_all = []
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
start = global_step.eval()
print("start from: ", start)
Length = y_train.shape[0]
batch_times = int(Length/128)+1
i = 0
for _ in range(2):
np.random.seed(520)
# np.random.shuffle(train_dataset)
for batch in range(batch_times):
start = batch*128
end = min(Length, start+128)
image = x_train[start:end,...]
label = y_train[start:end]
### sess.run()中运行的变脸要放在一起,用[]框起来,feed_dict中的值要是numpy类型,不能是tf.Tensor类型
_, loss, y_ = sess.run([train_op, cost, y_pred], feed_dict={X:image, y:label})
### 要实现global_step的增加,必须要加上eval(),否则增加就不会成功
global_step.assign(i).eval()
if global_step.eval()%50==0:
print(global_step.eval(), " : ", loss)
print(accuracy(y_,label))
saver.save(sess, save_path="./model/mode.ckpt", global_step=global_step)
if global_step.eval()%500==0:
y_ = sess.run(y_pred, feed_dict={X: x_train, y: y_train})
print(accuracy(y_,y_train))
y_ = sess.run(y_pred, feed_dict={X:x_test, y:y_test})
print(accuracy(y_,y_test))
i += 1
loss_all.append(loss)
### step7 加载模型
saver = tf.train.Saver()
with tf.Session() as sess:
### 必须是先初始化再加载模型
sess.run(tf.initialize_all_variables())
ckpt = tf.train.get_checkpoint_state("./model")
if ckpt and ckpt.model_checkpoint_path:
print(ckpt.model_checkpoint_path)
### 其实加载模型就这一句话
saver.restore(sess,ckpt.model_checkpoint_path)
loss, y_ = sess.run([cost, y_pred], feed_dict={X:x_test, y:y_test})
print(accuracy(y_,y_test))
print(loss)