《机器学习》课后习题3.5 编辑实现线性判别分析,并给出西瓜数据集 3.0α 上的结果.

本文介绍了一个简单的线性判别分析(LDA)实现,并使用西瓜数据集3.0α进行了演示。该文详细展示了如何通过Python计算类内散度矩阵和类间散度矩阵,并给出了数据可视化结果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

参考了han同学的答案,西瓜数据集也可在han同学的github上下载。

3.5 编辑实现线性判别分析,并给出西瓜数据集 3.0α 上的结果.

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
# 下面这两行是为了让matplotlib画图时的汉字正常显示
import matplotlib
matplotlib.rc("font",family='YouYuan')

class LDA(object):

    # 绘图,求出均值向量,根据公式3.34和3.39求出类内散度矩阵和类间散度矩阵
    def fit(self, X_, y_, plot_=False):
        # 取出正反例各自数据,计算均值向量

        neg = y_ == 0
        pos = y_ == 1
        X0 = X_[neg]
        X1 = X_[pos]

        # 均值向量,(1, 2)
        u0 = X0.mean(0, keepdims=True)
        u1 = X1.mean(0, keepdims=True)

        # 类内散度矩阵,公式3.33,(2, 2)
        sw = np.dot((X0 - u0).T, (X0 - u0)) + np.dot((X1 - u1).T, (X1 - u1))
        # 类间散度矩阵,公式3.37,(1, 2)
        w = np.dot(np.linalg.inv(sw), (u0 - u1).T).reshape(1, -1)

        # 绘图
        if plot_:
            fig, ax = plt.subplots()
            ax.spines['right'].set_color('none')
            ax.spines['top'].set_color('none')
            ax.spines['left'].set_position(('data', 0))
            ax.spines['bottom'].set_position(('data', 0))

            # 画样本点
            plt.scatter(X1[:, 0], X1[:, 1], c='k', marker='o', label='good')
            plt.scatter(X0[:, 0], X0[:, 1], c='r', marker='x', label='bad')

            plt.xlabel('密度')
            plt.ylabel('含糖量')
            plt.legend(loc='upper right')

            # 画线
            x_temp = np.linspace(-0.05, 0.15)
            y_temp = x_temp * w[0, 1] / w[0, 0]
            plt.plot(x_temp, y_temp, '#808080', linewidth=1)

            wu = w / np.linalg.norm(w)

            # 画正负样本点的投影,真的没看懂哈哈哈
            X0_project = np.dot(X0, np.dot(wu.T, wu))
            plt.scatter(X0_project[:, 0], X0_project[:, 1], c='r', s=15)
            for i in range(X0.shape[0]):
                plt.plot([X0[i, 0], X0_project[i, 0]], [X0[i, 1], X0_project[i, 1]], '--r', linewidth=1)

            X1_project = np.dot(X1, np.dot(wu.T, wu))
            plt.scatter(X1_project[:, 0], X1_project[:, 1], c='k', s=15)
            for i in range(X1.shape[0]):
                plt.plot([X1[i, 0], X1_project[i, 0]], [X1[i, 1], X1_project[i, 1]], '--r', linewidth=1)

            # 均值向量的投影点
            ax.annotate(r'u0 投影点',
                        xy=(u0_project[:, 0], u0_project[:, 1]),
                        xytext=(u0_project[:, 0] - 0.2, u0_project[:, 1] - 0.1),
                        size=13,
                        va="center", ha="left",
                        arrowprops=dict(arrowstyle="->",
                                        color="k",
                                        )
                        )

            ax.annotate(r'u1 投影点',
                        xy=(u1_project[:, 0], u1_project[:, 1]),
                        xytext=(u1_project[:, 0] - 0.1, u1_project[:, 1] + 0.1),
                        size=13,
                        va="center", ha="left",
                        arrowprops=dict(arrowstyle="->",
                                        color="k",
                                        )
                        )
            
            plt.axis("equal")  # 两坐标轴的单位刻度长度保存一致
            plt.show()
            
        self.w = w
        self.u0 = u0
        self.u1 = u1
        return self

    def predict(self, X):
        # 各样本在的投影
        project = np.dot(X, self.w.T)
        # 均值投影
        wu0 = np.dot(self.w, self.u0.T)
        wu1 = np.dot(self.w, self.u1.T)

        return (np.abs(project - wu1) < np.abs(project - wu0)).astype(int)


if __name__=='__main__':
    data_path = r'C:\***\ch3--线性模型\3.3\watermelon3_0_Ch.csv'
    data = pd.read_csv(data_path).values
    # print(data)

    X = data[:, 7:9].astype(float)
    y = data[:, 9]

    y[y == '是'] = 1
    y[y == '否'] = 0
    y = y.astype(int)

    lda = LDA()
    lda.fit(X, y, plot_=True)
    # 根据LDA的进行预测
    print(lda.predict(X)[:, 0])
    # 样本标记
    print(y)

可视化呈现

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值