import torch
import torch.nn as nn
import numpy as np
print(torch.__version__)
1.1.0
Logistic回归
logistic回归是一中广义线性回归,与多重线性回归分析有很多相同之处。它们模型形式上基本相同,都具有wx+b,其中w和b是待求解参数,区别在于因变量不同,多重线性回归直接将wx+b作为因变量,即y=wx+b,而logistic回归则通过函数L将wx+b对应一个隐状态p,p=L(wx+b),根据p与1-p的大小决定因变量的值。
L为logistic函数时为logistic回归,L为多项式函数时为多项式回归
logistic回归主要进行二分类预测:sigmoid函数就是常见的logistic函数,因为sigmoid函数的输出时0~1之间的概率,当概率大于0.5时预测为1,小于0.5时为0
# 加载数据
data = np.loadtxt('./data/german.data-numeric')
n, l = data.shape
# 数据归一化
for j in range(l - 1):
meanVal = np.mean(data[:, j])
stdVal = np.std(data[:, j])
data[:, j] = (data[:, j] - meanVal) / stdVal
# 打乱数据
np.random.shuffle(data)
# 划分训练集和测试集
train_data = data[:900, :l - 1]
train_tag = data[:900, l - 1] - 1
test_data = data[900:, :l - 1]
test_tag = data[900:, l - 1] - 1
# 定义网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(24, 2)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out = self.fc(x)
out = self.sigmoid(out)
return out
def test(pred, lab):
t = pred.max(-1)[1] == lab
return torch.mean(t.float())
net = Net()
critertion = nn.CrossEntropyLoss() # 使用交叉熵损失
optimizer = torch.optim.Adam(net.parameters()) # Adam优化器
epochs = 1000
for epoch in range(epochs):
net.train() # 指定模型为训练模型,计算梯度
x = torch.from_numpy(train_data).float()
y = torch.from_numpy(train_tag).long()
y_pred = net(x)
loss = critertion(y_pred, y) # 计算损失
optimizer.zero_grad() # 权重置零
loss.backward() # 反向传播
optimizer.step()
if (epoch + 1) % 100 == 0:
net.eval() #指定模型计算模式
test_in = torch.from_numpy(test_data).float()
test_t = torch.from_numpy(test_tag).long()
test_out = net(test_in)
# 使用测试函数计算准确率
accu = test(test_out, test_t)
print('epoch: {}, loss: {}, accuracy:{}'.format(
epoch + 1, loss.item(), accu))
epoch: 100, loss: 0.6658604145050049, accuracy:0.699999988079071
epoch: 200, loss: 0.6306068897247314, accuracy:0.8199999928474426
epoch: 300, loss: 0.6095178723335266, accuracy:0.8100000023841858
epoch: 400, loss: 0.5955496430397034, accuracy:0.8100000023841858
epoch: 500, loss: 0.5853410363197327, accuracy:0.800000011920929
epoch: 600, loss: 0.5774123072624207, accuracy:0.8199999928474426
epoch: 700, loss: 0.5710282921791077, accuracy:0.8199999928474426
epoch: 800, loss: 0.5657661557197571, accuracy:0.8199999928474426
epoch: 900, loss: 0.5613517165184021, accuracy:0.8199999928474426
epoch: 1000, loss: 0.5575944185256958, accuracy:0.8199999928474426