【TensorFlow学习】基本分类:对服装图像进行分类(1)

该博客介绍了如何使用TensorFlow的tf.keras API训练神经网络,对Fashion MNIST数据集中的服装图像进行分类。首先,导入并探索了包含60,000训练图像和10,000测试图像的数据集,然后对数据进行了预处理,将像素值缩放到0-1范围。" 123375974,11908942,Spring Boot工作流实践:Spring-boot-activiti与RuoYi-vue+flowable,"['spring boot', 'java', '后端', '工作流引擎', 'Activiti']

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

原文来自TensorFlow官网教程,讲的很详细,对于我这等铁five来说,真的很有用,害怕忘掉,赶紧记下来。

原文网址https://tensorflow.google.cn/tutorials/keras/classification

 

本指南训练了一个神经网络模型来对运动鞋和衬衫等服装的图像进行分类。如果您不了解所有细节,也可以;这是完整的TensorFlow程序的快速概述,详细内容随您进行。

本指南使用tf.keras(高级API)在TensorFlow中构建和训练模型。

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)
 
2.3.0
 

导入Fashion MNIST数据集

fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
 

加载数据集将返回四个NumPy数组:

  • train_imagestrain_labels阵列的训练集 -The数据模型用来学习。
  • 针对测试集,和test_images,对模型进行了测试test_labels

图像是28x28 NumPy数组,像素值范围是0到255。标签是整数数组,范围是0到9。这些对应于图像表示的衣服类别

标签
0T恤/上衣
1裤子
2拉过来
3连衣裙
4涂层
5凉鞋
6衬衫
7运动鞋
8
9脚踝靴

每个图像都映射到一个标签。由于类名不包含在数据集中,因此将它们存储在此处以供以后在绘制图像时使用:

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
 

探索数据

在训练模型之前,让我们探索数据集的格式。下图显示了训练集中有60,000张图像,每张图像表示为28 x 28像素:

train_images.shape
 
(60000, 28, 28)
 

同样,训练集中有60,000个标签:

len(train_labels)
 
60000
 

每个标签都是0到9之间的整数:

train_labels
 
数组([9,0,0,...,3,0,5],dtype = uint8)
 

测试集中有10,000张图像。同样,每个图像都表示为28 x 28像素:

test_images.shape
 
(10000, 28, 28)
 

测试集包含10,000个图像标签:

len(test_labels)
 
10000
 

预处理数据

在训练网络之前,必须对数据进行预处理。如果检查训练集中的第一张图像,您将看到像素值落在0到255的范围内:

plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.grid(False)
plt.show()
 

png

在将它们输入神经网络模型之前,将这些值缩放到0到1的范围。为此,将值除以255。以相同的方式预处理训练集测试集非常重要:

train_images = train_images / 255.0

test_images = test_images / 255.0
 

为了验证数据的格式正确,并准备好构建和训练网络,让我们显示训练集中的前25个图像,并在每个图像下方显示类别名称。

plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i]])
plt.show()
 

png

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值