tensorflow 模型保存:保存模型其实就是保存sess
# 关键句子
saver=tf.train.Saver()
saver.save(sess, model_path)
# 在语境中
import tensorflow as tf
def Cnn(x_image):
...
...
result = tf.nn.softmax(logit, name="final_tensor")
return logit
y = tf.placeholder(tf.float32, [None, 2],name="y")
x = tf.placeholder(tf.float32, [None, 160,160],name="x")
x_image = tf.reshape(x, [-1,160,160,1]) # 将输入按照 conv2d中input的格式来reshape,reshape
# 执行dep
y_predict= Cnn(x_image)
# 定义损失函数
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_predict, labels=y))
train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)
# 定义预测函数,准确率验证
correct_prediction = tf.equal(tf.argmax(y_predict,1),tf.argmax(y,1))
accuray = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
model_path ="./model/model.ckpt"
saver=tf.train.Saver()
with tf.Session() as sess:
# 链接sessiion ,训练模型
sess.run(tf.global_variables_initializer())
test_datas,test_label = get_batch(0,'test')
for i in range(n_item):
for batch in range(n_betch):
train_datas,train_lable = get_batch(batch,'train')
_=sess.run(train_step,feed_dict={x:train_datas,y:train_lable})
train_acc,loss_val_t = sess.run([accuray,cross_entropy], feed_dict={x: train_datas, y: train_lable})
test_acc,loss_val_e = sess.run([accuray,cross_entropy], feed_dict={x: test_datas, y: test_label})
print("step %d, training accuracy %g test accuracy %g lossTrain %g lossTest %g" % (i, train_acc,test_acc,loss_val_t,loss_val_e))
saver.save(sess, model_path)
tensorflow模型加载:加载模型其实就是加载sess
saver = tf.train.import_meta_graph('./model/model.ckpt.meta')
saver.restore(sess,tf.train.latest_checkpoint('./model/'))
# 在语境中
import tensorflow as tf
import numpy as np
import cv2
def getImage():
img = cv2.imread("D:\yun\FolderMark_1349\\19.jpg")
img = cv2.resize(img,(160,160))
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
im = gray_image/255.0
data = im[np.newaxis,:] # 增加维度
return data
data = getImage()
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./model/model.ckpt.meta')
saver.restore(sess,tf.train.latest_checkpoint('./model/'))
graph = tf.get_default_graph()
x = graph.get_tensor_by_name("x:0")
feed_dict = {x:data}
logits = graph.get_tensor_by_name("final_tensor:0")
classification_result = sess.run(logits,feed_dict)
#打印出预测矩阵
print(classification_result)