基于MNIST数据集的手写数字识别可以看作是机器学习领域的 “Hello World” 任务。
MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片及对应的标签,图片的大小为28×28,且只包含灰度值信息:
TensorFlow的一些入门知识主要参考了黄文坚所著的那本《TensorFlow实战》以及TensorFlow中文社区,里面已经将的非常详细了,所以就不再赘述。
Softmax Regression 是一个基本的多分类模型,我们用它对手写数字进行分类。当单纯地只使用Softmax Regression时,可以看作是一个没有隐含层的最浅的神经网络。
使用TensorFlow进行Softmax Regression分类的流程为:
- 定义算法公式,也就是神经网络forward时的计算。
- 定义loss,选定优化器,并指定优化器优化loss。
- 迭代地对数据进行训练。
- 在测试集或验证集上对准确率进行评测。
程序运行版本为:python-->3.7.3,tensorflow-->1.13.1
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# move warning
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
old_v = tf.logging.get_verbosity()
tf.logging.set_verbosity(tf.logging.ERROR)
# read data
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
# create model
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, shape=[None, 784]) # input images
y_ = tf.placeholder(tf.float32, shape=[None, 10]) # input labels
W = tf.Variable(tf.zeros([784, 10])) # weight
b = tf.Variable(tf.zeros([10])) # bias
y = tf.nn.softmax(tf.matmul(x, W) + b) # softmax regression
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y), reduction_indices=[1])) # loss function
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # optimization algorithm
tf.global_variables_initializer().run()
# train
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
train_step.run(feed_dict={x: batch_xs, y_: batch_ys})
# test
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print('Accuracy: ', accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))
tf.logging.set_verbosity(old_v)
程序是《TensorFlow实战》里面的,是可以正常运行的,准确率为92%左右。但可能因为版本的问题,出现了很多warning:
D:\Anaconda3\envs\python37\python.exe E:/PycharmProjects/Handwritten_Number_Recognition/softmax_regression.py
WARNING:tensorflow:From E:/PycharmProjects/Handwritten_Number_Recognition/softmax_regression.py:11: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From D:\Anaconda3\envs\python37\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Extracting MNIST_data/train-images-idx3-ubyte.gz
Please write your own downloading logic.
WARNING:tensorflow:From D:\Anaconda3\envs\python37\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
WARNING:tensorflow:From D:\Anaconda3\envs\python37\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
WARNING:tensorflow:From D:\Anaconda3\envs\python37\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From D:\Anaconda3\envs\python37\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
2019-06-22 22:03:21.339974: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX AVX2
2019-06-22 22:03:21.340911: I tensorflow/core/common_runtime/process_util.cc:71] Creating new thread pool with default inter op setting: 4. Tune using inter_op_parallelism_threads for best performance.
WARNING:tensorflow:From D:\Anaconda3\envs\python37\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From D:\Anaconda3\envs\python37\lib\site-packages\tensorflow\python\ops\math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
Accuracy: 0.9181
Process finished with exit code 0
这些warning不会影响程序的运行结果,但是如果看着不顺眼的话,可以加入一些语句将其屏蔽掉。
其中,这两行是屏蔽了CPU相关的警告:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
下面这几行用于屏蔽其它警告:
old_v = tf.logging.get_verbosity()
tf.logging.set_verbosity(tf.logging.ERROR)
tf.logging.set_verbosity(old_v)