kNN算法是一种分类算法,对于给定标签的训练集,计算新数据与训练集示例的距离,统计最近的k个示例,如果多数属于某一类,则新示例属于该类。
1.自定义函数实现kNN算法
import numpy as np from sklearn.model_selection import train_test_split from sklearn import datasets
准备数据
iris = datasets.load_iris() X = iris.data y = iris.target X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=100)
定义kNN类
class kNN(): ''' k:近邻个数 X_train:训练集 y_train:训练标签 x:新示例 ''' def __init__(self, k, X_train, y_train, x): self.k = k self.X = X_train self.y = y_train self.x = x #计算未分类样本与已知类别样本的欧式距离 def distance(self): from numpy import sqrt,sum distances = [] for x_train in self.X: d = sqrt(sum((self.x - x_train)**2)) distances.append(d) return distances #对距离进行排序,选出最近的k个示例 def sort(self): from numpy import argsort distances = kNN.distance(self) d_sorted = argsort(distances)[:self.k] return self.y[d_sorted] #找到最多的类别 def knn_classify(self): y_knn = kNN.sort(self) from collections import Counter y_knn = kNN.sort(self) votes = Counter(y_knn) return votes.most_common(1)[0][0]
实例化
print('预测值\t实际值') for x, y_ in zip(X_test, y_test): knn = kNN(6, X_train, y_train, x) yhat = knn.knn_classify() print(f'{yhat}\t{y_}'
预测值 实际值 2 2 0 0 2 2 0 0 2 2 2 2 0 0 0 0 2 2 0 0 0 0 2 2 0 0 0 0 2 2
2.scikit-learn中的kNN算法
准备数据
from sklearn import datasets iris = datasets.load_iris() X = iris.data y = iris.target
划分训练集和测试集
from sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=100)
数据归一化
from sklearn.preprocessing import StandardScaler standardscaler = StandardScaler() standardscaler.fit(X_train) X_train_s = standardscaler.transform(X_train) X_test_s = standardscaler.transform(X_test)
预测
from sklearn.neighbors import KNeighborsClassifier kNN_classify = KNeighborsClassifier(6) kNN_classify.fit(X_train_s, y_train) yhat = kNN_classify.predict(X_test_s) print(f'预测值:{yhat}') print(f'实际值:{y_test}') print(f'准确率:{kNN_classify.score(X_test_s, y_test)}')
预测值:[2 0 2 0 1 2 0 0 2 0 0 2 0 0 2 1 1 1 2 2 2 0 2 0 1 2 1 0 1 2] 实际值:[2 0 2 0 2 2 0 0 2 0 0 2 0 0 2 1 1 1 2 2 2 0 2 0 1 2 1 0 1 2] 准确率:0.9666666666666667