Java实现KNN
算法介绍
- 在特征空间中统计k个距离最近的样本的标签,选择最多的标签最为自己的标签。
- 可以采用多种距离计算策略,如曼哈顿距离、欧氏距离。
- KNN是非参且惰性的。
- 优点:实现简单、训练快(惰性)、效果好、对异常值不敏感
- 缺点:时空复杂度都高、需要合适的归一化等
算法流程
变量准备
在这里准备了两种距离策略、训练集、测试集、验证集和k值。
//两种距离
public static final int MANHATTAN = 0;
public static final int EUCLIDEAN = 1;
//距离策略
public int distanceMeasure = EUCLIDEAN;
//设置随机种子
public static final Random random = new Random();
//设置k值
int numNeighbors = 7;
//数据集
Instances dataset;
//训练集、测试集、验证集数组
int[] trainingSet;
int[] testingSet;
int[] predictions;
用Instances类读取数据集
使用构造方法,读取数据集并存入Instances类中。
/**
* 构造方法,用Instances类读取数据集
* @param paraFilename 数据集地址
*/
public Knn(String paraFilename) {
try {
FileReader fileReader = new FileReader(paraFilename);
//这里用Instances读取数据集
dataset = new Instances(fileReader);
//将当前Istances类的标签下标设置为(数据集的属性数-1)?
dataset.setClassIndex(dataset.numAttributes() - 1);
fileReader.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
}
打乱数据集并划分训练集与测试集
在这里将原本的数据集随机交换打乱,再按照比例将数据集划分为训练集与测试集。
/**
* 获得一个随机序列
* @param paraLength 序列长度
* @return 打乱后的序列
*/
public static int[] getRandomIndices(int paraLength) {
int[] resultIndices = new int[paraLength];
//按序号赋值
for(int i = 0; i < paraLength; i++) {
resultIndices[i] = i;
}
//随机交换打乱
int tempFirst, tempSecond, tempValue;
for(int i = 0; i < paraLength; i++) {
tempFirst = random.nextInt(paraLength);
tempSecond = random.nextInt(paraLength);
tempValue = resultIndices[tempFirst];
resultIndices[tempFirst] = resultIndices[tempSecond];
resultIndices[tempFirst] = tempValue;
}
return resultIndices;
}
/**
* 划分训练集与验证集,输入训练集的占比,将数据集的下标放入trainingSet和testingSet中
* @param paraTrainingFraction 训练集比例
*/
public void splitTrainingTesting(double paraTrainingFraction) {
//获得数据集记录数
int tempSize = dataset.numInstances();
//得到长度为tempSize的随机序列
int[] tempIndices = getRandomIndices(tempSize);
//获得训练集记录数
int tempTrainingSize = (int)(tempSize * paraTrainingFraction);
//声明训练集与测试集
trainingSet = new int[tempTrainingSize];
testingSet = new int[tempSize - tempTrainingSize];
//为训练集与测试集赋值
for (int i = 0; i < tempTrainingSize; i++) {
trainingSet[i] = tempIndices[i];
}
for (int i = 0; i < tempSize - tempTrainingSize; i++) {
testingSet[i] = tempIndices[tempTrainingSize + i];
}
}
距离计算
这里规定了两种距离,并实现了两种距离的计算。
/**
* 计算两条记录之间对应的距离
* @param paraI 一个记录下标
* @param paraJ 另一条记录下标
* @return 距离结果
*/
public double distance(int paraI, int paraJ) {
double resultDistance = 0;
double tempDifference;
//两种距离策略分开计算
switch (distanceMeasure) {
case MANHATTAN:
//对除去标签以外的所有属性进行处理
for (int i = 0; i < dataset.numAttributes() - 1; i

该文详细介绍了如何使用Java实现KNN(K近邻)算法,包括数据集的读取、训练集与测试集的划分、距离计算(曼哈顿距离和欧氏距离)、获取k个最近邻记录、单条记录预测以及模型的预测准确率计算。代码中包含了实例数据集的读取、随机打乱数据集、预测过程以及评估预测准确性的方法。
最低0.47元/天 解锁文章
1050

被折叠的 条评论
为什么被折叠?



