完整文件:https://github.com/JintuZheng/Blog-/blob/master/Demo_LogicRegression_MNIST.py
包导入准备
import torchvision.datasets
import torchvision.transforms
import torch.utils.data
import matplotlib.pyplot as plt
import torch.nn
import torch.optim
from debug import ptf_tensor
设置超参数
# Hyperparameters超参数
BATCH_SIZE=100
NUM_EPOCHS=5
DEVICE='cuda:0'
数据集下载
########################## 训练集的准备 ##############################################
train_dataset=torchvision.datasets.MNIST(root='D:/DataTmp/mnist',train=True,transform=torchvision.transforms.ToTensor(),download=True)
#root:下载数据存放到哪里,train:下载训练集还是测试集,transfrom:数据转化的形式
test_dataset=torchvision.datasets.MNIST(root='D:/DataTmp/mnist',train=False, transform=torchvision.transforms.ToTensor(),download=True)
【1】设置dataloader,分批读取数据,因为我们没办法一次训练过多数据
#由于数据集里面有上万条数

本文详细介绍了一种基于PyTorch实现的手写数字识别模型,使用逻辑回归算法对MNIST数据集进行训练与测试,实现了从数据加载、预处理到模型训练及评估的全过程。
最低0.47元/天 解锁文章
6558

被折叠的 条评论
为什么被折叠?



