《机器学习》课后习题3.5 编辑实现线性判别分析,并给出西瓜数据集 3.0α 上的结果.

本文介绍了一个简单的线性判别分析(LDA)实现,并使用西瓜数据集3.0α进行了演示。该文详细展示了如何通过Python计算类内散度矩阵和类间散度矩阵,并给出了数据可视化结果。

参考了han同学的答案,西瓜数据集也可在han同学的github上下载。

3.5 编辑实现线性判别分析,并给出西瓜数据集 3.0α 上的结果.

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
# 下面这两行是为了让matplotlib画图时的汉字正常显示
import matplotlib
matplotlib.rc("font",family='YouYuan')

class LDA(object):

    # 绘图,求出均值向量,根据公式3.34和3.39求出类内散度矩阵和类间散度矩阵
    def fit(self, X_, y_, plot_=False):
        # 取出正反例各自数据,计算均值向量

        neg = y_ == 0
        pos = y_ == 1
        X0 = X_[neg]
        X1 = X_[pos]

        # 均值向量,(1, 2)
        u0 = X0.mean(0, keepdims=True)
        u1 = X1.mean(0, keepdims=True)

        # 类内散度矩阵,公式3.33,(2, 2)
        sw = np.dot((X0 - u0).T, (X0 - u0)) + np.dot((X1 - u1).T, (X1 - u1))
        # 类间散度矩阵,公式3.37,(1, 2)
        w = np.dot(np.linalg.inv(sw), (u0 - u1).T).reshape(1, -1)

        # 绘图
        if plot_:
            fig, ax = plt.subplots()
            ax.spines['right'].set_color('none')
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值