from tensorflow.examples.tutorials.mnist.input_data import read_data_sets
import numpy as np
import cv2
import tensorflow as tf
class Config:
def __init__(self):
self.sample_path = '../deeplearning_ai12/p07_mnist/MNIST_data'
self.lr = 0.001
self.epoches = 200
self.batch_size = 50
self.eps = 1e-10
self.base_filters = 16 # should be 32 at least
self.name = 'mnist08'
self.save_path = '../models/{name}/{name}'.format(name=self.name)
self.logdir = '../logs/{name}'.format(name=self.name)
class Tensors:
def __init__(self, config: Config):
self.config = config
self.x = tf.placeholder(tf.float32, [None, 784], 'x')
logits = self.get_logits(self.x) # [-1, 10]
self.y_predict = tf.argmax(logits, axis=1, output_type=tf.int32) # [-1]
self.y = tf.placeholder(tf.int32, [None], 'y')
y = tf.one_hot(self.y, 10) # [-1, 10]
self.loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=logits)
self.loss = tf.reduce_mean(self.loss)
opt = tf.train.AdamOptimizer(config.lr)
self.train_op = opt.minimize(self.loss)
self.precise = tf.reduce_mean(tf.cast(tf.equal(self.y, self.y_predict), tf.float32))
params = 0
for var in tf.trainable_variables():
ps = _params(var.shape)
print(var.name, var.shape, ps)
params += ps
print('-' * 200)
print('Total:', params)
tf.summary.scalar('loss', self.loss)
tf.summary.scalar('precise', self.precise)
self.summary_op = tf.summary.merge_all()
def get_logits(self, x):
"""
:param x: [-1, 784]
:return: [-1, 10]
"""
x = tf.layers.dense(x, 4000, tf.nn.relu, name='dense1')
x = tf.layers.dense(x, 10, name='dense2')
return x
def _params(shape):
result = 1
for sh in shape:
result *= sh.value
return result
class Samples:
def __init__(self, config):
ds = read_data_sets(config.sample_path)
self.train = SubSamples(ds.train)
self.validation = SubSamples(ds.validation)
self.test = SubSamples(ds.test)
class SubSamples:
def __init__(self, data):
self.data = data
def num_examples(self):
return self.data.num_examples
def next_batch(self, batch_size):
return self.data.next_batch(batch_size) # xs: [batch_size, 784], ys: [batch_size]
def show_imgs(xs, ys):
print(ys)
xs = np.reshape(xs, [-1, 28, 28])
xs = np.transpose(xs, [1, 0, 2]) # [28, -1, 28]
xs = np.reshape(xs, [28, -1, 28 * 20]) # [28, -1, 560],
xs = np.transpose(xs, [1, 0, 2]) # [-1, 28, 560]
xs = np.reshape(xs, [-1, 28 * 20])
cv2.imshow('My digits', xs)
cv2.waitKey()
class App:
def __init__(self, config: Config):
self.config = config
self.samples = Samples(config)
g = tf.Graph()
with g.as_default():
self.tensors = Tensors(config)
self.session = tf.Session(graph=g)
self.saver = tf.train.Saver()
try:
self.saver.restore(self.session, config.save_path)
print('Restore the model from %s successfully' % config.save_path)
except:
print('Fail to restore the model from %s, use a new model instead' % config.save_path)
self.session.run(tf.global_variables_initializer())
def close(self):
self.session.close()
def train(self):
train_samples = self.samples.train
config = self.config
ts = self.tensors
fw = tf.summary.FileWriter(config.logdir, self.session.graph)
step = 0
for epoch in range(config.epoches):
batches = train_samples.num_examples() // config.batch_size
for batch in range(batches):
xs, ys = train_samples.next_batch(config.batch_size)
_, summary = self.session.run([ts.train_op, ts.summary_op], {ts.x: xs, ts.y: ys})
fw.add_summary(summary, step)
step += 1
xs, ys = self.samples.validation.next_batch(config.batch_size)
precise_v = self.session.run(ts.precise, {ts.x: xs, ts.y: ys})
print('Epoch: %d, batch %d: precise=%.6f' % (epoch, batch, precise_v))
self.saver.save(self.session, config.save_path)
print('Model saved into', config.save_path)
print('Training is finished!')
def predict(self):
xs, ys = self.samples.test.next_batch(self.config.batch_size)
print(ys)
ts = self.tensors
ys_predict = self.session.run(ts.y_predict, {ts.x: xs})
print('predict:')
print(ys_predict)
if __name__ == '__main__':
config = Config()
app = App(config)
app.train()
# app.predict()
app.close()
25-mnist08_3layers
最新推荐文章于 2025-01-16 12:36:41 发布
本文介绍了一个基于TensorFlow的手写数字识别系统,使用MNIST数据集进行训练和验证。模型采用全连接神经网络,包含两个隐藏层,通过Adam优化器进行参数更新,实现了对手写数字的有效识别。
部署运行你感兴趣的模型镜像
您可能感兴趣的与本文相关的镜像
TensorFlow-v2.15
TensorFlow
TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型
1250

被折叠的 条评论
为什么被折叠?



