import numpy as np
import pandas as pd
对数据集操作
data=pd.read_csv(r"F:\数据集\Iris数据集\iris.csv")
data.drop(["Unnamed: 0","Species"],axis=1,inplace=True)
data.drop_duplicates(inplace=True)
编写KNN类
class KNN:
"""
使用python实现KNN算法。(回归预测)
该算法用于回归预测,根据前三个特征属性,寻找最近的K个邻居,
然后再根据K个邻居的第四个特征属性,去预测当前样本的 第四个特征值
"""
def __init__(self,k):
"""
初始化方法
Parameters:
____________
k:邻居的个数
"""
self.k=k
def fit(self,X,y):
"""
训练方法
Parameters:
__________
X:类数组类型(特征矩阵),形状为[样本数量,特征数量]
待训练的样本特征(属性)
y:类数组类型(目标标签),形状为[样本数量]
每个样本的目标值(标签)
"""
self.X=np.asarray(X)
self.y=np.asarray(y)
def predict(self,X):
"""
根据参数传递的X,对样本呢进行预测
Parameters:
___________
X:类数组类型,形状为[样本数量,特征数量]
待测试的样本特征(属性)
return:数组类型
预测的结果值
"""
X=np.asarray(X)
result=[]
for x in X:
dis=np.sqrt(np.sum((x-self.X)**2,axis=1))
index=dis.argsort()
index=index[:self.k]
result.append(np.mean(self.y[index]))
return np.array(result)
def predict2(self,X):
"""
根据参数传递的X,对样本呢进行预测(考虑权重)
权重的计算方式:
使用每个节点(邻居)距离的倒数/所有节点距离倒数之和
Parameters:
___________
X:类数组类型,形状为[样本数量,特征数量]
待测试的样本特征(属性)
return:数组类型
预测的结果值
"""
X=np.asarray(X)
result=[]
for x in X:
dis=np.sqrt(np.sum((x-self.X)**2,axis=1))
index=dis.argsort()
index=index[:self.k]
s=np.sum(1/(dis[index]+0.001))
weight=(1/(dis[index]+0.001))/s
result.append(np.sum(self.y[index]*weight))
return np.array(result)
训练与测试
t=data.sample(len(data),random_state=0)
train_X=t.iloc[:120,:-1]
train_y=t.iloc[:120,-1]
test_X=t.iloc[120:,:-1]
test_y=t.iloc[120:,-1]
knn=KNN(k=3)
knn.fit(train_X,train_y)
result=knn.predict(test_X)
display(result)
np.mean(np.sum(result-test_y)**2)
display(test_y.values)
array([0.2 , 2.06666667, 0.2 , 1.93333333, 1.26666667,
1.2 , 1.23333333, 2. , 1.13333333, 1.93333333,
2.03333333, 1.83333333, 1.83333333, 0.2 , 1.16666667,
2.26666667, 1.63333333, 0.3 , 1.46666667, 1.26666667,
1.66666667, 1.33333333, 0.26666667, 0.23333333, 0.2 ,
2.03333333, 1.26666667, 2.2 , 0.23333333])
array([0.2, 1.6, 0.2, 2.3, 1.3, 1.2, 1.3, 1.8, 1. , 2.3, 2.3, 1.5, 1.7,
0.2, 1. , 2.1, 2.3, 0.2, 1.3, 1.3, 1.8, 1.3, 0.2, 0.4, 0.1, 1.8,
1. , 2.2, 0.2])
可视化
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams["font.family"]="SimHei"
mpl.rcParams["axes.unicode_minus"]=False
plt.figure(figsize=(8,8))
plt.plot(result,"ro-",label="预测值")
plt.plot(test_y.values,"go--",label="真实值")
plt.title("KNN 连续值预测展示")
plt.xlabel("节点序号")
plt.ylabel("花瓣宽度")
plt.legend()
plt.show()
