maven
<dependency>
<groupId>nz.ac.waikato.cms.weka</groupId>
<artifactId>weka-stable</artifactId>
<version>3.8.3</version>
</dependency>
以线性查询为例,别的实现方法替换对应的实现类即可
/**
*
* @param sourceIns 表示你在哪个数据集中找邻居
* @param target 找邻居的样本
* @param kNN 多少邻居
* @return k个邻居
*/
public static Instances getKNeighbour(Instances sourceIns, Instance target, int kNN)
{
Instances neighbours = null;
try{
EuclideanDistance dfunc = new EuclideanDistance();
LinearNNSearch lnn = new LinearNNSearch();
lnn.setDistanceFunction(dfunc);
lnn.setInstances( sourceIns );
lnn.addInstanceInfo(target);
neighbours = lnn.kNearestNeighbours(target, kNN);
}catch(Exception e){
System.out.println("Util.buildKdTree(LinearNNSearch m_NNSearch,Instance mean, int m_kNN) is wrong!");
}
return neighbours;
}
把arff文件读取成Instances
private static Instances getFileInstances(String fileName) throws Exception {
ConverterUtils.DataSource frData = new ConverterUtils.DataSource( fileName );
return frData.getDataSet();
}
跑一个例子:
public static void main(String[] args) throws Exception {
String path = "D:\\poi_test.arff";
Instances instances = getFileInstances(path);
System.out.println("instances:" + instances);
ArrayList<Attribute> atts = new ArrayList<>();
atts.add(new Attribute("x", Attribute.NUMERIC));
atts.add(new Attribute("y", Attribute.NUMERIC));
atts.add(new Attribute("class", Attribute.NUMERIC));
Instances df = new Instances("predictData", atts, 0);
df.setClassIndex(df.numAttributes() - 1);
Instance target = new DenseInstance(2);
target.setValue(0, 116.41);
target.setValue(1, 39.95);
System.out.println("target:" + target);
df.add(target);
System.out.println("=============knn=================:");
Instances result = getKNeighbour(instances, target, 3);
System.out.println("result:" + result);
}
数据:
@relation distance
@attribute x NUMERIC
@attribute y NUMERIC
@attribute class NUMERIC
@data
116.415668,39.953726,0
116.438614,39.930745,0
116.426367,39.928047,0