K临近(KNN)算法是一种原理比较简单的机器学习算法,其原理是将待分类数据与所有样本数据计算距离,根据距离由近到远选取K个临近点,根据临近点占比和距离权重对待分类点进行分类。
由于需要做距离计算,样本数据每个特征必须为数值型数据。加入我们需要对不同鸟进行分类,从翼展、身高、体重三个方面对老鹰、鸽子、麻雀三种鸟进行分类计算。下面给出一组假设的样本数据:
分类 |
翼展 |
体重 |
身高 |
老鹰 |
2米 |
5.0kg |
1.0米 |
鸽子 |
0.5米 |
0.5kg |
0.3米 |
麻雀 |
0.2米 |
0.05kg |
0.1米 |
从数据中可以看出,由于不同特征的值具跨度范围不一致,如果直接进行计算,容易造成权重失衡,为了消除权重失衡需要对每个特征内部进行归一化,即特征内每个值除以其中的最大值。那么归一化后老鹰(1.0,1.0,1.0),鸽子(0.25,0.1,0.3),麻雀(0.1,0.01,0.1)。我们可以将这三个特征数据想象为一个个三维空间中的点,那么待分类对象就是计算一个三维坐标距离样本点的距离。假设一个待分类数据(x,y,z),采用KNN算法进行分类,通过欧式距离可以计算出它离某个样本点(x1,y1,z1)的距离。
计算公式:距离=sqrt((x - x1)^2 + (y - y1)^2 + (z - z1)^2)。
实际实现为了降低计算消耗可以忽略开方运算,只做平方计算,消除值为负数的差值即可。
实现代码:
distance = Math.pow(Double.parseDouble(testData[j]) - Double.parseDouble(sample[j + 1]), 2);
从原理和实现上不难看出,KNN算法没有训练过程,拿到样本数据后就可以直接使用,虽然计算简单,由于需要对每个样本进行距离计算,当样本数量过大后,将会消耗极大的计算时间和内存空间。针对这种问题,可以采用先取出距离较近的一些点,再进行距离计算。即根据待分类数据(x,y,z),我们增加一个参数,查找半径,当样本数据中超过K个数据处于半径范围内,则停止查找。
实现代码:
private List<String[]> findNearestNeibor(List<String[]> modelList, String[] testData, double radius, int k) {
List<String[]> result = new ArrayList<String[]>();
double step = radius;
while(true) {
for(int i = 0; i < modelList.size(); i++) {
String[] modelSample = modelList.get(i);
List<Boolean> tempResult = new ArrayList<Boolean>();
for(int j = 0; j < testData.length; j++) {
double sampleMin = Double.parseDouble(testData[j]) - step;
double sampleMax = Double.parseDouble(testData[j]) + step;
double modelSampleIndex = Double.parseDouble(modelSample[j + 1]);
if (modelSampleIndex >= sampleMin && modelSampleIndex <= sampleMax) {
tempResult.add(true);
}else {
tempResult.add(false);
}
}
if (!tempResult.contains(false)) {
result.add(modelSample);
}
}
if (result.size() >= k) {
return result;
}else {
step += radius;
}
}
当查找到大于K个值后,再进行距离计算,找出最近的K个值并给出结果。假设K=1时,即取离待分类点最近的样本点作为分类结果。
实现代码:
private String getResultTag(List<String[]> nearestList, String[] testData) {
String result = new String();
double min = testData.length;
for(int i = 0; i < nearestList.size(); i++) {
String[] nearSample = nearestList.get(i);
double distance = 0.0;
for(int j = 1; j < testData.length; j++) {
distance += Math.pow(Double.parseDouble(testData[j]) - Double.parseDouble(nearSample[j]), 2);
}
if (distance < min) {
result = nearSample[0];
min = distance;
}
}
return result;
}
接下来,进行算法测试,随机生成一个包含10000个样本三种分类的文本文件,分类A的特征一在0.9左右,特征二0.5左右,特征三0.3左右;分类B的特征一在0.3左右,特征二0.6左右,特征三0.9左右;分类C的特征一在0.6左右,特征二0.9左右,特征三0.3左右;
如图:
同样,为了提高计算速度,默认K为1情况下,采用一边读取一边计算距离,当完成整个样本文件读取后,即完成计算。
实现代码:
public String predict(File model, String[] testData) {
String result = new String();
double min = testData.length;
try {
BufferedReader reader = new BufferedReader(new FileReader(model));
String line;
while ((line = reader.readLine()) != null) {
String[] sample = line.split(",");
double distance = 0.0;
for(int j = 0; j < testData.length; j++) {
distance += Math.pow(Double.parseDouble(testData[j]) - Double.parseDouble(sample[j + 1]), 2);
}
if (distance < min) {
result = sample[0];
min = distance;
}
}
reader.close();
}catch (Exception e) {
e.printStackTrace();
}
return result;
}
测试代码及测试结果:
public static void main(String[] args) throws Exception{
KNN knn = new KNN();
String[] testData = new String[] {"0.32","0.65","0.83"};
long time1 = System.currentTimeMillis();
String result = knn.predict(new File("C:/Users/admin/Desktop/test/sample.csv"), testData);
long time2 = System.currentTimeMillis();
System.out.println("计算用时:" + (time2 - time1) + "毫秒");
System.out.println(result);
}