debug_mnist.py是tensorflow 官网上的深度学习例子。吾习之,颇有心得,故分享之;如有错误,敬请提醒。
tensorflow环境:
本人使用的是python2.7和tensorflow1.0.0。在ubuntu16.10中安装virtualenv,并在其环境下安装tensorflow。virtualenv的安装路径是/home/chen/tensorflow,然后按照官网上的步骤会将tensorflow安装在/home/chen/tensorflow/local/lib/python2.7/site-packages/tensorflow。
例子入口:
该例子的main函数在~/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/debug/examples/debug_main.py
代码讲解:
#该代码片主要是用于指定参数
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--max_steps",
type=int,
default=10,
help="Number of steps to run trainer.")
parser.add_argument(
"--train_batch_size",
type=int,
default=100,
help="Batch size used during training.")
parser.add_argument(
"--learning_rate",
type=float,
default=0.025,
help="Initial learning rate.")
parser.add_argument(
"--data_dir",
type=str,
default="/tmp/mnist_data",
help="Directory for storing data")
parser.add_argument(
"--ui_type",
type=str,
default="curses",
help="Command-line user interface type (curses | readline)")
parser.add_argument(
"--fake_data",
type="bool",
nargs="?",
const=True,
default=False,
help="Use fake MNIST data for unit testing")
parser.add_argument(
"--debug",
type="bool",
nargs="?",
const=True,
default=False,
help="Use debugger to track down bad values during training")
FLAGS, unparsed = parser.parse_known_args()
其中有几个接口对于初次接触python的人不太熟悉,故在下面进行解释:
argparse.ArgumentParser()
该接口是生成一个解析命令的对象
FLAGS, unparsed = parser.parse_known_args()
获取命令对象到FLAG和unparsed中,其中FLAG是add过的命令,unparsed是没有添加过的命令
parser.add_argument(
"--max_steps",
type=int,
default=10,
help="Number of steps to run trainer.")
该接口是支持解析--max_step命令。
例子如下:

例子结果:

parser.register("type", "bool", lambda v: v.lower() == "true")
注册一个action,在tensorflow中主要是为了支持下面那句话中类型支持bool型(温馨提示:主要type="bool"中的bool是带双引号的)。
parser.add_argument(
"--debug",
type="bool",
nargs="?",
const=True,
default=False,
help="Use debugger to track down bad values during training")
另外提供给python初学者一个小技巧,可以通过以下方式进行查询文档:

获取完之后就开始tensorflow的正文喽。进入
tf.app.run(main=main,argv=[sys.argv[0] + unparsed)
这句话就跳转到tensorflow的深度学习的正文喽