import matplotlib.pyplot as plt import torch import torchvision from d2l import torch as d2l from IPython import display d2l.use_svg_display() """ # 1, 读取数据集 # 1.1,通过ToTensor实例将图像数据从PIL类型变换为32位浮点数格式。 # 并除以255使得所有像素数值均在0~1之间 trans = torchvision.transforms.ToTensor() mnist_train_dataset = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True) mnist_test_dataset = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True) print(mnist_train_dataset[0][0].shape) def get_fashion_mnist_labels(labels): text_labels = ['t-shrit', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels] # 1.2, 创建一个函数来可视化这些样本 def show_dataset_images(imgs, num_rows, num_cols, titles=None, scale=1.5): figsize = (num_cols*scale, num_rows*scale) _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize) axes = axes.flatten() for i, (ax, img) in enumerate(zip(axes, imgs)): if torch.is_tensor(img): # 图片张量 #ax.imshow(img.numpy()) plt.imshow(img.numpy()) else: # PIL图片 #ax.imshow(img) plt.imshow(img) ax.axes.get_xaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False) if titles: ax.set_title(titles[i]) return axes X, y = next(iter(torch.utils.data.DataLoader(mnist_train_dataset, batch_size=18))) show_dataset_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y)); batch_size1 = 256 """ # 2, 读取小批量 def get_dataloader_workers(): """使用4个进程来读取数据""" return 4 # 2.1, 为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器, # train_iter = torch.utils.data.DataLoader(mnist_train_dataset, batch_size1, shuffle=True, num_workers=get_dataloader_workers()) def load_data_fashion_mnist(batch_size, resize=None): #@save trans = [torchvision.transforms.ToTensor()] if resize: trans.insert(0, torchvision.transforms.Resize(resize)) trans = torchvision.transforms.Compose(trans) mnist_train_dataset = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True) mnist_test_dataset = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True) return (torch.utils.data.DataLoader(mnist_train_dataset, batch_size, shuffle=True, num_workers=get_dataloader_workers()), torch.utils.data.DataLoader(mnist_test_dataset, batch_size, shuffle=False, num_workers=get_dataloader_workers())) # 2.2 测试加载batch_size=32的样本。 train_iter1, test_iter1 = load_data_fashion_mnist(32, resize=64) for X1, Y1 in train_iter1: print(X1.shape,
使用Fashion-MNIST数据集完成Softmax回归模型搭建
于 2024-01-31 16:55:52 首次发布