TensorFlow模型的存储、加载以及TensorBoard的使用

MNIST手写数字识别
本文介绍使用TensorFlow实现两种不同的MNIST手写数字识别模型:一种基于Softmax回归,另一种采用卷积神经网络(CNN)。通过训练和评估这些模型,展示了如何提高识别精度,并利用TensorBoard进行可视化。
部署运行你感兴趣的模型镜像
# -*- coding:utf-8 -*-

"""
#-------------------------------------
@Project:tf_example
@version:v1.0
@date:2018/3/8
-------------------------------------
# @Brief:

"""
import logging
import os
import tensorflow as tf
from tensorflow.contrib.layers import conv2d,max_pool2d
from tensorflow.examples.tutorials.mnist.input_data import read_data_sets
import numpy as np


def read_datasets():
    mnist=read_data_sets('./mnist_data')
    print '训练数据集大小:',np.shape(mnist.train.images)
    print '测试数据集大小:',np.shape(mnist.test.images)
    xtrain=tf.cast(mnist.train.images,tf.float32)
    ytrain=tf.cast(mnist.train.labels,tf.float32)
    xtest=tf.cast(mnist.test.images,tf.float32)
    ytest=tf.cast(mnist.test.labels,tf.float32)
    return mnist


def mnist_softmax():
    mnist=read_data_sets('./mnist_data',one_hot=True)
    xtest=mnist.test.images
    ytest=mnist.test.labels
    print '训练数据集大小:',np.shape(mnist.train.images)
    print '测试数据集大小:',np.shape(mnist.test.images)
    W=tf.Variable(tf.random_normal(shape=[784,10]))
    b=tf.Variable(tf.zeros(shape=[10]))
    x=tf.placeholder(dtype=tf.float32,shape=[None,784])
    y=tf.placeholder(dtype=tf.float32,shape=[None,10])
    y_=tf.nn.softmax(tf.matmul(x,W)+b)
    loss=tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=y_)
    train=tf.train.AdamOptimizer()
    train_step=train.minimize(loss)
    step=0
    sess=tf.InteractiveSession()
    init=tf.global_variables_initializer()
    init.run()
    while step<10000:
        batch = mnist.train.next_batch(200)
        train_step.run(feed_dict={x: batch[0], y: batch[1]})
        if step%100==0:
            accuray=tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(y,1),tf.arg_max(y_,1)),tf.float32))
            print '预测准确率:',accuray.eval(feed_dict={x:xtest,y:ytest})
        step+=1

def mnist_cnn():
    def gen_w(shape):
        init=tf.truncated_normal(shape=shape)
        return tf.Variable(init)
    def gen_b(shape):
        initial=tf.constant(0.1,shape=shape)
        return tf.Variable(initial)
    def conv2d(x,w):
        return tf.nn.conv2d(x,filter=w,strides=[1,1,1,1],padding='SAME')
    def pool2d(x):
        return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')

    mnist=read_data_sets('./mnist_data',one_hot=True)
    xtest=mnist.test.images
    ytest=mnist.test.labels
    print '训练数据集大小:',np.shape(mnist.train.images)
    print '测试数据集大小:',np.shape(mnist.test.images)

    x=tf.placeholder(shape=[None,784],dtype=tf.float32,name='x')
    x_=tf.reshape(x,[-1,28,28,1])
    y=tf.placeholder(shape=[None,10],dtype=tf.float32,name='y')

    keep_drop= tf.placeholder(dtype=tf.float32,name='keep_drop')

    w1=gen_w([5,5,1,32])
    b1=gen_b([32])
    conv1=tf.nn.relu(conv2d(x_,w1)+b1)
    pool1=pool2d(conv1)

    w2=gen_w([5,5,32,64])
    b2=gen_b([64])
    conv2=tf.nn.relu(conv2d(pool1,w2)+b2)
    pool2=pool2d(conv2)
    flat=tf.reshape(pool2,[-1,7*7*64])

    w3=gen_w([7*7*64,512])
    b3=gen_b([512])
    hfc1=tf.nn.relu(tf.matmul(flat,w3)+b3)
    hfc1_drop=tf.nn.dropout(hfc1,keep_prob=keep_drop)

    w4 = gen_w([512, 10])
    b4 = gen_b([10])

    y_ =tf.matmul(hfc1_drop, w4) + b4
    #为了加载模型时,根据名称加载操作
    tf.add_to_collection('y_', y_)

    loss=tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=y_)
    train=tf.train.RMSPropOptimizer(learning_rate=0.01)
    train_step=train.minimize(loss)
    tf.add_to_collection('train_step',train_step)

    accury = tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(y, 1), tf.arg_max(y_, 1)), dtype=tf.float32))

    sess=tf.InteractiveSession()
    saver=tf.train.Saver(max_to_keep=3)

    init=tf.global_variables_initializer()
    init.run()
    step=0

    #TensorBoard可视化,定义summary
    tf.summary.scalar('accury',accury)
    tf.summary.histogram('w1',w1)
    tf.summary.histogram('w2', w2)
    merge_summary_op = tf.summary.merge_all()
    if tf.gfile.Exists("./tmp/mnist_logs"):
        tf.gfile.DeleteRecursively("./tmp/mnist_logs")
    summary_writer = tf.summary.FileWriter('./tmp/mnist_logs', sess.graph)


    while step<100:
        batch=mnist.train.next_batch(50)
        train_step.run(feed_dict={x:batch[0],y:batch[1],keep_drop:0.2})
        #summary写入logs文件
        summary_str=sess.run(merge_summary_op,feed_dict={x: xtest, y: ytest, keep_drop: 1})
        summary_writer.add_summary(summary_str,step)
        #第一次存储网络结构
        saver.save(sess, './cpkdir/model.ckpt', global_step=step, write_meta_graph=True)
        if step%5==0:
            print accury.eval(feed_dict={x:xtest,y:ytest,keep_drop:1})
            #保存训练的checkpoint文件,不再存储网络结构
            saver.save(sess,'./cpkdir/model.ckpt',global_step=step,write_meta_graph=False)
        step+=1
    sess.close()

    # 根据检查点文件加载模型,仅仅根据model_checkpoint_path文件只能恢复权重参数,不能恢复网络结构
    # ckpt=tf.train.get_checkpoint_state('./cpkdir')
    # if ckpt and ckpt.model_checkpoint_path:
    #     print ckpt.model_checkpoint_path
    #     saver.restore(sess,ckpt.model_checkpoint_path)
    #
    # step=100
    # while step<120:
    #     batch=mnist.train.next_batch(50)
    #     train_step.run(feed_dict={x:batch[0],y:batch[1],keep_drop:0.2})
    #     if step%5==0:
    #         accury=tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(y,1),tf.arg_max(y_,1)),dtype=tf.float32))
    #         print accury.eval(feed_dict={x:xtest,y:ytest,keep_drop:1})
    #         saver.save(sess,'./cpkdir/model.ckpt',global_step=step)
    #     step+=1

def model_restore():
    """
    同时恢复网络结构和权重参数
    :return: 
    """
    sess = tf.InteractiveSession()
    ckpt=tf.train.get_checkpoint_state('./cpkdir')
    if ckpt and ckpt.model_checkpoint_path:
        print ckpt.model_checkpoint_path
        saver = tf.train.import_meta_graph('./cpkdir/model.ckpt-105.meta')
        saver.restore(sess,ckpt.model_checkpoint_path)
        print '加载成功'
        graph = tf.get_default_graph()
        step=100
        train_step=graph.get_collection('train_step')[0]
        y_ = graph.get_collection('y_')[0]
        y=graph.get_tensor_by_name('y:0')
        x=graph.get_tensor_by_name('x:0')
        keep_drop = graph.get_tensor_by_name('keep_drop:0')
        mnist = read_data_sets('./mnist_data', one_hot=True)
        xtest = mnist.test.images
        ytest = mnist.test.labels

        while step<120:
            batch=mnist.train.next_batch(50)
            train_step.run(feed_dict={x:batch[0],y:batch[1],keep_drop:0.2})
            if step%5==0:
                accury=tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(y,1),tf.arg_max(y_,1)),dtype=tf.float32))
                print accury.eval(feed_dict={x:xtest,y:ytest,keep_drop:1})
                saver.save(sess,'./cpkdir/model.ckpt',global_step=step)
            step+=1
        sess.close()


if __name__=="__main__":
    mnist_cnn()

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值