- 在【Pytorch实战(三)】线性回归中介绍了线性回归。线性判决与线性回归是类似的,区别在于线性回归的“标签”通常是较为连续的取值,而线性判决的“标签”往往只有几个可能的取值,例如【0,1】。
- 在线性回归中,经常使用MSE损失
torch.nn.MSELoss()
、L1损失torch.nn.L1Loss
和SmoothL1损失torch.nn.SmoothL1Loss
等。相应地,在线性判决中也需使用合适的损失。对于二元判决常用BCE(Binary Cross Entropy)损失;对于多元则常用Cross Entropy损失。 - 以二元判决为例,计算BCE损失需要获得一个预测概率,这需要使用逻辑回归算法(对应于
torch.sigmoid()
函数);若拓展至多元则对应于softmax函数。具体公式不再赘述,可参照结合实例理解pytorch中交叉熵损失函数进行理解。 - 实际实现中,二元的
torch.nn.BCELoss()
和torch.nn.BCEWithLogitsLoss()
又有所区别,前者需传入预测概率和真实标签,而后者则需传入W与X线性运算的结果Z和真实标签。多元的则使用torch.nn.CrossEntropyLoss
即可。 - 以下为例程:
import torch
import torch.nn
import torch.optim
x = torch.tensor([[1., 2., 1.], [2., 4., 1.], [3., 5., 1.], [4., 2., 1.], [4., 4., 1.]])
y = torch.tensor([0., 1., 1., 0., 1.])
w = torch.zeros(3, requires_grad=True)
L = torch.nn.BCEWithLogitsLoss()
optimiter = torch.optim.Adam([w])
for step in range(100001):
if step:
optimiter.zero_grad()
loss.backward()
optimiter.step()
pred = torch.mv(x, w)
loss = L(pred, y)
if step % 10000 == 0:
print('step{}: loss={}, w={}'.format(step, loss, w.tolist()))
z = torch.mv(x, w)
y_ = torch.sigmoid(z)
print(y_)
以上为个人理解,如有不正之处还请指出,欢迎交流!