# -*- coding:utf-8 -*-
"""
#-------------------------------------
@Project:tf_example
@version:v1.0
@date:2018/3/8
-------------------------------------
# @Brief:
"""
import logging
import os
import tensorflow as tf
from tensorflow.contrib.layers import conv2d,max_pool2d
from tensorflow.examples.tutorials.mnist.input_data import read_data_sets
import numpy as np
def read_datasets():
mnist=read_data_sets('./mnist_data')
print '训练数据集大小:',np.shape(mnist.train.images)
print '测试数据集大小:',np.shape(mnist.test.images)
xtrain=tf.cast(mnist.train.images,tf.float32)
ytrain=tf.cast(mnist.train.labels,tf.float32)
xtest=tf.cast(mnist.test.images,tf.float32)
ytest=tf.cast(mnist.test.labels,tf.float32)
return mnist
def mnist_softmax():
mnist=read_data_sets('./mnist_data',one_hot=True)
xtest=mnist.test.images
ytest=mnist.test.labels
print '训练数据集大小:',np.shape(mnist.train.images)
print '测试数据集大小:',np.shape(mnist.test.images)
W=tf.Variable(tf.random_normal(shape=[784,10]))
b=tf.Variable(tf.zeros(shape=[10]))
x=tf.placeholder(dtype=tf.float32,shape=[None,784])
y=tf.placeholder(dtype=tf.float32,shape=[None,10])
y_=tf.nn.softmax(tf.matmul(x,W)+b)
loss=tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=y_)
train=tf.train.AdamOptimizer()
train_step=train.minimize(loss)
step=0
sess=tf.InteractiveSession()
init=tf.global_variables_initializer()
init.run()
while step<10000:
batch = mnist.train.next_batch(200)
train_step.run(feed_dict={x: batch[0], y: batch[1]})
if step%100==0:
accuray=tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(y,1),tf.arg_max(y_,1)),tf.float32))
print '预测准确率:',accuray.eval(feed_dict={x:xtest,y:ytest})
step+=1
def mnist_cnn():
def gen_w(shape):
init=tf.truncated_normal(shape=shape)
return tf.Variable(init)
def gen_b(shape):
initial=tf.constant(0.1,shape=shape)
return tf.Variable(initial)
def conv2d(x,w):
return tf.nn.conv2d(x,filter=w,strides=[1,1,1,1],padding='SAME')
def pool2d(x):
return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
mnist=read_data_sets('./mnist_data',one_hot=True)
xtest=mnist.test.images
ytest=mnist.test.labels
print '训练数据集大小:',np.shape(mnist.train.images)
print '测试数据集大小:',np.shape(mnist.test.images)
x=tf.placeholder(shape=[None,784],dtype=tf.float32,name='x')
x_=tf.reshape(x,[-1,28,28,1])
y=tf.placeholder(shape=[None,10],dtype=tf.float32,name='y')
keep_drop= tf.placeholder(dtype=tf.float32,name='keep_drop')
w1=gen_w([5,5,1,32])
b1=gen_b([32])
conv1=tf.nn.relu(conv2d(x_,w1)+b1)
pool1=pool2d(conv1)
w2=gen_w([5,5,32,64])
b2=gen_b([64])
conv2=tf.nn.relu(conv2d(pool1,w2)+b2)
pool2=pool2d(conv2)
flat=tf.reshape(pool2,[-1,7*7*64])
w3=gen_w([7*7*64,512])
b3=gen_b([512])
hfc1=tf.nn.relu(tf.matmul(flat,w3)+b3)
hfc1_drop=tf.nn.dropout(hfc1,keep_prob=keep_drop)
w4 = gen_w([512, 10])
b4 = gen_b([10])
y_ =tf.matmul(hfc1_drop, w4) + b4
#为了加载模型时,根据名称加载操作
tf.add_to_collection('y_', y_)
loss=tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=y_)
train=tf.train.RMSPropOptimizer(learning_rate=0.01)
train_step=train.minimize(loss)
tf.add_to_collection('train_step',train_step)
accury = tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(y, 1), tf.arg_max(y_, 1)), dtype=tf.float32))
sess=tf.InteractiveSession()
saver=tf.train.Saver(max_to_keep=3)
init=tf.global_variables_initializer()
init.run()
step=0
#TensorBoard可视化,定义summary
tf.summary.scalar('accury',accury)
tf.summary.histogram('w1',w1)
tf.summary.histogram('w2', w2)
merge_summary_op = tf.summary.merge_all()
if tf.gfile.Exists("./tmp/mnist_logs"):
tf.gfile.DeleteRecursively("./tmp/mnist_logs")
summary_writer = tf.summary.FileWriter('./tmp/mnist_logs', sess.graph)
while step<100:
batch=mnist.train.next_batch(50)
train_step.run(feed_dict={x:batch[0],y:batch[1],keep_drop:0.2})
#summary写入logs文件
summary_str=sess.run(merge_summary_op,feed_dict={x: xtest, y: ytest, keep_drop: 1})
summary_writer.add_summary(summary_str,step)
#第一次存储网络结构
saver.save(sess, './cpkdir/model.ckpt', global_step=step, write_meta_graph=True)
if step%5==0:
print accury.eval(feed_dict={x:xtest,y:ytest,keep_drop:1})
#保存训练的checkpoint文件,不再存储网络结构
saver.save(sess,'./cpkdir/model.ckpt',global_step=step,write_meta_graph=False)
step+=1
sess.close()
# 根据检查点文件加载模型,仅仅根据model_checkpoint_path文件只能恢复权重参数,不能恢复网络结构
# ckpt=tf.train.get_checkpoint_state('./cpkdir')
# if ckpt and ckpt.model_checkpoint_path:
# print ckpt.model_checkpoint_path
# saver.restore(sess,ckpt.model_checkpoint_path)
#
# step=100
# while step<120:
# batch=mnist.train.next_batch(50)
# train_step.run(feed_dict={x:batch[0],y:batch[1],keep_drop:0.2})
# if step%5==0:
# accury=tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(y,1),tf.arg_max(y_,1)),dtype=tf.float32))
# print accury.eval(feed_dict={x:xtest,y:ytest,keep_drop:1})
# saver.save(sess,'./cpkdir/model.ckpt',global_step=step)
# step+=1
def model_restore():
"""
同时恢复网络结构和权重参数
:return:
"""
sess = tf.InteractiveSession()
ckpt=tf.train.get_checkpoint_state('./cpkdir')
if ckpt and ckpt.model_checkpoint_path:
print ckpt.model_checkpoint_path
saver = tf.train.import_meta_graph('./cpkdir/model.ckpt-105.meta')
saver.restore(sess,ckpt.model_checkpoint_path)
print '加载成功'
graph = tf.get_default_graph()
step=100
train_step=graph.get_collection('train_step')[0]
y_ = graph.get_collection('y_')[0]
y=graph.get_tensor_by_name('y:0')
x=graph.get_tensor_by_name('x:0')
keep_drop = graph.get_tensor_by_name('keep_drop:0')
mnist = read_data_sets('./mnist_data', one_hot=True)
xtest = mnist.test.images
ytest = mnist.test.labels
while step<120:
batch=mnist.train.next_batch(50)
train_step.run(feed_dict={x:batch[0],y:batch[1],keep_drop:0.2})
if step%5==0:
accury=tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(y,1),tf.arg_max(y_,1)),dtype=tf.float32))
print accury.eval(feed_dict={x:xtest,y:ytest,keep_drop:1})
saver.save(sess,'./cpkdir/model.ckpt',global_step=step)
step+=1
sess.close()
if __name__=="__main__":
mnist_cnn()
TensorFlow模型的存储、加载以及TensorBoard的使用
MNIST手写数字识别
最新推荐文章于 2023-07-25 00:10:04 发布
本文介绍使用TensorFlow实现两种不同的MNIST手写数字识别模型:一种基于Softmax回归,另一种采用卷积神经网络(CNN)。通过训练和评估这些模型,展示了如何提高识别精度,并利用TensorBoard进行可视化。
部署运行你感兴趣的模型镜像
您可能感兴趣的与本文相关的镜像
TensorFlow-v2.15
TensorFlow
TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型
849

被折叠的 条评论
为什么被折叠?



