K-Nearest Neighbor算法实现java

文章介绍了K-NearestNeighbor(KNN)算法的实现过程,包括数据集的构建、分类流程、距离度量和分类器评估。通过实验,发现当p值取0.5,k值取10时,KNN的分类效果最佳。代码示例展示了数据读取和距离计算的实现,并提供了测试集的正确率计算方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

K-Nearest Neighbor算法实现

要求

在这里插入图片描述

数据

数据集中一共12000多条数据,每个数据包含16个特征,1个标签(该条数据对应的种子类别),一共有7类种子。
每个特征都为定距数据,即:取值范围为连续取值的数值数据。
部分特征是通过其他特征计算出来。
各类种子的个数如下:
Seker(2027), Barbunya(1322), Bombay(522), Cali(1630), Dermosan(3546), Horoz(1928) ,Sira(2636)。

本实验中将其按37分拆分为测试集:4086条, 训练集:9525条。

KNN模型介绍

KNN算法是一种模式识别方法,根据对象进行分类。一个样本与数据集中的k个样本最相似, 如果这k个样本中的大多数属于某一个类别, 则该样本也属于这个类别。也就是说,该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。
邻近算法,或者说K最邻近(KNN,K-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是K个最近的邻居的意思,说的是每个样本都可以用它最接近的K个邻近值来代表。近邻算法就是将数据集合中每一个记录进行分类的方法。
该方法的不足之处是计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最邻近点。

KNN算法的核心思想是,如果一个样本在特征空间中的K个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。KNN方法在类别决策时,只与极少量的相邻样本有关。由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。

算法流程:

总体来说,KNN分类算法包括以下4个步骤:

  1. 准备数据,对数据进行预处理。
  2. 计算测试样本点(也就是待分类点)到其他每个样本点的距离
  3. 对每个距离进行排序,然后选择出与当前点距离最小的K个点。
  4. 对K个点所属的类别进行比较,根据少数服从多数的原则,将测试样本点归入在K个点中占比最高的那一类。

距离度量公式如下:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ufPJdbkN-1674289743074)(C:\Users\He\AppData\Roaming\Typora\typora-user-images\image-20221217041659272.png)]

算法框图如下:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-tFRZnxV8-1674289743075)(C:\Users\He\AppData\Roaming\Typora\typora-user-images\image-20221217042234278.png)]

代码简介:

public static Dataset BuildData(String pathname, int sort){
        // 创建数据集
        Dataset data = new DefaultDataset();
        File file = new File(pathname);
        try{
            BufferedReader textFile = new BufferedReader(new FileReader(file));
            String lineDta = "";

            while ((lineDta = textFile.readLine()) != null){
                double[] values = new double[16];
                int i = 0;
                StringTokenizer st = new StringTokenizer(lineDta, ",");
                while (st.hasMoreElements()) {
                    values[i] = Double.parseDouble((String) st.nextElement());
                    i++;
                    if(i == 16){
                        int sortE = Integer.parseInt((String) st.nextElement());
                        if( sortE == sort){
                            Instance instance = new DenseInstance(values, sort);
                            data.add(instance);
                        }

                    }
                }
            }
        }catch (FileNotFoundException e){
            System.out.println("没有找到指定文件");
        }catch (IOException e){
            System.out.println("文件读写出错");
        }
        return data;
    }

该方法从文件中根据各分类建立数据集, 方便后续从各分类随机37分组织训练集和测试集.

public static void Acc(Dataset data, Classifier knn)
    {
        int correct = 0, wrong = 0;

        for (Instance inst : data) {
            Object predictedClassValue = knn.classify(inst);
            Object realClassValue = inst.classValue();
            if (predictedClassValue.equals(realClassValue))
                correct++;
            else
                wrong++;
        }
        double acc = correct * 1.0 / (correct + wrong);
        System.out.println("正确率为: " + acc);
    }

该方法读取测试集及训练好的knn分类器, 并输出正确率.

public double p;
    @Override
    public double measure(Instance x, Instance y) {
        Instance temp = x.add(y.multiply(-1));
        temp = temp.multiply(temp);
        temp = temp.sqrt();
        if(this.p == 0.5){
            temp = temp.sqrt();
            double sum = 0;
            for(double i:temp){
                sum += i;
            }
            return sum * sum;
        } else if (this.p == 1) {
            double sum = 0;
            for(double i:temp){
                sum += i;
            }
            return sum;
        } else{
            temp = temp.multiply(temp);
            double sum = 0;
            for(double i:temp){
                sum += i;
            }
            return Math.sqrt(sum);
        }
    }

该方法重写了距离测量方法, 根据实验要求分别对q = 0.5 q = 1 和 q = 2 做了实现.

结论:

本次实验的结果如下:
在这里插入图片描述

将其绘制为带折线的散点图如下:

[外链图片转存中...(img-mhANHQIQ-1674289743076)]

由图可以看到, p值取0.5时, 正确率均高于取1和2时, 且在k值取10时, 正确率高于取其他值, 由此可以初步确认在p值取0.5, k值取10左右时, KNN的分类正确率最高。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值