从图像中区分鸟类和飞机:图像学习指南
1. 均方误差(MSE)在分类任务中的局限性
在图像分类任务里,预测结果可能会偏离目标。当预测处于低损失区域时,正确类别被赋予的预测概率能达到 99.97%。然而,我们一开始就排除掉的均方误差(MSE),在很早的时候就会饱和,而且对于非常错误的预测也是如此。其根本原因在于,MSE 的斜率过低,无法弥补 softmax 函数在错误预测时的平缓特性。所以,基于概率的 MSE 并不适合分类工作。
2. 训练分类器
2.1 训练循环的调整
我们重新使用之前写的训练循环,并对其进行了一些小调整。以下是具体的代码:
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Linear(3072, 512),
nn.Tanh(),
nn.Linear(512, 2),
nn.LogSoftmax(dim=1)
)
learning_rate = 1e-2
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
loss_fn = nn.NLLLoss()
n_epochs = 100
for epoch in range(n_epochs):
for img, label in cifar2:
out = model(img.view(-1).unsqueeze(0))
loss = loss_fn(out, torch.tensor([labe
超级会员免费看
订阅专栏 解锁全文
929

被折叠的 条评论
为什么被折叠?



