5.3 用PyTorch实现Logistic回归

一、数据准备

Logistic回归常用于解决二分类问题。

为了便于描述,我们分别从两个多元高斯分布 N₁(μ₁,Σ₁ )、N₂(μ₂,Σ₂)中生成数据 x₁ 和 x₂,这两个多元高斯分布分别表示两个类别,分别设置其标签为 y₁ 和 y₂。

PyTorch 的 torch.distributions 提供了 MultivariateNormal 构建多元高斯分布

下面代码设置两组不同的均值向量和协方差矩阵,μ₁(mul)和 μ₂(mul)是二维均值向量,Σ₁(sigmal)和Σ₂(sigma2)是2*2的协方差矩阵。

前面定义的均值向量和协方差矩阵作为差数传入 MultivariateNormal,就实例化了两个多元高斯分布 m₁和 m₂。

调用 m₁和 m₂ 的sample方法分别生成100个样本。

设置样本对应的标签 y,分别用 0 和 1 表示不同高斯分布的数据,也就是正样本和负样本。

使用 cat 函数将 x₁(m1)和 x₂(m2)组合在一起。

打乱样本和标签的顺序,将数据重新随机排列,这是十分重要的步骤,否则算法的每次迭代只会学习到同一个类别的信息,容易造成模型过拟合。 

将生成的样本用 plt.scatter 绘制出来。

绘制结果如图:

可以明显的看出多元高斯分布生成的样本聚成了两个簇,并且簇的中心分别处于不同的位置(多元高斯分布的均值向量决定了其位置)。

右上角簇的样本分布比较稀疏,而左下角簇的样本分布紧凑(多元高斯分布的协方差矩阵决定了分布形状)。

【可调整

mu1 = -3 * torch.ones(2)

mu2 = 3 * torch.ones(2)

的参数,观察变化!

二、线性方程

Logistic回归用输入变量x的线性函数表示样本为正类的对数概率。nn.Linear 实现了 y = xAᵀ + b,我们可以直接调用它来实现Logistic回归的线性部分。

定义线性模型的输入维度D_in 和输出维度 D_out,因为前面定义的多元高斯分布 m₁(m1)和 m₂(m2)产生的变量是二维的,所以线性模型的输入维度应该定义为D_in = 2 ;而Logistic回归是二分类模型,预测的是变量为正类的概率,所以输出的维度应该为D_in = 1。

实例化了nn.Linear,将线性模型应用到数据 x 上,得到计算结果output。

Linear的初始参数是随机设置的,可以调用Linear.weight 和 Linear.bias 获取线性模型的参数。

输出输入的变量x,模型参数weight和bias,以及计算结果output的维度。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值