【漫话机器学习系列】038.KNN算法(Does K-NN Learn)

KNN算法:K-近邻算法(K-Nearest Neighbors Algorithm)

KNN(K-Nearest Neighbors)是一种简单且广泛使用的监督学习算法,常用于分类和回归问题。它基于“相似样本具有相似输出”的思想,通过计算样本点之间的距离来进行预测。


算法的基本思想

KNN算法的核心是:

  1. 定义距离:计算待预测样本与训练样本之间的距离。
  2. 选择邻居:选择距离最近的 K 个样本。
  3. 输出结果
    • 分类问题:选择 K 个邻居中出现次数最多的类别作为预测结果。
    • 回归问题:返回 K 个邻居的平均值或加权平均值作为预测结果。

KNN算法的步骤

  1. 数据准备
    • 收集并整理训练数据集。
  2. 距离计算
    • 对于待预测样本,计算其与训练集中每个样本的距离。
    • 常见距离公式:
      • 欧几里得距离: 

        d(x, y) = \sqrt{\sum_{i=1}^n (x_i - y_i)^2}
      • 曼哈顿距离: 

        d(x, y) = \sum_{i=1}^n |x_i - y_i|
      • 其他距离:如余弦相似度、闵可夫斯基距离等。
  3. 选择邻居
    • 选出距离最近的 K 个样本。
  4. 投票或加权
    • 分类:统计 K 个邻居中每个类别的样本数量,选择数量最多的类别。
    • 回归:计算 K 个邻居的平均值或加权平均值。
  5. 输出预测结果

K值的选择

  • 较小的 K
    • 更注重局部模式,容易导致模型对噪声敏感(过拟合)。
  • 较大的 K
    • 模型更平滑,鲁棒性更强,但可能忽略局部模式(欠拟合)。
  • 一般选择
    • K 值通常为奇数,以避免分类问题中的投票平局。
    • 使用交叉验证选择最优 K。

KNN的优缺点

优点
  1. 简单直观,易于实现。
  2. 不需要训练过程,适合多分类问题。
  3. 对异常数据不敏感。
缺点
  1. 计算复杂度高,特别是在大数据集上。
  2. 存储成本较高,需保存所有训练数据。
  3. 对维度敏感,容易受到“维度灾难”的影响。
  4. 对样本分布不均匀的数据效果较差。

应用场景

  1. 分类问题
    • 如手写数字识别、文本分类、图像分类等。
  2. 回归问题
    • 如预测房价、气温等连续变量。
  3. 推荐系统
    • 基于用户相似度的推荐算法。

Python实现

分类示例
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

# 加载数据
iris = load_iris()
X, y = iris.data, iris.target

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建KNN分类器
knn = KNeighborsClassifier(n_neighbors=3)

# 训练模型
knn.fit(X_train, y_train)

# 预测
y_pred = knn.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"准确率: {accuracy}")

运行结果

准确率: 1.0

回归示例


from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsRegressor
from sklearn.metrics import mean_squared_error

# 生成数据
X, y = make_regression(n_samples=100, n_features=1, noise=15, random_state=42)

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建KNN回归器
knn_regressor = KNeighborsRegressor(n_neighbors=3)

# 训练模型
knn_regressor.fit(X_train, y_train)

# 预测
y_pred = knn_regressor.predict(X_test)

# 计算均方误差
mse = mean_squared_error(y_test, y_pred)
print(f"均方误差: {mse}")

运行结果

均方误差: 287.67115500435614


总结

KNN是一种基于实例的非参数学习算法,其核心思想是“就近原则”。虽然它简单直观,但计算复杂度高且对高维数据效果较差。在实际应用中,需根据数据分布、特征维度和问题特点合理调整 K 值,并结合其他技术(如降维和数据归一化)提高性能。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值