代码变动部分:
logits, train_op, loss, maintain_averages_op, accuracy = simplenet(x,y,class_num)
saver = tf.train.Saver()
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
steps = epochs * len(img_data) // batchsize
for step in range(steps):
batch_inputs = inputs[step*batchsize:(step+1)*batchsize]
batch_labels = true_labels[step*batchsize:(step+1)*batchsize]
ls, acc, _ = sess.run([loss,accuracy,maintain_averages_op],feed_dict={x:batch_inputs,y:batch_labels})
if step%100 == 0:
saver.save(sess,model_dir,global_step=step)
print(' step: ', step, ' loss: ', ls, ' accuracy: ', acc)
模型加载:
tf.train.import_meta_graph('./models/model-0.meta')
for variable_name in tf.global_variables():
print(variable_name)
for tensor_name in tf.contrib.graph_editor.get_tensors(tf.get_default_graph()):
print(tensor_name)
# with tf.Session() as sess:
# for node in sess.graph_def.node:
# print(node)
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess,'./models/model-100')
#print(sess.run(tf.get_default_graph().get_tensor_by_name('Variable_5:0')))
本文详细介绍了如何使用TensorFlow进行深度学习模型的训练过程,包括定义损失函数、优化器和准确率指标,以及如何利用Saver类保存模型参数到本地文件。同时,也展示了如何重新加载模型进行后续的预测或继续训练。

被折叠的 条评论
为什么被折叠?



