Java实现KNN

该文详细介绍了如何使用Java实现KNN(K近邻)算法,包括数据集的读取、训练集与测试集的划分、距离计算(曼哈顿距离和欧氏距离)、获取k个最近邻记录、单条记录预测以及模型的预测准确率计算。代码中包含了实例数据集的读取、随机打乱数据集、预测过程以及评估预测准确性的方法。

算法介绍

  1. 在特征空间中统计k个距离最近的样本的标签,选择最多的标签最为自己的标签。
  2. 可以采用多种距离计算策略,如曼哈顿距离、欧氏距离。
  3. KNN是非参且惰性的。
  4. 优点:实现简单、训练快(惰性)、效果好、对异常值不敏感
  5. 缺点:时空复杂度都高、需要合适的归一化等

算法流程

变量准备

在这里准备了两种距离策略、训练集、测试集、验证集和k值。

	//两种距离
    public static final int MANHATTAN = 0;
    public static final int EUCLIDEAN = 1;
    //距离策略
    public int distanceMeasure = EUCLIDEAN;
    //设置随机种子
    public static final Random random = new Random();
    //设置k值
    int numNeighbors = 7;
    //数据集
    Instances dataset;
    //训练集、测试集、验证集数组
    int[] trainingSet;
    int[] testingSet;
    int[] predictions;

用Instances类读取数据集

使用构造方法,读取数据集并存入Instances类中。

    /**
     * 构造方法,用Instances类读取数据集
     * @param paraFilename 数据集地址
     */
    public Knn(String paraFilename) {
   
   
        try {
   
   
            FileReader fileReader = new FileReader(paraFilename);
            //这里用Instances读取数据集
            dataset = new Instances(fileReader);
            //将当前Istances类的标签下标设置为(数据集的属性数-1)?
            dataset.setClassIndex(dataset.numAttributes() - 1);
            fileReader.close();
        } catch (FileNotFoundException e) {
   
   
            e.printStackTrace();
        } catch (IOException e) {
   
   
            e.printStackTrace();
        }
    }

打乱数据集并划分训练集与测试集

在这里将原本的数据集随机交换打乱,再按照比例将数据集划分为训练集与测试集。

    /**
     * 获得一个随机序列
     * @param paraLength 序列长度
     * @return 打乱后的序列
     */
    public static int[] getRandomIndices(int paraLength) {
   
   
        int[] resultIndices = new int[paraLength];
        //按序号赋值
        for(int i = 0; i < paraLength; i++) {
   
   
            resultIndices[i] = i;
        }
        //随机交换打乱
        int tempFirst, tempSecond, tempValue;
        for(int i = 0; i < paraLength; i++) {
   
   
            tempFirst = random.nextInt(paraLength);
            tempSecond = random.nextInt(paraLength);
            tempValue = resultIndices[tempFirst];
            resultIndices[tempFirst] = resultIndices[tempSecond];
            resultIndices[tempFirst] = tempValue;
        }
        return resultIndices;
    }
    /**
     * 划分训练集与验证集,输入训练集的占比,将数据集的下标放入trainingSet和testingSet中
     * @param paraTrainingFraction 训练集比例
     */
    public void splitTrainingTesting(double paraTrainingFraction) {
   
   
        //获得数据集记录数
        int tempSize = dataset.numInstances();
        //得到长度为tempSize的随机序列
        int[] tempIndices = getRandomIndices(tempSize);
        //获得训练集记录数
        int tempTrainingSize = (int)(tempSize * paraTrainingFraction);
        //声明训练集与测试集
        trainingSet = new int[tempTrainingSize];
        testingSet = new int[tempSize - tempTrainingSize];
        //为训练集与测试集赋值
        for (int i = 0; i < tempTrainingSize; i++) {
   
   
            trainingSet[i] = tempIndices[i];
        }
        for (int i = 0; i < tempSize - tempTrainingSize; i++) {
   
   
            testingSet[i] = tempIndices[tempTrainingSize + i];
        }
    }

距离计算

这里规定了两种距离,并实现了两种距离的计算。

    /**
     * 计算两条记录之间对应的距离
     * @param paraI 一个记录下标
     * @param paraJ 另一条记录下标
     * @return 距离结果
     */
    public double distance(int paraI, int paraJ) {
   
   
        double resultDistance = 0;
        double tempDifference;
        //两种距离策略分开计算
        switch (distanceMeasure) {
   
   
            case MANHATTAN:
                //对除去标签以外的所有属性进行处理
                for (int i = 0; i < dataset.numAttributes() - 1; i
KNN算法的思想是在训练集中数据和标签已知的情况下,输入测试数据,将测试数据的特征与训练集中对应的特征进行相互比较,找到训练集中与之最为相似的前K个数据,该测试数据对应的类别就是K个数据中出现次数最多的那个分类[^2]。 使用Java实现KNN算法的大体思路如下: 1. 输入所有已知点。 2. 输入未知点。 3. 计算所有已知点到未知点的欧式距离。 4. 根据距离对所有已知点排序。 5. 选出距离未知点最近的k个点。 6. 计算k个点所在分类出现的频率。 7. 选择频率最大的类别即为未知点的类别[^3]。 以下是一个简单的Java代码示例来实现KNN算法: ```java import java.util.*; // 数据类 class Data { double[] features; String label; public Data(double[] features, String label) { this.features = features; this.label = label; } } // 算法类 class KNN { private List<Data> trainingData; public KNN(List<Data> trainingData) { this.trainingData = trainingData; } // 计算欧式距离 private double euclideanDistance(double[] features1, double[] features2) { double sum = 0; for (int i = 0; i < features1.length; i++) { sum += Math.pow(features1[i] - features2[i], 2); } return Math.sqrt(sum); } // 预测类别 public String predict(double[] testFeatures, int k) { // 存储距离和对应的标签 List<Map.Entry<Double, String>> distances = new ArrayList<>(); // 计算所有已知点到未知点的距离 for (Data data : trainingData) { double distance = euclideanDistance(data.features, testFeatures); distances.add(new AbstractMap.SimpleEntry<>(distance, data.label)); } // 根据距离排序 distances.sort(Map.Entry.comparingByKey()); // 选出距离最近的k个点 Map<String, Integer> labelCount = new HashMap<>(); for (int i = 0; i < k; i++) { String label = distances.get(i).getValue(); labelCount.put(label, labelCount.getOrDefault(label, 0) + 1); } // 选择频率最大的类别 String predictedLabel = null; int maxCount = 0; for (Map.Entry<String, Integer> entry : labelCount.entrySet()) { if (entry.getValue() > maxCount) { maxCount = entry.getValue(); predictedLabel = entry.getKey(); } } return predictedLabel; } } // 测试类 class KNNTest { public static void main(String[] args) { // 训练数据 List<Data> trainingData = new ArrayList<>(); trainingData.add(new Data(new double[]{1, 2}, "A")); trainingData.add(new Data(new double[]{2, 3}, "A")); trainingData.add(new Data(new double[]{8, 7}, "B")); trainingData.add(new Data(new double[]{7, 8}, "B")); // 创建KNN实例 KNN knn = new KNN(trainingData); // 测试数据 double[] testFeatures = new double[]{3, 4}; int k = 3; // 预测类别 String predictedLabel = knn.predict(testFeatures, k); System.out.println("预测类别: " + predictedLabel); } } ```
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值