如需转载,请注明出处,欢迎加入深度学习群 255568483
mnist数据集基本上可以算做是数据科学里的hello world程序。
tensorflow官方文档有一个例子可以对mnist数据集做分析,对于刚入门者,可以通过简单的模型达到92%的准确率。
以下代码分析是对应所做的分析,请看对应的中文注释
#!/usr/bin/env python
# encoding: utf-8
# 文件说明:
# 原始参考URL:
# 功能说明:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
FLAGS = None
def main(_):
# 获取数据集
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
# 定义输入和输出的占位符
x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
# 定义通用函数
def weight_variable(shape):
# 截断正态分布 标准方差为0.1
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(init