kNN 的特点:
- 简单. 没有学习过程, 也被称为惰性学习 lazy learning. 类似于开卷考试, 在已有数据中去找答案.
- 本源. 找相似, 正是人类认识事物的常用方法, 隐藏于人类或者其他动物的基因里面. 当然, 人类也会上当,例如有人把邻居的滴水观音误认为是芋头, 偷食后中毒.
- 效果好. 永远不要小视 kNN, 对于很多数据, 你很难设计算法超越它.
- 适应性强. 可用于分类, 回归. 可用于各种数据.
- 可扩展性强. 设计不同的度量, 可获得意想不到的效果.
- 一般需要对数据归一化.
- 复杂度高. 这也是 kNN 最重要的缺点. 对于每一个测试数据, 复杂度为 O ( ( m + k ) n ) , 其中 n 为训练数据个数, m为条件属性个数, k为邻居个数. 代码见 computeNearests().
代码:
package machinelearning.knn;
import weka.core.*;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;
public class KnnClassification {
//曼哈顿距离,|x|+|y|
public static final int MANHATTAN = 0;
//欧氏距离
public static final int EUCLIDEAN = 1;
//距离衡量方式
public int distanceMeasure = EUCLIDEAN;
//一个随机实例
public static final Random random = new Random();
//邻居的数量
int numNeighbors = 7;
//存储整个数据集
Instances dataset;
//训练集,由数据索引表示
int[] trainingSet;
//测试集,由数据索引表示
int[] testingSet;
//预测结果
int[] predictions;
public KnnClassification(String paraFilename) {
try {
FileReader fileReader = new FileReader(paraFilename);
dataset = new Instances(fileReader);
//最后一个属性是类别
dataset.setClassIndex(dataset.numAttributes() - 1);
fileReader.close();
} catch (Exception e) {
System.out.println("Error occurred while trying to read \'" + paraFilename
+ "\' in KnnClassification constructor.\r\n" + e);
System.exit(0);
}
}
/**
* 获得一个随机索引用于数据随机化
*
* @param paraLength 数据的长度
* @return 返回一个索引数组
*/
public static int[] getRandomIndices(int paraLength) {
int[] resultIndices = new int[paraLength];
//1. 初始化
for (int i = 0; i < paraLength; i++) {
resultIndices[i] = i;
}
//2. 随机交换
int tempFirst, tempSecond, tempValue;
for (int i = 0; i < paraLength; i++) {
//产生两个随机索引
tempFirst = random.nextInt(paraLength);
tempSecond = random.nextInt(paraLength);
//交换
tempValue = resultIndices[tempFirst];
resultIndices[tempFirst] = resultIndices[tempSecond];
resultIndices[tempSecond] = tempValue;
}
return resultIndices;
}
/**
* 将数据分为训练集与测试集
*
* @param paraTrainingFraction 训练集所占比例
*/
public void splitTrainingTesting(double paraTrainingFraction) {
int tempSize = dataset.numInstances();//数据集所含数据的数量
int[] tempIndices = getRandomIndices(tempSize);
int tempTrainingSize = (int) (tempSize * paraTrainingFraction);
trainingSet = new int[tempTrainingSize];
testingSet = new int[tempSize - tempTrainingSize];
for (int i = 0; i < tempTrainingSize; i++) {
trainingSet[i] = tempIndices[i];
}
for (int i = 0; i < tempSize - tempTrainingSize; i++) {
testingSet[i] = tempIndices[tempTrainingSize + i];
}
}
/**
* 预测整个测试集,结果存储在预测集中
*/
public void predict() {
predictions = new int[testingSet.length];
for (int i = 0; i < predictions.length; i++) {
predictions[i] = predict(testingSet[i]);
}
}
/**
* 预测给定的实例
*
* @param paraIndex
* @return 预测的结果
*/
private int predict(int paraIndex) {
int[] tempNeighbors = computeNearests(paraIndex);
int resultPrediction = simpleVoting(tempNeighbors);
return resultPrediction;
}
/**
* 两个实例之间的距离
*
* @param paraI 第一个实例的索引
* @param paraJ 第二个实例的索引
* @return 距离
*/
public double distance(int paraI, int paraJ) {
double resultDistance = 0;
double tempDifference;
switch (distanceMeasure) {
case MANHATTAN:
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
if (tempDifference < 0) {
resultDistance -= tempDifference;
} else {
resultDistance += tempDifference;
}
}
break;
case EUCLIDEAN:
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
resultDistance += tempDifference * tempDifference;
}
break;
default:
System.out.println("Unsupported distance measure: " + distanceMeasure);
}
return resultDistance;
}
/**
* 获取分类器的准确度
*
* @return
*/
public double getAccuracy() {
double tempCorrect = 0;
for (int i = 0; i < predictions.length; i++) {
if (predictions[i] == dataset.instance(testingSet[i]).classValue()) {
tempCorrect++;
}
}
return tempCorrect / testingSet.length;
}
/**
* 计算最近的n个邻居
*
* @param paraCurrent 最近的实例
* @return 最近实例的索引
*/
private int[] computeNearests(int paraCurrent) {
int[] resultNearests = new int[numNeighbors];
boolean[] tempSelected = new boolean[trainingSet.length];
double tempMinimalDistance;
int tempMinimalIndex = 0;
double[] tempDistances = new double[trainingSet.length];
for (int i = 0; i < trainingSet.length; i++) {
tempDistances[i] = distance(paraCurrent, trainingSet[i]);
}
//选择最近的k个索引
for (int i = 0; i < numNeighbors; i++) {
tempMinimalDistance = Double.MAX_VALUE;
for (int j = 0; j < trainingSet.length; j++) {
if (tempSelected[j]) {
continue;
}
if (tempDistances[j] < tempMinimalDistance) {
tempMinimalDistance = tempDistances[j];
tempMinimalIndex = j;
}
}
resultNearests[i] = trainingSet[tempMinimalIndex];
tempSelected[tempMinimalIndex] = true;
}
System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests));
return resultNearests;
}
/**
* 投票
*
* @param paraNeighbors
* @return
*/
private int simpleVoting(int[] paraNeighbors) {
int[] tempVotes = new int[dataset.numClasses()];
for (int i = 0; i < paraNeighbors.length; i++) {
tempVotes[(int) dataset.instance(paraNeighbors[i]).classValue()]++;
}
int tempMaximalVotingIndex = 0;
int tempMaximalVoting = 0;
for (int i = 0; i < dataset.numClasses(); i++) {
if (tempVotes[i] > tempMaximalVoting) {
tempMaximalVoting = tempVotes[i];
tempMaximalVotingIndex = i;
}
}
return tempMaximalVotingIndex;
}
public void setDistanceMeasure(int paraType) {
if (paraType == 0) {
distanceMeasure = MANHATTAN;
} else if (paraType == 1) {
distanceMeasure = EUCLIDEAN;
} else {
System.out.println("Wrong Distance Measure!!!");
}
}
public void setNumNeighbors(int paraNumNeighbors) {
if (paraNumNeighbors > dataset.numInstances()) {
System.out.println("out of range");
return;
}
this.numNeighbors = paraNumNeighbors;
}
public static void main(String args[]) {
KnnClassification tempClassifier = new KnnClassification("D:\\研究生学习\\iris.arff");
tempClassifier.splitTrainingTesting(0.8);
tempClassifier.predict();
System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy());
}
}
运行结果:
The nearest of 120 are: [143, 140, 124, 144, 112, 139, 102]
The nearest of 3 are: [29, 2, 45, 12, 38, 42, 34]
The nearest of 64 are: [82, 79, 88, 99, 59, 92, 89]
The nearest of 37 are: [34, 9, 1, 12, 29, 45, 2]
The nearest of 148 are: [136, 115, 147, 140, 137, 124, 144]
The nearest of 30 are: [29, 34, 9, 45, 12, 1, 11]
The nearest of 126 are: [123, 127, 138, 146, 83, 63, 72]
The nearest of 117 are: [131, 105, 109, 122, 125, 107, 118]
The nearest of 55 are: [66, 96, 94, 78, 95, 99, 84]
The nearest of 47 are: [2, 42, 6, 29, 38, 12, 45]
The nearest of 90 are: [94, 96, 89, 99, 67, 95, 92]
The nearest of 71 are: [97, 82, 92, 61, 99, 74, 67]
The nearest of 132 are: [128, 104, 103, 111, 112, 140, 147]
The nearest of 49 are: [7, 39, 0, 28, 17, 40, 34]
The nearest of 134 are: [103, 83, 111, 137, 119, 72, 108]
The nearest of 35 are: [1, 2, 40, 28, 34, 9, 7]
The nearest of 10 are: [48, 27, 36, 19, 5, 16, 20]
The nearest of 130 are: [107, 102, 125, 129, 105, 122, 108]
The nearest of 15 are: [33, 14, 5, 16, 32, 48, 19]
The nearest of 8 are: [38, 42, 13, 12, 45, 2, 29]
The nearest of 133 are: [83, 72, 123, 127, 63, 111, 77]
The nearest of 18 are: [5, 48, 20, 16, 31, 36, 33]
The nearest of 69 are: [80, 89, 81, 92, 82, 53, 67]
The nearest of 135 are: [105, 102, 107, 122, 125, 109, 118]
The nearest of 25 are: [34, 9, 1, 12, 45, 29, 7]
The nearest of 46 are: [19, 21, 48, 4, 27, 32, 44]
The nearest of 110 are: [147, 115, 77, 137, 141, 139, 127]
The nearest of 116 are: [137, 103, 147, 111, 128, 112, 104]
The nearest of 145 are: [141, 147, 139, 112, 115, 140, 128]
The nearest of 149 are: [127, 138, 142, 101, 70, 83, 121]
The accuracy of the classifier is: 0.9666666666666667
本文介绍了k-最近邻(kNN)算法的基本特点,如其简单性、本源性、效果好和适应性强,同时也揭示了其复杂度高的问题。通过实际代码展示,探讨了如何在实践中应用和调整kNN,包括距离衡量方法和邻居数量的选择。
1823

被折叠的 条评论
为什么被折叠?



