在学习tensorflow时,mnist数据集就相当于是程序界的helloworld,所以先跑一下mnist数据集,我采用rnn(循环神经网络)的方式处理。第一步就是要进行自动下载和安装 MNIST 到 TensorFlow 的 python 源码input_data。
inptu_data.py如下:
#!/usr/bin/Python # -*- coding: utf-8 -*- # Copyright 2015 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Functions for downloading and reading MNIST data.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import gzip import os import tensorflow.python.platform import numpy from six.moves import urllib from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf from tensorflow.examples.tutorials.mnist import mnist SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' def maybe_download(filename, work_directory): """Download the data from Yann's website, unless it's already here.""" if not os.path.exists(work_directory): os.mkdir(work_directory) filepath = os.path.join(work_directory, filename) if not os.path.exists(filepath): filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) statinfo = os.stat(filepath) print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') return