话不多说,下面是源码,参考了网上找的资料.模型结构:输入->h1->h2->h3->输出,尽管在mnist测试准确率达到95%,但实际使用画图程序手写10个数字进行识别, 最多只能识别出7个,跟你笔画粗细,数字是否居中,数字大小都有关系
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
# 加载mnist数据集
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
# 对数据进行预处理
def preprocess(x, y):
# 将x映射到[0,1]范围内
x = tf.cast(x, dtype=tf.float32) / 255.
y = tf.cast(y, dtype=tf.int32)
return (x, y)
# 构建dataset对象,对数据进行打乱,批处理等操作
train_data = tf.data.Dataset.from_tensor_slices(
(x_train, y_train)).shuffle(10000).batch(512)
train_data = train_data.map(preprocess)
test_data = tf.data.Dataset.from_tensor_slices(
(x_test, y_test)).shuffle(1000).batch(512)
test_data = test_data.map(preprocess)
# 建立权重和偏置,学习效率等训练中用到的变量
w1 = tf.Variable(tf.random.truncated_normal([28*28, 256], stddev=0.1))
b1 = tf.Variable(tf.zeros(256))
w2 = tf.Variable(</