K-Nearest Neighbor算法实现
要求
数据
数据集中一共12000多条数据,每个数据包含16个特征,1个标签(该条数据对应的种子类别),一共有7类种子。
每个特征都为定距数据,即:取值范围为连续取值的数值数据。
部分特征是通过其他特征计算出来。
各类种子的个数如下:
Seker(2027), Barbunya(1322), Bombay(522), Cali(1630), Dermosan(3546), Horoz(1928) ,Sira(2636)。
本实验中将其按37分拆分为测试集:4086条, 训练集:9525条。
KNN模型介绍
KNN算法是一种模式识别方法,根据对象进行分类。一个样本与数据集中的k个样本最相似, 如果这k个样本中的大多数属于某一个类别, 则该样本也属于这个类别。也就是说,该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。
邻近算法,或者说K最邻近(KNN,K-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是K个最近的邻居的意思,说的是每个样本都可以用它最接近的K个邻近值来代表。近邻算法就是将数据集合中每一个记录进行分类的方法。
该方法的不足之处是计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最邻近点。
KNN算法的核心思想是,如果一个样本在特征空间中的K个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。KNN方法在类别决策时,只与极少量的相邻样本有关。由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。
算法流程:
总体来说,KNN分类算法包括以下4个步骤:
- 准备数据,对数据进行预处理。
- 计算测试样本点(也就是待分类点)到其他每个样本点的距离
- 对每个距离进行排序,然后选择出与当前点距离最小的K个点。
- 对K个点所属的类别进行比较,根据少数服从多数的原则,将测试样本点归入在K个点中占比最高的那一类。
距离度量公式如下:
算法框图如下:
代码简介:
public static Dataset BuildData(String pathname, int sort){
// 创建数据集
Dataset data = new DefaultDataset();
File file = new File(pathname);
try{
BufferedReader textFile = new BufferedReader(new FileReader(file));
String lineDta = "";
while ((lineDta = textFile.readLine()) != null){
double[] values = new double[16];
int i = 0;
StringTokenizer st = new StringTokenizer(lineDta, ",");
while (st.hasMoreElements()) {
values[i] = Double.parseDouble((String) st.nextElement());
i++;
if(i == 16){
int sortE = Integer.parseInt((String) st.nextElement());
if( sortE == sort){
Instance instance = new DenseInstance(values, sort);
data.add(instance);
}
}
}
}
}catch (FileNotFoundException e){
System.out.println("没有找到指定文件");
}catch (IOException e){
System.out.println("文件读写出错");
}
return data;
}
该方法从文件中根据各分类建立数据集, 方便后续从各分类随机37分组织训练集和测试集.
public static void Acc(Dataset data, Classifier knn)
{
int correct = 0, wrong = 0;
for (Instance inst : data) {
Object predictedClassValue = knn.classify(inst);
Object realClassValue = inst.classValue();
if (predictedClassValue.equals(realClassValue))
correct++;
else
wrong++;
}
double acc = correct * 1.0 / (correct + wrong);
System.out.println("正确率为: " + acc);
}
该方法读取测试集及训练好的knn分类器, 并输出正确率.
public double p;
@Override
public double measure(Instance x, Instance y) {
Instance temp = x.add(y.multiply(-1));
temp = temp.multiply(temp);
temp = temp.sqrt();
if(this.p == 0.5){
temp = temp.sqrt();
double sum = 0;
for(double i:temp){
sum += i;
}
return sum * sum;
} else if (this.p == 1) {
double sum = 0;
for(double i:temp){
sum += i;
}
return sum;
} else{
temp = temp.multiply(temp);
double sum = 0;
for(double i:temp){
sum += i;
}
return Math.sqrt(sum);
}
}
该方法重写了距离测量方法, 根据实验要求分别对q = 0.5 q = 1 和 q = 2 做了实现.
结论:
本次实验的结果如下:
将其绘制为带折线的散点图如下:
由图可以看到, p值取0.5时, 正确率均高于取1和2时, 且在k值取10时, 正确率高于取其他值, 由此可以初步确认在p值取0.5, k值取10左右时, KNN的分类正确率最高。