KNN算法的实现

  • 相关知识点
  1. K近邻算法(KNN):是一种基于实例的监督学习算法。其基本思想是:对于每一个待分类的样本,找出其在训练集中的K个最近邻样本,然后根据这些邻居的类别进行投票,返回出现频率最高的类别作为预测结果。
  2. 欧几里得距离:用于衡量两点间的直线距离,常用于KNN算法中计算样本之间的相似度。计算公式为:

其中 xix_ixi​ 和 yiy_iyi​ 是样本的第i个特征值,n是特征的维度。

  1. 训练集与测试集划分:通常将数据集划分为训练集和测试集,训练集用于模型的训练,测试集用于评估模型的性能。常见的划分比例是 80% 用于训练,20% 用于测试。
  2. 准确率(Accuracy):分类模型评估的一个重要指标,表示正确分类的样本占总样本的比例:

  • 实验分析
  1. 数据加载:使用 pandas.read_csv 函数加载鸢尾花数据集,并将其分为特征(X)和标签(y)。特征包括花萼长度、花萼宽度、花瓣长度、花瓣宽度,标签为花的类别。
  2. 数据划分:将数据集划分为训练集和测试集。通过 train_test_split 将 80% 的数据作为训练集,20% 作为测试集。这是为了避免过拟合并验证模型在未知数据上的表现。
  3. KNN分类:对于每个测试集的样本,计算它与训练集样本之间的欧几里得距离。按照距离从小到大排序,选取最近的 K 个邻居。通过邻居的标签进行投票,返回出现次数最多的标签作为预测类别。
  4. 准确率计算:根据 KNN 的预测结果,统计预测正确的样本数,计算准确率。
  5. 单个样本预测:从测试集中选取一个样本,展示其特征,真实标签,并对其进行分类预测,显示预测结果。
  • 实验代码
  • import numpy as np
    import pandas as pd
    from sklearn.model_selection import train_test_split
    import operator
    
    # 1. 定义KNN算法分类函数
    def classify0(in_data, feature_group, label, k):
        """
        k近邻算法的分类器
        @param in_data: 测试数据点
        @param feature_group: 训练数据的特征集合
        @param label: 训练数据对应的类别标签
        @param k: 最近邻的数量
        @return: 返回预测类别
        """
        dim_len = feature_group.shape[0]  # 获取训练数据的数量
        # 计算测试数据与训练数据的欧几里得距离
        sub_data = np.tile(in_data, (dim_len, 1)) - feature_group  # 广播机制扩展测试点
        sq_sub_data = sub_data ** 2  # 每个特征值平方
        sum_sq_sub_data = sq_sub_data.sum(axis=1)  # 按行求和
        geom_distance = sum_sq_sub_data ** 0.5  # 开方得到欧几里得距离
    
        # 按距离从小到大排序,获取排序后的索引值
        index_sorted_gd = geom_distance.argsort()
    
        # 对k个最近的标签计数
        label_freq = {}
        for i in range(k):
            label_value = label[index_sorted_gd[i]]
            label_freq[label_value] = label_freq.get(label_value, 0) + 1
    
        # 对标签计数进行排序,并返回出现次数最多的标签
        sorted_class_label = sorted(label_freq.items(), key=operator.itemgetter(1), reverse=True)
        return sorted_class_label[0][0]
    
    
    # 2. 加载鸢尾花数据集
    def load_iris_data(file_path):
        """
        加载Iris数据集
        @param file_path: Iris数据集文件路径
        @return: 特征和标签
        """
        try:
            data = pd.read_csv(file_path, index_col=0)  # 加载数据集并移除索引列
            print("数据集加载成功!数据预览:")
            print(data.head())
            X = data.iloc[:, :-1].values  # 提取特征
            y = data.iloc[:, -1].values  # 提取标签
            return X, y
        except FileNotFoundError:
            print("错误:找不到文件,请检查路径是否正确。")
            exit()
    
    
    # 3. 评估KNN算法性能
    def evaluate_knn(X, y, k=3):
        """
        使用KNN算法对数据集进行分类,并计算准确率
        @param X: 数据集特征
        @param y: 数据集标签
        @param k: 最近邻的数量
        @return: 准确率
        """
        # 随机划分训练集和测试集(80%训练集,20%测试集)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
        correct = 0
        for i in range(len(X_test)):
            prediction = classify0(X_test[i], X_train, y_train, k)
            if prediction == y_test[i]:
                correct += 1
    
        accuracy = correct / len(y_test)
        return accuracy
    
    
    # 4. 主函数:运行实验
    if __name__ == "__main__":
        # 加载数据
        file_path = "iris.csv"  # 修改为实际文件路径
        X, y = load_iris_data(file_path)
    
        # 设置K值
        k = 3
    
        # 评估分类准确率
        accuracy = evaluate_knn(X, y, k)
        print(f"\nKNN分类算法在测试集上的准确率为: {accuracy:.2f}")
    
        # 示例:预测单个测试样本
        print("\n预测一个测试样本的类别")
        example_index = 0
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        prediction = classify0(X_test[example_index], X_train, y_train, k)
        print(f"测试样本特征: {X_test[example_index]}")
        print(f"真实类别: {y_test[example_index]}")
        print(f"预测类别: {prediction}")
    
    

  • 运行截图
  • 实验总结
  • 模型效果:KNN算法能够在鸢尾花数据集上有效地进行分类。通过合适的K值,模型可以达到较高的准确率。在本次实验中,选择的K值为3,通常K值较小会增加模型的灵活性,但也容易过拟合,较大的K值则可能导致欠拟合。

    计算开销:KNN是一个懒惰学习算法,即没有明确的训练过程,所有的计算都发生在预测阶段。在处理大数据集时,KNN可能会非常慢,因为每个预测都需要计算与所有训练样本的距离。

    优化建议:为了提高KNN算法的性能,可以考虑使用更高效的距离计算方法,或者通过降维(如PCA)减少特征空间的维度,减少计算开销。此外,选择合适的K值至关重要,应根据具体情况调优。
  • 出现的问题:最初编写代码与试运行时结果一直运行不了,提醒是关于数据处理的问题
  • 解决方案:仔细检查代码逻辑后,发现是因为没有数据处理的步骤来处理缺失值,给代码加上相关代码逻辑处理后能够正常运行了。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值