mnist_loader.py

这是一个用于加载 MNIST 图像数据的库,详细介绍了数据结构及其返回的数据格式。通常调用 `load_data_wrapper` 函数来获取经过预处理的训练、验证及测试数据。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

"""
mnist_loader
~~~~~~~~~~~~

A library to load the MNIST image data.  For details of the data
structures that are returned, see the doc strings for ``load_data``
and ``load_data_wrapper``.  In practice, ``load_data_wrapper`` is the
function usually called by our neural network code.
"""

#### Libraries
# Standard library
import pickle
import gzip

import numpy as np
def load_data():
    f = gzip.open('mnist.pkl.gz', 'rb')
    training_data, validation_data, test_data = pickle.load(f,encoding = 'bytes')
    f.close()
    return (training_data, validation_data, test_data)

def load_data_wrapper():
    tr_d, va_d, te_d = load_data()
    training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]
    training_results = [vectorized_result(y) for y in tr_d[1]]
    training_data = zip(training_inputs, training_results)
    validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]
    validation_data = zip(validation_inputs, va_d[1])
    test_inputs = [np.reshape(x, (784, 1)) for x in te_d[0]]
    test_data = zip(test_inputs, te_d[1])
    return (training_data, validation_data, test_data)

def vectorized_result(j):
    e = np.zeros((10, 1))
    e[j] = 1.0
    return e
优化代码import torch import torchvision.datasets as dsets import torchvision.transforms as transforms #/********** Begin *********/ # 下载MNIST数据集 Mnist_dataset = dsets.MNIST(root=&#39;./data&#39;,train=True,transform=transforms.ToTensor(),download=True) # 创建batch_size=100, shuffle=True的DataLoader类型的变量data_loader data_loader = torch.utils.data.DataLoader(dataset=Mnist_dataset, batch_size=100,shuffle=True) # 输出 data_loader中数据类型 print(type(data_loader.dataset)) #/********** End *********/使其通过import torch import torch.nn as nn from torch.autograd import Variable import os import re import sys fileName = &#39;data.py&#39; #fileName = &#39;answer.py&#39; path = os.path.split(os.path.abspath(os.path.realpath(sys.argv[0])))[0] + os.path.sep cmd = &#39;python3 &#39; + path + fileName flag = 0 answer = os.popen(cmd).read() print(answer) answer = "".join(answer.split()) #print(answer) a = "<class&#39;torchvision.datasets.mnist.MNIST&#39;>" if a == answer: flag += 1 #print(flag) file=open(path + fileName) #默认的mode是&#39;r&#39;,即读模式content = file.read() #读取文件内容 content = file.read() #读取文件内容 content = content.replace(&#39; &#39;, &#39;&#39;) text = ["MNIST","type(data_loader.dataset)","DataLoader","ToTensor"] i = 0 for t in text: if t in content: i += 1 continue else: print("Sorry! Check again please!") break if i == len(text): flag += 1 if flag == 2: print("Congratulation!") file.close()测试并正确输出<class &#39;torchvision.datasets.mnist.MNIST&#39;> Congratulation!
03-11
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值