逻辑回归(logistic regression)
github代码地址:
https://github.com/yunjey/pytorch-tutorial
本次代码简介(注释详细)
用的数据集是MNIST数据集,网络结构就一层(适合Pytorch新手学习)
(包括数据的下载,加载,模型定义,模型训练,模型测试,测试集的准确率计算等)
在代码最后还有关于max()函数的参数说明
code
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
#1.设置超参数
input_size=28*28
num_classes=10 #类别是也是相当于output_size
num_epochs=5
batch_size=100
learning_rate=0.001
#2.设置数据(我们使用MNIST数据)
train_dataset=torchvision.datasets.MNIST(root="mnistData",
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset=torchvision.datasets.MNIST(root="mnistData",
train=False,
transform=transforms.ToTensor(),
download=True)
#3.加载数据
train_loader=torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader=torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
#4.设置logistic regression模型
model=nn.Linear(input_size,num_classes)
#loss和optimizer
criterion=nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),
lr=learning_rate)
#5.训练model
#统计数据个数
total_step=len(train_loader)
#开始
for epoch in range(1,num_epochs+1):
for i,(images,labels) in enumerate(train_loader):
#上句代码可以写成下面(只不过没有i序号)
# for images,labels in train_loader:
#由于使用enumerate使得train_loader每一个元组都有了一个序号i
#数据处理
# images=images.reshape(-1,input_size)
images=images.view(-1,input_size)
#forward
prev=model(images)
#计算损失
loss=criterion(prev,labels)
#backward and optimizer
optimizer.zero_grad()
loss.backward()
optimizer.step()
#输出
if i%100==0:
print('Epoch[{}/{}],Step[{}/{}],loss::{:.4f}'
.format(epoch,num_epochs,i,total_step,loss))
#6.使用测试集测试模型
#对于测试我们不需要计算梯度
with torch.no_grad():
total = 0
correct = 0
for images,labels in test_loader:
#数据处理
images=images.reshape(-1,input_size)
#测试得到输出
outputs=model(images)
#torch.max(tensor,0或者1)--》输出(最大值,索引值)
_,predicted=torch.max(outputs.detach(),1) #1代表每一行的最大值
# _,predicted=torch.max(outputs.data,1) #1代表每一行的最大值
total+=labels.size(0) #每次都是一个batch_size
correct+=(predicted==labels).sum()
# # print("total:",total)
# print("correct:",correct)
# print(predicted)
# print(labels)
print('Accuracy of th model on the 10000 test images : {}%'.format(correct*100/total))
#保存model checkpoint
torch.save(model.state_dict(),'logisticRegerssion.ckpt')
#关于torch.max()的说明:
# output = torch.max(input, dim)
# 输入
#
# input是softmax函数输出的一个tensor
# dim是max函数索引的维度0/1,0是每列的最大值,1是每行的最大值
# 输出
# 函数会返回两个tensor,第一个tensor是每行(列)的最大值;第二个tensor是如果0-》行 最大值的索引。
#