【pytorch_6】Logistic回归的实现

本文介绍了如何在PyTorch中实现Logistic回归,包括Logistic函数的特性,定义模型,设定损失函数(二分类交叉熵)和优化器(SGD),以及训练和更新模型的步骤。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1. Logistic 函数

函数可以将输出映射到[0,1]之间,函数会逐渐趋于平稳,即饱和,对应着导数会越来越小(至少会>0)。如图:

 

2. 定义模型

 

3. 定义Loss损失和优化

Loss函数:

首先线性回归的损失函数使用的是:

而对于二分类的交叉熵损失函数,即:

如果使用Mini-Batch损失函数处理则变成了此时Logistic回归模型所用的损失函数。

 

优化:SGD

 

4. 训练以及更新

 

 


实现过程:

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[0],[0],[1]])    #准备数据

class LogisticRegressionModel(torch.nn.Module):  #继承
    def __init__(self):                          #构造函数,初始化对象
        super(LogisticRegressionModel,self).__init__()      #调用父类
        self.linear = torch.nn.Linear(1,1)       #(权重,偏置)

    def forward(self,x):
        y_pred = F.sigmoid(self.linear(x))
        return y_pred

model = LogisticRegressionModel()   #实例化

criterion = torch.nn.BCELoss(size_average=False)         #交叉熵损失函数
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)  #学习率0。01

#训练
for epoch in range(1000):
    y_pred = model(x_data)
    loss = criterion(y_pred,y_data)
    print(epoch,loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


#画图
x = np.linspace(0,10,200)
x_t = torch.Tensor(x).view((200,1))
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x,y)
plt.plot([0,10],[0.5,0.5],c='r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()
0 3.0579307079315186
1 2.996345043182373
2 2.937220335006714
3 2.8805525302886963
4 2.826328754425049
5 2.7745296955108643
6 2.7251267433166504
7 2.678084373474121
8 2.6333603858947754
9 2.590904712677002
10 2.550662040710449
...
993 1.0034360885620117
994 1.0030064582824707
995 1.002577304840088
996 1.0021485090255737
997 1.0017201900482178
998 1.0012922286987305
999 1.000864863395691

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值