【Pytorch】加载本地FashionMNIST数据文件

本文介绍了如何解决在使用沐神的深度学习教程时,Fashion-MNIST数据集下载速度慢的问题。通过修改d2l包中torch.py文件的相应代码,将数据集根目录改为本地路径并禁止下载,可以避免从AWS下载,提高数据加载速度。同时提供了FashionMNIST数据集的百度网盘链接供读者下载。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

MNIST解决网上下载FashionMNIST数据集较慢的问题

文末包含FashionMNIST数据集资源

最近在学习沐神的深度学习,在Softmax的实现中,需要导入Fashion-MNIST数据集, 并设置数据迭代器的批量大小为256。

代码如下:

import torch
from IPython import display
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

其中d2l.load_data_fashion_mnist要求从d2l包中加载数据集,而这里的数据集默认是亚马逊AWS下载,所以下载速度较慢,如下图:

于是根据终止程序以后的提示找到d2l文件中torch.py文件位置 

提示如下:

打开以后找到第234行(下方代码块第10、12行)。

def load_data_fashion_mnist(batch_size, resize=None):
    """Download the Fashion-MNIST dataset and then load it into memory.

    Defined in :numref:`sec_fashion_mnist`"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))

root="../data"改为root="./data"  # 使用项目文件同级地址下的已有的数据集

downloda=True改为download=False  # 不下载

如下:

``` 
mnist_train = torchvision.datasets.FashionMNIST(
        root="./data", train=True, transform=trans, download=False)
mnist_test = torchvision.datasets.FashionMNIST(
        root="./data", train=False, transform=trans, download=False)
```

保存之后重启Jupyter内核再次运行代码之后,没有提示下载。


 这里附上FashionMNIST数据集链接

链接:https://pan.baidu.com/s/1GyaSjSG-_NNEj1vJelJQJg 
提取码:PyFM 


<think>好的,用户现在想了解如何获取和加载MNIST数据集文件用于神经网络训练。首先,我需要回顾之前的对话和提供的引用内容,确保回答准确。 用户之前询问了如何在PyCharm中编写MNIST神经网络代码,我给出了使用TensorFlow的示例。现在用户的问题转向数据集本身的获取和加载,可能需要更详细的步骤,特别是不同框架下的方法,比如TensorFlow、PyTorch,以及手动下载的方式。 根据引用[1],MNIST数据集是基础数据集,适合初学者。引用[2]和[4]提到了PyTorch中使用torchvision加载MNIST和FashionMNIST的方法,包括数据预处理步骤。引用[3]讨论了使用BP神经网络训练MNIST,需要注意图像尺寸问题。这些信息需要整合到回答中。 用户可能需要多种加载方式,比如自动下载和手动下载。自动下载方面,TensorFlow和PyTorch都有内置函数,如tf.keras.datasets.mnist.load_data()和torchvision.datasets.MNIST。手动下载部分,需要提供官方来源,如Yann LeCun的网站,并指导如何加载本地文件。 此外,用户可能关心数据预处理,比如归一化、转换张量等,这在引用[2]和[4]中都有提到。需要强调预处理的重要性,比如归一化到0-1之间,使用ToTensor()等。 还要注意用户可能遇到的常见问题,例如下载慢或失败,这时候手动下载并指定路径是一个解决方案。同时,需要提到不同框架下的数据加载器,如TensorFlow的Dataset和PyTorch的DataLoader,以及如何设置批量大小和洗牌数据。 最后,确保回答结构清晰,分点说明自动下载和手动下载的方法,并涵盖TensorFlow和PyTorch两种主流框架,满足不同用户的需求。同时,加入注意事项和常见问题处理,帮助用户避免常见陷阱。</think>### MNIST数据集获取与加载全指南 #### 一、官方获取途径 1. **原始数据集来源** - 官网地址:http://yann.lecun.com/exdb/mnist/(包含四个.gz压缩文件) - 文件组成: - `train-images-idx3-ubyte.gz`:训练集图片(60,000样本) - `train-labels-idx1-ubyte.gz`:训练集标签 - `t10k-images-idx3-ubyte.gz`:测试集图片(10,000样本) - `t10k-labels-idx1-ubyte.gz`:测试集标签 #### 二、框架自动加载方式 ##### 1. TensorFlow/Keras实现 ```python import tensorflow as tf # 自动下载到 ~/.keras/datasets/ (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data() # 数据预处理 train_images = train_images.reshape((60000, 28, 28, 1)) # 添加通道维度 train_images = train_images.astype('float32') / 255 # 归一化[^1] test_images = test_images.reshape((10000, 28, 28, 1)) test_images = test_images.astype('float32') / 255 ``` ##### 2. PyTorch实现 ```python import torchvision from torchvision import transforms transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) # MNIST标准化参数[^2] ]) # 自动下载到 ./data 目录 train_set = torchvision.datasets.MNIST( root='./data', train=True, download=True, transform=transform ) test_set = torchvision.datasets.MNIST( root='./data', train=False, download=True, transform=transform ) # 创建数据加载器 train_loader = torch.utils.data.DataLoader( train_set, batch_size=64, shuffle=True) ``` #### 三、手动加载本地文件 ```python # 以TensorFlow为例 path_to_train_images = 'mnist/train-images-idx3-ubyte' path_to_train_labels = 'mnist/train-labels-idx1-ubyte' def load_mnist_images(file_path): with open(file_path, 'rb') as f: data = np.frombuffer(f.read(), np.uint8, offset=16) return data.reshape(-1, 28, 28) train_images = load_mnist_images(path_to_train_images) / 255.0 train_labels = np.fromfile(path_to_train_labels, dtype=np.uint8, offset=8) ``` #### 四、数据集特点说明 | 特征 | 说明 | |-------------|-----------------------------| | 图像尺寸 | 28×28像素灰度图 | | 数据分布 | 训练集60,000,测试集10,000 | | 标签范围 | 0-9共10个类别 | | 文件格式 | IDX3-UBYTE特殊二进制格式 | #### 五、常见问题解决 1. **下载失败处理**: - 使用国内镜像源:`pip install tensorflow -i https://pypi.tuna.tsinghua.edu.cn/simple` - 手动下载后放入`~/.keras/datasets/`目录 2. **数据可视化验证**: ```python import matplotlib.pyplot as plt plt.imshow(train_images[0].squeeze(), cmap='gray') plt.title(f"Label: {train_labels[0]}") plt.show() ``` 3. **多框架适配建议**: - TensorFlow用户优先使用`tf.data.Dataset`API ```python train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)) train_dataset = train_dataset.shuffle(1000).batch(32) ``` - PyTorch用户推荐使用`DataLoader`的num_workers参数加速 ```python train_loader = DataLoader(dataset, batch_size=64, num_workers=4) ``` --相关问题--: 1. MNIST数据集能否直接用于卷积神经网络训练? 2. 如何处理自定义手写数字图片与MNIST格式的兼容问题? 3. MNIST数据集在图像分类任务中有哪些局限性? 4. 如何将MNIST数据转换为CSV格式进行存储? [^1]: MNIST数据集文件采用IDX格式存储,包含文件头信息用于描述数据维度 [^2]: 标准化参数来自MNIST数据集的全局统计值,能提升模型训练稳定性
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Snychng

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值