数据概况
Fashion-mnist
经典的MNIST数据集包含了大量的手写数字。十几年来,来自机器学习、机器视觉、人工智能、深度学习领域的研究员们把这个数据集作为衡量算法的基准之一。你会在很多的会议,期刊的论文中发现这个数据集的身影。实际上,MNIST数据集已经成为算法作者的必测的数据集之一。
类别标注
在Fashion-mnist数据集中,每个训练样本都按照以下类别进行了标注:

数据处理
对输入进行归一化
归一化时需要统一进行 x = (x - mean) / std
train_trans = transforms.Compose([
transforms.RandomCrop(28, padding=2),#数据增强
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
])
test_trans = transforms.Compose([
transforms.ToTensor(),
normalize
])
mnist_train = torchvision.datasets.FashionMNIST(root='../data',train=True,download=True,transform=train_trans)
mnist_test = torchvision.datasets.FashionMNIST(root='../data',train=False,download=True,transform=test_trans)
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False)
# 求整个数据集的均值
temp_sum = 0
cnt = 0
for X, y in train_iter:
if y.shape[0] != batch_size:
break # 最后一个batch不足batch_size,这里就忽略了
channel_mean = torch.mean(X, dim=(0,2,3)) # 按channel求均值(不过这里只有1个channel)
cnt += 1 # cnt记录的是batch的个数,不是图像
temp_sum += channel_mean[0].item()
dataset_global_mean = temp_sum / cnt
print('整个数据集的像素均值:{}'.format(dataset_global_mean))
# 求整个数据集的标准差
cnt = 0
temp_sum = 0
for X, y in train_iter:
if y.shape[0] != batch_size:
break # 最后一个batch不足batch_size,这里就忽略了
residual = (X - dataset_global_mean) ** 2
channel_var_mean = torch.mean(residual, dim=(0,2,3))
cnt += 1 # cnt记录的是batch的个数,不是图像
temp_sum += math.sqrt(channel_var_mean[0].item())
dataset_global_std = temp_sum / cnt
print('整个数据集的像素标准差:{}'.format(dataset_global_std))
整个数据集的像素均值:0.2860366729433025
整个数据集的像素标准差:0.35288708155778725
数据增强
加入随机裁剪和翻转
============================ step 1/6 数据 ============================
batch_size = 64
normalize = transforms.Normalize(mean=[0.286], std=[0.352])#对像素值归一化
train_trans = transforms.Compose([
transforms.RandomCrop(28, padding=2),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
])
test_trans = transforms.Compose([
transforms.ToTensor(),
normalize
])
mnist_train = torchvision.datasets.FashionMNIST(root='../data',train=True,download=True,transform=train_trans)
mnist_test = torchvision.datasets.FashionMNIST(root='../data',train=False,download=True,transform=test_trans)
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False)

本文详细介绍了使用PyTorch框架对Fashion MNIST数据集进行图像分类的任务,包括数据预处理、ResNet网络定义、训练及测试过程。通过数据增强、归一化等技术提高模型性能。
最低0.47元/天 解锁文章
985

被折叠的 条评论
为什么被折叠?



