import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
from tensorflow import keras
导入相关的包
进行数据读取和处理
def preprocess(x,y):
x =tf.cast(x,dtype=tf.float32) / 255. - 1.
y =tf.cast(y,dtype=tf.int32)
return x,y
batchsz = 128
(x, y), (x_val, y_val) = datasets.fashion_mnist.load_data()
# 删除所有大小为1的维度 将y和y_val 的[50k,1,10]==>[50000, 10]
y = tf.squeeze(y)
y_val = tf.squeeze(y_val)
y = tf.one_hot(y,depth=10)# [50k,10]
y_val = tf.one_hot(y_val,depth=10) #[10k,10]
print("datasets:",x.shape,y.shape,x.min(),x.max())
#调整x,y值
train_db = tf.data.Dataset.from_tensor_slices((x,y))
#处理图片,打散并取前128张
train_db = train_db.map(preprocess).shuffle(100000).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices((x_val,y_val))
test_db = test_db.map(preprocess).batch(batchsz)
sample = next(iter(train_db))
print('batch:',sample[0].shape,sample[1].shape)
开始进行梯度处理
将[b,28*28]==>[b,10]
class MyDense(layers.Layer):
def __init__(self,inp_dim,outp_dim):
super(MyDense, self).__init__()
self.kernel = self.add_variable('w',[inp_dim,outp_dim])
# self.bias = self.add_variable('b',[outp_dim])
def call(self, inputs,training=None):
x = inputs @ self.kernel
return x
class MyNetwork(keras.Model):
def __init__(self):
super(MyNetwork, self).__init__()
self.fc1 = MyDense(28*28,256)
self.fc2 = MyDense(256,128)
self.fc3 = MyDense(128,64)
self.fc4 = MyDense(64,32)
self.fc5 = MyDense(32,10)
def call(self, inputs, training=None):
"""
:param inputs:[b,32,32,3]
:param training:
:return:
"""
x = tf.reshape(inputs,[-1,28*28])
#[b,28*28]==>[b,10]
x = self.fc1(x)
x = tf.nn.relu(x)
x = self.fc2(x)
x = tf.nn.relu(x)
x = self.fc3(x)
x = tf.nn.relu(x)
x = self.fc4(x)
x = tf.nn.relu(x)
x = self.fc5(x)
return x
开始训练
network = MyNetwork()
network.compile(optimizer=optimizers.Adam(lr=1e-3),
loss = tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
network.fit(train_db,epochs=5,validation_data=test_db,validation_freq=1)
network.evaluate(test_db)
该博客介绍了如何使用TensorFlow构建一个深度学习网络,包括数据预处理、自定义层(MyDense)、网络模型(MyNetwork)的创建,以及训练和评估过程。网络应用于Fashion MNIST数据集,对图像进行分类。
3090

被折叠的 条评论
为什么被折叠?



