参考了han同学的答案,西瓜数据集也可在han同学的github上下载。
3.5 编辑实现线性判别分析,并给出西瓜数据集 3.0α 上的结果.
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
# 下面这两行是为了让matplotlib画图时的汉字正常显示
import matplotlib
matplotlib.rc("font",family='YouYuan')
class LDA(object):
# 绘图,求出均值向量,根据公式3.34和3.39求出类内散度矩阵和类间散度矩阵
def fit(self, X_, y_, plot_=False):
# 取出正反例各自数据,计算均值向量
neg = y_ == 0
pos = y_ == 1
X0 = X_[neg]
X1 = X_[pos]
# 均值向量,(1, 2)
u0 = X0.mean(0, keepdims=True)
u1 = X1.mean(0, keepdims=True)
# 类内散度矩阵,公式3.33,(2, 2)
sw = np.dot((X0 - u0).T, (X0 - u0)) + np.dot((X1 - u1).T, (X1 - u1))
# 类间散度矩阵,公式3.37,(1, 2)
w = np.dot(np.linalg.inv(sw), (u0 - u1).T).reshape(1, -1)
# 绘图
if plot_:
fig, ax = plt.subplots()
ax.spines['right'].set_color('none')

本文介绍了一个简单的线性判别分析(LDA)实现,并使用西瓜数据集3.0α进行了演示。该文详细展示了如何通过Python计算类内散度矩阵和类间散度矩阵,并给出了数据可视化结果。
最低0.47元/天 解锁文章
4332

被折叠的 条评论
为什么被折叠?



