#关于CNN学习后 对pytorch的理解
#导入相应模块
import torch
import torch.nn as nn
import torchvision
import torch.utils.data as Data
# Hyper Parameters 超参数
EPOCH = 1 # 训练整批数据多少次, 为了节约时间, 我们只训练一次
BATCH_SIZE = 50
LR = 0.001 # 学习率
DOWNLOAD_MNIST = True # 已经下载好的话,会自动跳过的
# Mnist 手写数字集的设置
train_data = torchvision.datasets.MNIST(
root='./mnist/', # 保存或者提取位置
train=True, # this is training data
transform=torchvision.transforms.ToTensor(), # 转换 PIL.Image or numpy.ndarray 成
# torch.FloatTensor (C x H x W), 训练的时候 normalize 成 [0.0, 1.0] 区间
download=DOWNLOAD_MNIST, # 没下载就下载, 下载了就不用再下了
)
#关于train_data 里面的transform=torchvision.transforms.ToTensor(),是将图片改成C*H*W在这里为1*28*28
#而train_data.train data 为60000*28*28
test_data = torchvision.datasets.MNIST(
root='./mnist/',
train=False,
)
#print(test_data.test_data) 为10000*28*28
# 训练集丢BATCH_SIZE个, 图片大小为28*28
train_loader = Data.DataLoader( #数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器
dataset=train_data, #数据加载器的作用是将数据存进去,
经典CNN识别MINIST数据集代码的理解
最新推荐文章于 2025-05-12 15:40:32 发布