本文属于学些tensorflow框架系列的文章,不是注重于算法~
基于之前博文中的工作,已经安装好tensorflow等等的配置工作,开始学习tensorflow框架的使用,本文参考了以下链接,致以敬意
http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html
https://github.com/GoogleCloudPlatform/tensorflow-without-a-phd
(建议先读完整个文章再进行实际操作)
1、直接上参考代码:
(1)执行函数代码
import tensorflow as tf
import mnist_data_process
data_path = r'./mnist_data'
def mnist_v1():
input_data = tf.placeholder('float', [None, 784])
input_labels = tf.placeholder('float', [None, 10])
# TODO:1、构建图
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
out_put = tf.nn.softmax(tf.matmul(input_data, W) + b)
# TODO:2、定义损失函数
cross_entropy_loss = -tf.reduce_sum(input_labels * tf.log(out_put))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy_loss)
# TODO:3、始化所有的参数
sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)
# TODO:4、准备数据
# 这里是读入tfrecord的数据并且转化为正常的数据矩阵
mnist_data = mnist_data_process.read_data_sets(data_path, one_hot=True, reshape=True)
# TODO:5、开始训练,定义训练的迭代轮数
iter_num = 1000
for index in range(iter_num):
batch_xs, batch_ys = mnist_data.train.next_batch(100)
sess.run(train_step, feed_dict={input_data: batch_xs, input_labels: batch_ys})
# TODO:6、使用训练完的参数对测试数据进行测试
test_data = mnist_data.test.images
test_labels = tf.argmax(mnist_data.test.labels, 1)
test_y = sess.run(out_put, feed_dict={input_data: test_data})
predict_labels = tf.argmax(test_y, 1)
correct_prediction = tf.equal(predict_labels, test_labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print(sess.run(accuracy))
if __name__ == '__main__':
mnist_v1()
(2)第二级函数代码:mnist_data_process.py
# encoding: UTF-8
# Copyright 2018 Google.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
import numpy as np
from mnist_input_data import load_mnist_data
from mnist_input_data import load_dataset
# This loads entire dataset to an in-memory numpy array.
# This uses tf.data.Dataset to avoid duplicating code.
# Normally, if you already have a tf.data.Dataset, loading
# it to memory is not useful. The goal here is educational:
# teach about neural network basics without having to
# explain tf.data.Dataset now. The concept will be introduced
# later.
# The proper way of using tf.data.Dataset is to call
# features, labels = tf_dataset.make_one_shot_iterator().get_next()
# and then to use "features" and "labels" in your Tensorflow
# model directly. These tensorflow nodes, when executed, will
# automatically trigger the loading of the next batch of data.
# The sample that uses tf.data.Dataset correctly is in mlengine/trainer.
class MnistData(object):
def __init__(self, tf_dataset, one_hot, reshape):
self.pos = 0
self.images = None
self.labels = None
# load entire Dataset into memory by chunks of 10000
tf_dataset = tf_dataset.batch(10000)
tf_dataset = tf_dataset.repeat(1)
features, labels = tf_dataset.make_one_shot_iterator().get_next()
if not reshape:
features = tf.reshape(features, [-1, 28, 28, 1])
if one_hot:
labels = tf.one_hot(labels, 10)
with tf.Session() as sess:
while True:
try:
feats, labs = sess.run([features, labels])
self.images = feats if self.images is None else np.concatenate([self.images, feats])
self.labels = labs if self.labels is None else np.concatenate([self.labels, labs])
except tf.errors.OutOfRangeError:
break
def next_batch(self, batch_size):
if self.pos+batch_size > len(self.images) or self.pos+batch_size > len(self.labels):
self.pos = 0
res = (self.images[self.pos:self.pos+batch_size], self.labels[self.pos:self.pos+batch_size])
self.pos += batch_size
return res
class Mnist(object):
def __init__(self, train_dataset, test_dataset, one_hot, reshape):
self.train = MnistData(train_dataset, one_hot, reshape)
self.test = MnistData(test_dataset, one_hot, reshape)
def read_data_sets(data_dir, one_hot, reshape):
train_images_file, train_labels_file, test_images_file, test_labels_file = load_mnist_data(data_dir)
train_dataset = load_dataset(train_images_file, train_labels_file)
train_dataset = train_dataset.shuffle(60000)
test_dataset = load_dataset(test_images_file, test_labels_file)
mnist = Mnist(train_dataset, test_dataset, one_hot, reshape)
return mnist
(3)第三级函数代码:mnist_input_data.py
import os
import gzip
import shutil
from six.moves import urllib
from tensorflow.python.platform import gfile
import tensorflow as tf
from tensorflow.python.platform import tf_logging as logging
logging.set_verbosity(logging.INFO)
logging.log(logging.INFO, "Tensorflow version " + tf.__version__)
def maybe_download_and_ungzip(filename, work_directory, source_url):
if filename[-3:] == ".gz":
unzipped_filename = filename[:-3]
else:
unzipped_filename = filename
if not gfile.Exists(work_directory):
gfile.MakeDirs(work_directory)
filepath = os.path.join(work_directory, filename)
unzipped_filepath = os.path.join(work_directory, unzipped_filename)
if not gfile.Exists(unzipped_filepath):
if not os._exists(filepath):
urllib.request.urlretrieve(source_url, filepath)
if not filename == unzipped_filename:
with gzip.open(filepath, 'rb') as f_in:
with open(unzipped_filepath, 'wb') as f_out: # remove .gz
shutil.copyfileobj(f_in, f_out)
with gfile.GFile(filepath) as f:
size = f.size()
print('Successfully downloaded and unzipped', filename, size, 'bytes.')
return unzipped_filepath
def read_label(tf_bytestring):
label = tf.decode_raw(tf_bytestring, tf.uint8)
return tf.reshape(label, [])
def read_image(tf_bytestring):
image = tf.decode_raw(tf_bytestring, tf.uint8)
return tf.cast(image, tf.float32)/256.0
def load_mnist_data(data_dir):
SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
train_images_file = 'train-images-idx3-ubyte.gz'
local_train_images_file = maybe_download_and_ungzip(train_images_file, data_dir, SOURCE_URL + train_images_file)
train_labels_file = 'train-labels-idx1-ubyte.gz'
local_train_labels_file = maybe_download_and_ungzip(train_labels_file, data_dir, SOURCE_URL + train_labels_file)
test_images_file = 't10k-images-idx3-ubyte.gz'
local_test_images_file = maybe_download_and_ungzip(test_images_file, data_dir, SOURCE_URL + test_images_file)
test_labels_file = 't10k-labels-idx1-ubyte.gz'
local_test_labels_file = maybe_download_and_ungzip(test_labels_file, data_dir, SOURCE_URL + test_labels_file)
return local_train_images_file, local_train_labels_file, local_test_images_file, local_test_labels_file
# Load a tf.data.Dataset made of interleaved images and labels
# from an image file and a labels file.
def load_dataset(image_file, label_file):
imagedataset = tf.data.FixedLengthRecordDataset(image_file, 28*28,
header_bytes=16, buffer_size=1024*16).map(read_image)
labelsdataset = tf.data.FixedLengthRecordDataset(label_file, 1,
header_bytes=8, buffer_size=1024*16).map(read_label)
dataset = tf.data.Dataset.zip((imagedataset, labelsdataset))
return dataset
2、总结
这只是一个简单使用softmax来对手写字体进行分类的过程,这里相当于将原始灰度图像的像素值作为了特征,softmax相当于一个激活的操作,实验测试结果的acc在91%左右,实际操作该代码主要为了理解tensorflow的运行机制,个人认为从头学习tensorflow的时候直接拿复杂的深度网络模型来调试分析会很容易被搞晕,先用一个简单的例子来学习是比较好的,然后再一步一步的加深。
ps.如果代码里面的路径下载数据集下载不下来就先在http://www.tensorfly.cn/tfdoc/tutorials/mnist_download.html这里下载好数据集放在对应的路径下,然后再运行上面的代码。