一、数据准备
Logistic回归常用于解决二分类问题。
为了便于描述,我们分别从两个多元高斯分布 N₁(μ₁,Σ₁ )、N₂(μ₂,Σ₂)中生成数据 x₁ 和 x₂,这两个多元高斯分布分别表示两个类别,分别设置其标签为 y₁ 和 y₂。
PyTorch 的 torch.distributions 提供了 MultivariateNormal 构建多元高斯分布。
下面代码设置两组不同的均值向量和协方差矩阵,μ₁(mul)和 μ₂(mul)是二维均值向量,Σ₁(sigmal)和Σ₂(sigma2)是2*2的协方差矩阵。
前面定义的均值向量和协方差矩阵作为差数传入 MultivariateNormal,就实例化了两个多元高斯分布 m₁和 m₂。
调用 m₁和 m₂ 的sample方法分别生成100个样本。
设置样本对应的标签 y,分别用 0 和 1 表示不同高斯分布的数据,也就是正样本和负样本。
使用 cat 函数将 x₁(m1)和 x₂(m2)组合在一起。
打乱样本和标签的顺序,将数据重新随机排列,这是十分重要的步骤,否则算法的每次迭代只会学习到同一个类别的信息,容易造成模型过拟合。
将生成的样本用 plt.scatter 绘制出来。
绘制结果如图:
可以明显的看出多元高斯分布生成的样本聚成了两个簇,并且簇的中心分别处于不同的位置(多元高斯分布的均值向量决定了其位置)。
右上角簇的样本分布比较稀疏,而左下角簇的样本分布紧凑(多元高斯分布的协方差矩阵决定了分布形状)。
【可调整
mu1 = -3 * torch.ones(2)
mu2 = 3 * torch.ones(2)
的参数,观察变化!
】
二、线性方程
Logistic回归用输入变量x的线性函数表示样本为正类的对数概率。nn.Linear 实现了 y = xAᵀ + b,我们可以直接调用它来实现Logistic回归的线性部分。
定义线性模型的输入维度D_in 和输出维度 D_out,因为前面定义的多元高斯分布 m₁(m1)和 m₂(m2)产生的变量是二维的,所以线性模型的输入维度应该定义为D_in = 2 ;而Logistic回归是二分类模型,预测的是变量为正类的概率,所以输出的维度应该为D_in = 1。
实例化了nn.Linear,将线性模型应用到数据 x 上,得到计算结果output。
Linear的初始参数是随机设置的,可以调用Linear.weight 和 Linear.bias 获取线性模型的参数。
输出输入的变量x,模型参数weight和bias,以及计算结果output的维度。