import numpy as np
import tensorflow as tf
复制代码
/anaconda3/envs/py35/lib/python3.5/importlib/_bootstrap.py:222: RuntimeWarning: compiletime version 3.6 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.5
return f(*args, **kwds)
/anaconda3/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
from ._conv import register_converters as _register_converters
复制代码
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/" , one_hot = True )
复制代码
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
复制代码
len(mnist.train.images), len(mnist.train.labels)
复制代码
(55000, 55000)
复制代码
len(mnist.test.images), len(mnist.test.labels)
复制代码
(10000, 10000)
复制代码
mnist.train.images[0 ]
复制代码
array([0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.3803922 , 0.37647063, 0.3019608 ,
0.46274513, 0.2392157 , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.3529412 , 0.5411765 , 0.9215687 ,
0.9215687 , 0.9215687 , 0.9215687 , 0.9215687 , 0.9215687 ,
0.9843138 , 0.9843138 , 0.9725491 , 0.9960785 , 0.9607844 ,
0.9215687 , 0.74509805, 0.08235294, 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.54901963,
0.9843138 , 0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 ,
0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 ,
0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 ,
0.7411765 , 0.09019608, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.8862746 , 0.9960785 , 0.81568635,
0.7803922 , 0.7803922 , 0.7803922 , 0.7803922 , 0.54509807,
0.2392157 , 0.2392157 , 0.2392157 , 0.2392157 , 0.2392157 ,
0.5019608 , 0.8705883 , 0.9960785 , 0.9960785 , 0.7411765 ,
0.08235294, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0.14901961, 0.32156864, 0.0509804 , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.13333334,
0.8352942 , 0.9960785 , 0.9960785 , 0.45098042, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.32941177, 0.9960785 ,
0.9960785 , 0.9176471 , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0.32941177, 0.9960785 , 0.9960785 , 0.9176471 ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.4156863 , 0.6156863 ,
0.9960785 , 0.9960785 , 0.95294124, 0.20000002, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0.09803922, 0.45882356, 0.8941177 , 0.8941177 ,
0.8941177 , 0.9921569 , 0.9960785 , 0.9960785 , 0.9960785 ,
0.9960785 , 0.94117653, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.26666668, 0.4666667 , 0.86274517,
0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 ,
0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 , 0.5568628 ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.14509805, 0.73333335,
0.9921569 , 0.9960785 , 0.9960785 , 0.9960785 , 0.8745099 ,
0.8078432 , 0.8078432 , 0.29411766, 0.26666668, 0.8431373 ,
0.9960785 , 0.9960785 , 0.45882356, 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0.4431373 , 0.8588236 , 0.9960785 , 0.9490197 , 0.89019614,
0.45098042, 0.34901962, 0.12156864, 0. , 0. ,
0. , 0. , 0.7843138 , 0.9960785 , 0.9450981 ,
0.16078432, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.6627451 , 0.9960785 ,
0.6901961 , 0.24313727, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.18823531,
0.9058824 , 0.9960785 , 0.9176471 , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0.07058824, 0.48627454, 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.32941177, 0.9960785 , 0.9960785 ,
0.6509804 , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0.54509807, 0.9960785 , 0.9333334 , 0.22352943, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.8235295 , 0.9803922 , 0.9960785 ,
0.65882355, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0.9490197 , 0.9960785 , 0.93725497, 0.22352943, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.34901962, 0.9843138 , 0.9450981 ,
0.3372549 , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.01960784,
0.8078432 , 0.96470594, 0.6156863 , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.01568628, 0.45882356, 0.27058825,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. ], dtype=float32)
复制代码
len(mnist.train.images[0 ])
复制代码
784
复制代码
import matplotlib.pyplot as plt
%matplotlib inline
复制代码
plt.imshow(mnist.train.images[1 ].reshape(28 ,28 ))
复制代码
<matplotlib.image.AxesImage at 0x1c27a1a550>
复制代码
mnist.train.labels[1 ]
复制代码
array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0.])
复制代码
x = tf.placeholder("float" , shape=[None , 784 ])
y = tf.placeholder("float" , shape=[None , 10 ])
复制代码
weight = tf.Variable(tf.truncated_normal([784 ,10 ]))
bias = tf.Variable(tf.truncated_normal([10 ]))
复制代码
combine_input = tf.matmul(x, weight) + bias
复制代码
pred = tf.nn.softmax(combine_input)
复制代码
loss = -tf.reduce_sum(y * tf.log(pred))
复制代码
train_step = tf.train.GradientDescentOptimizer(0.01 ).minimize(loss)
复制代码
sess = tf.Session()
sess.run(tf.global_variables_initializer())
复制代码
for i in range(1100 ):
batch = mnist.train.next_batch(50 )
sess.run(train_step, feed_dict={x : batch[0 ], y:batch[1 ]})
if i%50 == 0 :
print(sess.run(loss, feed_dict={x : batch[0 ], y:batch[1 ]}))
复制代码
329.698
114.416695
38.314323
20.213478
48.926674
26.53627
28.653086
43.464195
16.75724
39.731388
12.251608
32.379055
24.371075
18.137915
8.972845
24.207663
29.931976
7.5475473
10.576719
28.017235
14.364228
11.022556
复制代码
correct_pred = tf.equal(tf.argmax(pred, 1 ), tf.argmax(y, 1 ))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
复制代码
acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
print(acc)
复制代码
0.8813
复制代码