mnist数据集简介:
在torchvision中包含,可通过下列代码下载:
from torchvision import datasets
train_dataset = datasets.MNIST(root='./dataset/mnist/', train=True, download=True)
# test_dataset = datasets.MNIST(root='./dataset/mnist/', train=False, download=True)
上述代码会将数据下载到指定路径下,后面不需要再次下载,包含训练集(60000)、测试集(10000)的图片和标签共四个文件,
查看数据集:
from torchvision import transforms
from torchvision import datasets
mean,sqr=0.1307,0.3081 #minist数据集的均值和方差,transforms.ToTensor()会将数据转换至0-1,compose是将操作组合
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((mean,), (sqr,))])
traindata=datasets.MNIST(root='./datasets/mnist',download=False,train=True,transform=transform)
testdata=datasets.MNIST(root='./datasets/mnist',download=False,train=False,transform=transform)
print(traindata.data.size())
print(traindata.targets.size())
print(testdata.data.size())
print(testdata.targets.size())
得到结果:
torch.Size([60000, 28, 28])
torch.Size([60000])
torch.Size([10000, 28, 28])
torch.Size([10000])
模型经过16轮训练在测试集可以达到98.14%的准确率:
import torch
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
batchsize=32
lr=0.01
mean,sqr=0.1307,0.3081 #minist数据集的均值和方差,transforms.ToTensor()会将数据转换至0-1,compose是将操作组合
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((mean,), (sqr,))])
traindata=datasets.MNIST(root='./datasets/mnist',download=False,train=True,transform=transform)
testdata=datasets.MNIST(root='./datasets/mnist',download=False,train=False,transform=transform)
trainloard=DataLoader(dataset=traindata,batch_size=batchsize,shuffle=True,num_workers=0,drop_last=True)
testloard=DataLoader(dataset=testdata,batch_size=batchsize,shuffle=False,num_workers=0,drop_last=True)
class Model(torch.nn.Module):
def __init__(self):
super(Model,self).__init__()
self.l1=torch.nn.Linear(784,256)
self.l2=torch.nn.Linear(256,64)
self.l3=torch.nn.Linear(64,32)
self.l4=torch.nn.Linear(32,10)
self.relu=torch.nn.ReLU()
def forward(self,x):
x=x.view(-1,28*28*1)
x=self.relu(self.l1(x))
x=self.relu(self.l2(x))
x=self.relu(self.l3(x))
x=self.l4(x)
return x
model=Model()
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=lr,momentum=0.5)
def train(allepoch):
lepoch=[]
llost=[]
lacc=[]
for epoch in range(allepoch):
lost=0
count=0
for num,(x,y) in enumerate(trainloard,1):
y_h=model(x)
loss=criterion(y_h,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
lost+=loss.item()
count+=batchsize
print('epoch:',epoch+1,'loss:',lost/count,end=' ')
lepoch.append(epoch+1)
llost.append(lost/count)
acc=test()
lacc.append(acc)
print('acc:',acc)
plt.plot(lepoch,llost,label='loss')
plt.plot(lepoch,lacc,label='acc')
plt.legend()
plt.show()
def test():
with torch.no_grad():
acc=0
count=0
for num, (x, y) in enumerate(testloard, 1):
y_h = model(x)
_,y_h=torch.max(y_h.data,dim=1)
acc+= (y_h==y).sum().item()
count+=x.size(0)
return acc/count
if __name__=='__main__':
train(30)
结果如下: