import matplotlib.pyplot as plt
import torch
import torchvision
from d2l import torch as d2l
from IPython import display
d2l.use_svg_display()
"""
# 1, 读取数据集
# 1.1,通过ToTensor实例将图像数据从PIL类型变换为32位浮点数格式。
# 并除以255使得所有像素数值均在0~1之间
trans = torchvision.transforms.ToTensor()
mnist_train_dataset = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test_dataset = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
print(mnist_train_dataset[0][0].shape)
def get_fashion_mnist_labels(labels):
text_labels = ['t-shrit', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
# 1.2, 创建一个函数来可视化这些样本
def show_dataset_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
figsize = (num_cols*scale, num_rows*scale)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
# 图片张量
#ax.imshow(img.numpy())
plt.imshow(img.numpy())
else:
# PIL图片
#ax.imshow(img)
plt.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
X, y = next(iter(torch.utils.data.DataLoader(mnist_train_dataset, batch_size=18)))
show_dataset_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
batch_size1 = 256
"""
# 2, 读取小批量
def get_dataloader_workers():
"""使用4个进程来读取数据"""
return 4
# 2.1, 为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,
# train_iter = torch.utils.data.DataLoader(mnist_train_dataset, batch_size1, shuffle=True, num_workers=get_dataloader_workers())
def load_data_fashion_mnist(batch_size, resize=None): #@save
trans = [torchvision.transforms.ToTensor()]
if resize:
trans.insert(0, torchvision.transforms.Resize(resize))
trans = torchvision.transforms.Compose(trans)
mnist_train_dataset = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test_dataset = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
return (torch.utils.data.DataLoader(mnist_train_dataset, batch_size, shuffle=True, num_workers=get_dataloader_workers()),
torch.utils.data.DataLoader(mnist_test_dataset, batch_size, shuffle=False, num_workers=get_dataloader_workers()))
# 2.2 测试加载batch_size=32的样本。
train_iter1, test_iter1 = load_data_fashion_mnist(32, resize=64)
for X1, Y1 in train_iter1:
print(X1.shape, X1.dtype, Y1.shape, Y1.dtype)
break
# 3, Softmax回归从零开始实现
# 3.1 使用Fashion-MNIST数据集,并设置数据迭代器的批量大小为256.
batch_size2 = 256
train_dataset_iterator, test_dataset_iterator = d2l.load_data_fashion_mnist(batch_size2, resize=None)
# 4, 初始化模型参数
# 4.1 数据集中每个样本是28x28的图像,将该图像展平为一维向量为784长度,
# 由于有10个类别,那么权重将是784x10的矩阵,偏置将构成一个1x10的行向量,
# 我们仍然使用正态分布初始化权重Weight,偏置bais初始化为0.
num_inputs = 784
num_outputs = 10
weight = torch.normal(0, 0.01, size=(num_in
使用Fashion-MNIST数据集完成Softmax回归模型搭建
于 2024-01-31 16:55:52 首次发布
本文介绍了如何使用PyTorch库处理FashionMNIST数据集,包括数据预处理、图像可视化、Softmax回归模型的实现、损失函数计算以及使用小批量随机梯度下降进行训练。同时展示了训练过程中的损失和准确率变化以及模型预测的结果。

最低0.47元/天 解锁文章
722

被折叠的 条评论
为什么被折叠?



