【Java实现KNN算法】KNN(k邻近)详解与java实现

本文详细介绍KNN算法原理及其Java实现过程,包括加载样本数据、确定K值及预测对象、计算距离、返回最近K个对象及预测结果。同时,探讨了KNN算法的应用范围。

            1.KNN算法

       1.KNN算法

        KNN法最初由Cover 和Hart 于1968 年提出, 是一个理论上比较成熟的方法。KNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的K(K=1,2,3…,n,其中,n<=D)个样本的类别来决定待分样本所属的类别。由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合

     2.步骤

      KNN算法是比较简单的,理解起来也不难,具体步骤如下图:

      上图中D表示训练集,其中P表示特征值Ti所描述的的对象(或者类型);Xik表示Xi的第k个特征值(Wk类似);GroupBy()对p进行分类统计(小编借用SQL语言里面的关键字,哈哈)。哦!对了,k的取值是有学问,具体如何取,取多少,就要你针对你研究的问题多试验几次,或者你是个老司机,早已摸清待研究问题的套路。大概思路就是上面这5步,大家可以一步两步,一步两步,是魔鬼的步伐浪起来~~~~~

        3.java实现

       根据上述的五个步骤,分步实现。在没开始前,首先要对研究对象进行抽象,具体如下:
public class KNNnode implements Comparable<KNNnode>{

	/**
	 * 实现comparable接口重写compareTo()方法
	 * 目的:方便存放KNNnode对象的List进行排序,排序的目标属性为l(即与待测点距离)
	 */
	float x1,x2;  //特征值
	String type;  //特征值对应的类型
	double l ;     //与待预测点的距离
	
	public float getX1() {
		return x1;
	}

	public void setX1(float x1) {
		this.x1 = x1;
	}

	public float getX2() {
		return x2;
	}


	public void setX2(float x2) {
		this.x2 = x2;
	}


	public String getType() {
		return type;
	}


	public void setType(String type) {
		this.type = type;
	}


	public double getL() {
		return l;
	}


	public void setL(double l) {
		this.l = l;
	}

	@Override
	public String toString() {
		return "KNNnode [x1=" + x1 + ", x2=" + x2 + ", type=" + type + ", l="
				+ l + "]";
	}
	//从大到小排列
	@Override
	public int compareTo(KNNnode o) {
		// TODO Auto-generated method stub
		if(l>o.getL())
			return -1;
		if(l<o.getL())
			return 1;
		return 0;
	}
}
        上面抽象的类,描述的对象有两个特征值,即X 1, X 2 ,type是特征值对应的类型。现在就可以用KNNnode来描述或者用历史数据进行生成对象了。基础已经打好了,直接进入五步法了:
          ①加载样本数据进行训练
         /**
	 * 从txt文本中读取KNNnode所需数据并存放在List中
	 * @param url  txt文本存放路径
	 * @return  
	 */
	public List<KNNnode> ReadKNNnodeFromFile(String url){
		List<KNNnode> node = new ArrayList<KNNnode>();
		String st = "";
		File file = new File(url);
		if(file.isFile()){
			try{
				BufferedReader reader = new BufferedReader(new FileReader(file));
				while((st=reader.readLine())!=null){
					//用空格对字符串进行分割,s+可以匹配多个空格
					String val[] = st.split("\\s+");
					if(val.length==3){
						KNNnode knNnode = new KNNnode();
						knNnode.setX1(Float.parseFloat(val[0]));
						knNnode.setX2(Float.parseFloat(val[1]));
						knNnode.setType(val[2]);
						node.add(knNnode);
					}
				}
				reader.close();
			}catch(IOException e){
				e.printStackTrace();
			}catch (Exception e) {
				// TODO: handle exception
				e.printStackTrace();
			}
		}else{
			System.out.println("文件不存在!");
		}
		return node;
	}
        txt文本存放的数据是以每个对象里面的属性为一行,如下图:

        使用IO流从文本中获取数据来生成对象非常方便,只需要按照上面这种模式,一个循环就可以依次读出数据生成KNNnode类并存放在List里面。
        ②确定K以及待分类(确定/预测)的对象。这个比较简单。
                final int K = 3;
                KNNnode kn1 = new KNNnode();
		kn1.setX1(22);
		kn1.setX2(17);
        ③计算距离(相似度)
	/**
	 * 按照欧式距离公式,计算历史数据与待预测对象之间的距离
	 * @param node l;训练集样本对象
	 * @param kn1    待预测
	 * @return
	 */
	public List<KNNnode> calcul(List<KNNnode> node,KNNnode kn1){
		for(int i=0;i<node.size();i++){
			KNNnode kn2 = node.get(i);
			kn2.setL(Math.sqrt(Math.pow(kn1.getX1()-kn2.getX1(), 2)+Math.pow(kn1.getX2()-kn2.getX2(), 2)));
		}
		return node;
	}
         在抽象对象时,有一个属性l是专门为了存放训练样本与带预测对象之间的距离或者相似度。
         ④返回距离(相似度)最近(高)的K个对象
/**
	 * 对k个KNNnode的类型进行分类统计
	 * 使用Map,借助map键值对的存储方式,所以非常方便
	 * @param node
	 * @param k
	 * @return
	 */
	public Map<String,List<KNNnode>> result(List<KNNnode> node,int k){
		Map<String,List<KNNnode>> knnmap = new HashMap<String,List<KNNnode>>();
		System.out.println("---------------------K个最小的KNNnode对象-------------------");
		for(int i=0;i<node.size();i++){
			System.out.println(node.get(i).toString());
		}

		for(int i=0;i<k;i++){
			String type = node.get(i).getType().trim();
			if(knnmap.containsKey(type)){
				knnmap.get(type).add(node.get(i));
				Collections.sort(knnmap.get(type));
			}else{
				knnmap.put(type, new ArrayList<KNNnode>());
				knnmap.get(type).add(node.get(i));
			}
		}
		
		return knnmap;
	}
          其次,待预测的type值就等于频数高的type。
	public static void main(String args[]) {
		final String file_url = "D:\\knn.txt";
		final int K = 5;
		String type = "";
		List<KNNnode> node = new ArrayList<KNNnode>();
		// 待归类的node
		KNNnode kn1 = new KNNnode();
		kn1.setX1(33);
		kn1.setX2(12);
		KNNprocess knNprocess = new KNNprocess();
		node = knNprocess.ReadKNNnodeFromFile(file_url);
		node = knNprocess.calcul(node, kn1);
		node = knNprocess.getnodeDESC(node, K);
		double l = node.get(0).getL();
		int s = 0;
		Map<String, List<KNNnode>> knn = knNprocess.result(node, K);
		for(Map.Entry<String, List<KNNnode>> en:knn.entrySet()){
			int s1= en.getValue().size();
			if(s1>s){
				l = ((KNNnode)(en.getValue().get(s1-1))).getL();
				s = s1;
				type = en.getKey();
			}else
			if(s1==s&l>((KNNnode)(en.getValue().get(s1-1))).getL()){
				l = ((KNNnode)(en.getValue().get(s1-1))).getL();
				s = s1;
				type = en.getKey();
			}
		}
		System.out.println("---------------------------预测结果-------------------------");
		kn1.setType(type);
		System.out.println(kn1.toString());
	}
         综合上述5步,已经实现了待预测对象的归类了问题了。来看下我跑程序的结果:


          OK!!KNN算法使用java已经实现。

         4.KNN可以解决的问题

        当然,你看该算法肯定是要用的,怎么用你也知道。那先来看下学术界用KNN算法都解决什么问题。



         上图是我从知网上搜索截的图(冰山一角),KNN算法应用十分的广泛,而且还对KNN进行了改进,使用了权重进行算法优化,或者和其他算法进行了组合实用,是目前比较简单而且应用比较广的有监督的机器学习方法。

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值