1.KNN算法
1.KNN算法1.KNN算法
KNN法最初由Cover 和Hart 于1968 年提出, 是一个理论上比较成熟的方法。KNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的K(K=1,2,3…,n,其中,n<=D)个样本的类别来决定待分样本所属的类别。由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。
2.步骤
KNN算法是比较简单的,理解起来也不难,具体步骤如下图:
上图中D表示训练集,其中P表示特征值Ti所描述的的对象(或者类型);Xik表示Xi的第k个特征值(Wk类似);GroupBy()对p进行分类统计(小编借用SQL语言里面的关键字,哈哈)。哦!对了,k的取值是有学问,具体如何取,取多少,就要你针对你研究的问题多试验几次,或者你是个老司机,早已摸清待研究问题的套路。大概思路就是上面这5步,大家可以一步两步,一步两步,是魔鬼的步伐浪起来~~~~~
3.java实现
根据上述的五个步骤,分步实现。在没开始前,首先要对研究对象进行抽象,具体如下:
public class KNNnode implements Comparable<KNNnode>{
/**
* 实现comparable接口重写compareTo()方法
* 目的:方便存放KNNnode对象的List进行排序,排序的目标属性为l(即与待测点距离)
*/
float x1,x2; //特征值
String type; //特征值对应的类型
double l ; //与待预测点的距离
public float getX1() {
return x1;
}
public void setX1(float x1) {
this.x1 = x1;
}
public float getX2() {
return x2;
}
public void setX2(float x2) {
this.x2 = x2;
}
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public double getL() {
return l;
}
public void setL(double l) {
this.l = l;
}
@Override
public String toString() {
return "KNNnode [x1=" + x1 + ", x2=" + x2 + ", type=" + type + ", l="
+ l + "]";
}
//从大到小排列
@Override
public int compareTo(KNNnode o) {
// TODO Auto-generated method stub
if(l>o.getL())
return -1;
if(l<o.getL())
return 1;
return 0;
}
}
上面抽象的类,描述的对象有两个特征值,即X
1,
X
2
,type是特征值对应的类型。现在就可以用KNNnode来描述或者用历史数据进行生成对象了。基础已经打好了,直接进入五步法了:
①加载样本数据进行训练
/**
* 从txt文本中读取KNNnode所需数据并存放在List中
* @param url txt文本存放路径
* @return
*/
public List<KNNnode> ReadKNNnodeFromFile(String url){
List<KNNnode> node = new ArrayList<KNNnode>();
String st = "";
File file = new File(url);
if(file.isFile()){
try{
BufferedReader reader = new BufferedReader(new FileReader(file));
while((st=reader.readLine())!=null){
//用空格对字符串进行分割,s+可以匹配多个空格
String val[] = st.split("\\s+");
if(val.length==3){
KNNnode knNnode = new KNNnode();
knNnode.setX1(Float.parseFloat(val[0]));
knNnode.setX2(Float.parseFloat(val[1]));
knNnode.setType(val[2]);
node.add(knNnode);
}
}
reader.close();
}catch(IOException e){
e.printStackTrace();
}catch (Exception e) {
// TODO: handle exception
e.printStackTrace();
}
}else{
System.out.println("文件不存在!");
}
return node;
}
txt文本存放的数据是以每个对象里面的属性为一行,如下图:
②确定K以及待分类(确定/预测)的对象。这个比较简单。
final int K = 3;
KNNnode kn1 = new KNNnode();
kn1.setX1(22);
kn1.setX2(17); ③计算距离(相似度)
/**
* 按照欧式距离公式,计算历史数据与待预测对象之间的距离
* @param node l;训练集样本对象
* @param kn1 待预测
* @return
*/
public List<KNNnode> calcul(List<KNNnode> node,KNNnode kn1){
for(int i=0;i<node.size();i++){
KNNnode kn2 = node.get(i);
kn2.setL(Math.sqrt(Math.pow(kn1.getX1()-kn2.getX1(), 2)+Math.pow(kn1.getX2()-kn2.getX2(), 2)));
}
return node;
}
在抽象对象时,有一个属性l是专门为了存放训练样本与带预测对象之间的距离或者相似度。
④返回距离(相似度)最近(高)的K个对象
/**
* 对k个KNNnode的类型进行分类统计
* 使用Map,借助map键值对的存储方式,所以非常方便
* @param node
* @param k
* @return
*/
public Map<String,List<KNNnode>> result(List<KNNnode> node,int k){
Map<String,List<KNNnode>> knnmap = new HashMap<String,List<KNNnode>>();
System.out.println("---------------------K个最小的KNNnode对象-------------------");
for(int i=0;i<node.size();i++){
System.out.println(node.get(i).toString());
}
for(int i=0;i<k;i++){
String type = node.get(i).getType().trim();
if(knnmap.containsKey(type)){
knnmap.get(type).add(node.get(i));
Collections.sort(knnmap.get(type));
}else{
knnmap.put(type, new ArrayList<KNNnode>());
knnmap.get(type).add(node.get(i));
}
}
return knnmap;
}
其次,待预测的type值就等于频数高的type。
public static void main(String args[]) {
final String file_url = "D:\\knn.txt";
final int K = 5;
String type = "";
List<KNNnode> node = new ArrayList<KNNnode>();
// 待归类的node
KNNnode kn1 = new KNNnode();
kn1.setX1(33);
kn1.setX2(12);
KNNprocess knNprocess = new KNNprocess();
node = knNprocess.ReadKNNnodeFromFile(file_url);
node = knNprocess.calcul(node, kn1);
node = knNprocess.getnodeDESC(node, K);
double l = node.get(0).getL();
int s = 0;
Map<String, List<KNNnode>> knn = knNprocess.result(node, K);
for(Map.Entry<String, List<KNNnode>> en:knn.entrySet()){
int s1= en.getValue().size();
if(s1>s){
l = ((KNNnode)(en.getValue().get(s1-1))).getL();
s = s1;
type = en.getKey();
}else
if(s1==s&l>((KNNnode)(en.getValue().get(s1-1))).getL()){
l = ((KNNnode)(en.getValue().get(s1-1))).getL();
s = s1;
type = en.getKey();
}
}
System.out.println("---------------------------预测结果-------------------------");
kn1.setType(type);
System.out.println(kn1.toString());
}
综合上述5步,已经实现了待预测对象的归类了问题了。来看下我跑程序的结果:
OK!!KNN算法使用java已经实现。
4.KNN可以解决的问题
当然,你看该算法肯定是要用的,怎么用你也知道。那先来看下学术界用KNN算法都解决什么问题。
上图是我从知网上搜索截的图(冰山一角),KNN算法应用十分的广泛,而且还对KNN进行了改进,使用了权重进行算法优化,或者和其他算法进行了组合实用,是目前比较简单而且应用比较广的有监督的机器学习方法。
demo下载:KNN算法java实现demo

本文详细介绍KNN算法原理及其Java实现过程,包括加载样本数据、确定K值及预测对象、计算距离、返回最近K个对象及预测结果。同时,探讨了KNN算法的应用范围。
1028

被折叠的 条评论
为什么被折叠?



