KNN算法就是把待分类数据放在训练集里找出离他最近的K个元素(欧氏距离),然后看看其中哪个类最多,就将这个元素分为这个类。在本实验中,使用数字数据集。每个数字含有一个二维数组表示其中的像素点,可以认为拥有M*N个特征,只不过每个特征只有0和1两种值,表示该像素点是否绘制。
将下载的训练集和测试集放在项目根目录下,因为测试集中每个元素也是已标记数据,所以每次分类后可以判断分类是否正确,从而得出一个正确率。
在拿到待分类元素的K个邻居后,最简单的处理是每个邻居具有相等的投票权,考虑增大离得近的元素的影响力,也就是为他们的投票权设置权值。这里我的权值设置是离得最近的具有K票,第二近的具有K-1票,依次递减,比较容易理解,直接看代码。
1.封装的数字类:
public class Number {
private int[][] data=new int[32][32];
private int kind;
public int[][] getData() {
return data;
}
public void setData(int[][] data) {
this.data = data;
}
public int getKind() {
return kind;
}
public void setKind(int kind) {
this.kind = kind;
}
}
2.test类:
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;
public class Test {
private int k;
private List<Number> testDatas;
public void putTestData(String path){//放入测试数据
File folder=new File(path);
File[] files=folder.listFiles();
testDatas=new ArrayList<>();
for(File file:files){
testDatas.add(txt2Number(file));
}
}
public void work(String path){//开始测试
File folder=new File(path);
File[] files=folder.listFiles();
int numberAll=0;//测试总数
int numCorrect=0;//测试正确数
double result=0;//正确率
for(File file:files){
Number num=txt2Number(file);
int[] minDistances=new int[k];
int[] resultKinds=new int[k];
for (int i = 0; i < k; i++) {
minDistances[i]=Integer.MAX_VALUE;
resultKinds[i]=0;
}
for(Number nu:testDatas){
int currentDis=calcu(num, nu);
int currentKind=nu.getKind();
for (int i = 0; i < k; i++) {//将当前测试数据与邻居数组中每一个进行比对,看看是否可以替换掉一个
if (currentDis<minDistances[i]) {
resultKinds[i]=currentKind;
minDistances[i]=currentDis;
break;
}
}
}
int []kinds=new int[10];//10个类别的个数
for (int i = 0; i < k; i++) {
kinds[resultKinds[i]]+=add(minDistances, i);//加权后累加
}
int resultKind=0;
int resultKindNum=0;
for (int i = 0; i < 10; i++) {
if (kinds[i]>resultKindNum) {
resultKind=i;
resultKindNum=kinds[i];
}
}
numberAll++;
if (resultKind==num.getKind()) {
numCorrect++;
}
}
result=((double)(numCorrect*100))/numberAll;
System.out.println("k是:"+getK()+" 测试总数:"+numberAll+" "
+ "正确数:"+numCorrect+" 正确率"+result);
}
public void workOne(String path){//测试单个
File fileTest=new File(path);
Number num=txt2Number(fileTest);
int[] minDistances=new int[k];
int[] resultKinds=new int[k];
for (int i = 0; i < k; i++) {
minDistances[i]=Integer.MAX_VALUE;
resultKinds[i]=0;
}
for(Number nu:testDatas){
int currentDis=calcu(num, nu);
int currentKind=nu.getKind();
for (int i = 0; i < k; i++) {//将当前测试数据与邻居数组中每一个进行比对,看看是否可以替换掉一个
if (currentDis<minDistances[i]) {
resultKinds[i]=currentKind;
minDistances[i]=currentDis;
break;
}
}
}
int []kinds=new int[10];//10个类别的个数
for (int i = 0; i < k; i++) {
kinds[resultKinds[i]]++;
}
int resultKind=0;
int resultKindNum=0;
for (int i = 0; i < 10; i++) {
if (kinds[i]>resultKindNum) {
resultKind=i;
resultKindNum=kinds[i];
}
}
System.out.println("识别文件"+path+"为:"+resultKind+" 实际类型为:"+num.getKind());
}
public int calcu(Number a,Number b){//计算两张图的欧氏距离,为了简化计算不开根号
int result=0;
for(int i=0;i<32;i++){
for (int j = 0; j < 32; j++) {
int[][] d1=a.getData();
int[][] d2=b.getData();
int dis=d1[i][j]-d2[i][j];
result+=dis*dis;
}
}
return result;
}
public int getK() {
return k;
}
public void setK(int k) {
this.k = k;
}
public Number txt2Number(File file){//txt文件转Number对象
Number num=new Number();
int[][] data=new int[32][32];
String fileName=file.getName();
int kind =Integer.valueOf(fileName.substring(0,1));
num.setKind(kind);
try {
BufferedReader reader=new BufferedReader(new FileReader(file));
String s=null;
for (int i = 0; i < 32; i++) {
s=reader.readLine();
for (int j = 0; j < 32; j++) {
data[i][j]=Integer.valueOf(s.substring(j, j+1));
}
}
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
num.setData(data);
return num;
}
public int add(int []a,int index){//获取这个邻居元素的权值
int re=0;
for (int i = 0; i < a.length; i++) {
if (a[index]<=a[i]) {
re++;
}
}
return re;
}
public static void main (String[] args) {
Test test=new Test();
test.putTestData("testDigits");
test.setK(10);
test.work("trainingDigits");
test.setK(1);
test.work("trainingDigits");
}
}
实验结果: