1. 导入库并加载数据
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models
import numpy as np
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
- 导入所需的库:TensorFlow、Matplotlib 和 Numpy。
- 加载 Fashion-MNIST 数据集,并对数据进行标准化处理,使图像数据在 [0, 1] 之间。
2. 定义标签到服饰名称的映射,并展示训练图片
class_names = [
"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
]
plt.figure(figsize=(6, 6))
for i in range(9):
plt.subplot(3, 3, i+1)
plt.imshow(x_train[i], cmap=plt.cm.binary)
plt.title(f" {class_names[y_train[i]]}")
plt.axis('off')
plt.tight_layout()
plt.show()
- 定义数字标签到服饰名称的映射。
- 使用 Matplotlib 创建 3x3 的网格,展示训练集中前 9 张图像,并显示每张图像对应的服饰类别名称。

3. 构建并输出模型概况
model = models.Sequential([
layers.Reshape((28, 28, 1), input_shape=(28, 28)),
layers.Conv2D(32, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dense(10, activation='softmax')
])
model.summary()
- 使用 Keras 构建卷积神经网络(CNN),包括三个卷积层、两个池化层、展平层和两个全连接层,最后的输出层包含 10 个神经元(对应 10 种服饰类别)。
- 输出模型的结构概况,展示每一层的形状、参数数量等信息。

4. 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
- 编译模型,使用
adam
优化器和 sparse_categorical_crossentropy
损失函数,监控准确率。
5. 训练模型
history = model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))
- 训练模型,设置训练轮数为 5,并使用测试集进行验证。

6. 评估模型并展示预测结果
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Test accuracy: {test_acc}")

plt.figure(figsize=(6, 6))
for i in range(9):
test_img = np.expand_dims(x_test[i], axis=0)
prediction = model.predict(test_img)
predicted_label = np.argmax(prediction)
plt.subplot(3, 3, i+1)
plt.imshow(x_test[i], cmap=plt.cm.binary)
plt.title(f" {class_names[predicted_label]}")
plt.axis('off')
plt.tight_layout()
plt.show()
- 评估模型在测试集上的性能,并输出测试准确率。
- 创建 3x3 网格展示测试集的前 9 张图像,并显示每张图像的预测服饰类别。
