'''
构建模型
'''
class Net(nn.Module):
def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):
super(Net,self).__init__()
self.layer1 = nn.Sequential(nn.Linear(in_dim,n_hidden_1),nn.BatchNorm1d(n_hidden_1))
self.layer2 = nn.Sequential(nn.Linear(n_hidden_1,n_hidden_2),nn.BatchNorm1d(n_hidden_2))
self.layer3 = nn.Sequential(nn.Linear(n_hidden_2,out_dim))
def forward(self,x):
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
x = self.layer3(x)
return x
device = torch.device("cuda:0")
model = Net(28*28 , 300,100,10)
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=lr,momentum=momentum)
'''
训练模型
'''
losses = []
acces = []
eval_losses = []
eval_acces = []
for epoch in range(num_epoches):
train_loss = 0
train_acc = 0
model.train()
if epoch%5 == 0:
optimizer.param_groups[0]['lr'] *=0.1
for img,label in train_loader:
img = img.to(device)
label = label.to(device)
img = img.view(img.size(0),-1)
out = model(img)
loss = criterion(out,label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
_, pred = out.max(1)
num_correct = (pred == label).sum().item()
acc = num_correct / img.shape[0]
train_acc += acc
losses.append(train_loss / len(train_loader))
acces.append(train_acc / len(train_loader))
print('epoch:{},train_loss:{:.6f},acc:{:.6f}'.format(epoch + 1, train_loss / len(train_loader),
train_acc / len(train_loader)))
eval_loss = 0
eval_acc = 0
model.eval()
for img,label in test_loader:
img = img.to(device)
label = label.to(device)
img = img.view(img.size(0), -1)
out = model(img)
loss = criterion(out,label)
eval_loss +=loss.item()
_, pred = out.max(1)
num_correct = (pred == label).sum().item()
acc = num_correct / img.shape[0]
eval_acc += acc
eval_losses.append(eval_loss / len(test_loader))
eval_acces.append(eval_acc / len(test_loader))
print('Test Loss:{:.6f}, Acc:{:.6f}'.format(eval_loss / len(test_loader), eval_acc / len(test_loader)))
'''
可视化训练结果
'''
plt.title('train_loss')
plt.plot(np.arange(len(losses)),losses)
plt.legend(['Train Loss'],loc= 'upper right')
plt.show()