新手学习-机器学习。 用KNN 算法预测足彩赔率:
当前简单模型,只用K值。未加权重,初级距离计算公式是 -欧式距离
代码都有注释:
主体类:
public static void main(String[] args) throws Exception {
//导入CSV 工具类
String path = "F:/deeplearnsuanfa/test.csv";
List<String> dataList=CvsUtil.getCvs(new File(path));
dataList.remove(0);
//获取训练集 测试集 和
Map<String,List<String>> result = ListUtil.trainTestUtil(dataList, 0.2);
//定义 xTrain yTrain xTest yTest
Map<String,List> trainAndTestData = ListUtil.getData(result.get("train"), 53);
List<List<Double>> xTrain = trainAndTestData.get("data");
List<String> yTrain = trainAndTestData.get("lable");
trainAndTestData = ListUtil.getData(result.get("test"), 53);
List<List<Double>> xTest = trainAndTestData.get("data");
List<String> yTest = trainAndTestData.get("lable");
KnnAlgorithms.getKnn(xTrain, xTest, yTrain, yTest, 5, 53);
}
2:CVS读取类
/**
* 获取CVS 数据
* path 路径
*/
public static List<String> getCvs(File file) throws IOException {
List<String> dataList=new ArrayList<String>();
BufferedReader br=null;
try {
br = new BufferedReader(new FileReader(file));
String line = "";
while ((line = br.readLine()) != null) {
dataList.add(line);
}
}catch (Exception e) {
}finally{
if(br!=null){
try {
br.close();
br=null;
} catch (IOException e) {
e.printStackTrace();
}
}
}
return dataList;
}
3:因为JAVA 暂时没有找到和PYTHON numpy 对应list 操作库。只好用for写操作list
/**
* 集合操作工具类
* @author join
*
*/
public class ListUtil {
/**
* 分成 测试集和 训练集
* list 传来的CVS
* number 测试集的占比
*/
public static Map<String,List<String>> trainTestUtil(List<String> list,Double number ){
//随机打乱数据集
Collections.shuffle(list);
//大的数据集 里面包括两个数据集
Map<String,List<String>> result = new HashMap<String,List<String>>();
//训练集
List<String> train = new ArrayList<String>();
//测试集
List<String> test = new ArrayList<String>();
if(list.size()>0) {
//得到测试集的 数量下标
int testNumber = list.size() -(int)(list.size() *number);
for(int i = 0;i<list.size();i++) {
//则是训练集
if(i<testNumber) {
train.add(list.get(i));
}else {
test.add(list.get(i));
}
}
}else {
return null;
}
result.put("train", train);
result.put("test", test);
return result;
}
/**
* 得到 每一个 cvs 的 lable 和 训练数据
* @param list
* @return
*/
public static Map<String,List> getData(List<String> list,int lableNumber) {
Map<String,List> result = new HashMap<String,List>();
List<List<Double>> datas = new ArrayList<List<Double>>();
List<String> lableData = new ArrayList<String>();
for(int i=0;i<list.size();i++) {
String [] data = list.get(i).split(",");
List<Double> da = new ArrayList<Double>();
for(int k=0;k<data.length;k++) {
if(k == lableNumber) {
lableData.add(data[k]);
}else {
if(data[k] != null || !data[k].equals("")) {
da.add(Double.valueOf(data[k]));
}else {
da.add(0.3);
}
}
}
//放入数据中
datas.add(da);
}
result.put("data", datas);
result.put("lable", lableData);
return result;
}
/**
* xTest 和 所有的 xTrain的距离
* @param xTrain 训练集
* @param xTest 测试集
* @param yTrain 训练集标签
* @throws Exception
*/
public static List<List<Map<String,Object>>> allDistance(List<List<Double>> xTrain,List<List<Double>> xTest,List<String> yTrain,int dataNumber) throws Exception {
List<List<Map<String,Object>>> result = new ArrayList<List<Map<String,Object>>>();
//FOR 测试集
for(List<Double> test:xTest) {
//每个训练集到测试集的样本的距离 和 训练集距离的标签
List<Map<String,Object>> testList = new ArrayList<Map<String,Object>>();
//FOR 训练集
for(int i = 0;i<xTrain.size();i++) {
//每个训练集到测试集的样本的距离 和 训练集距离的标签
Map<String,Object> map = new HashMap<String,Object>();
Double euclideanDistance = MathUtil.getEuclideanDistance(xTrain.get(i),test);
String lable =yTrain.get(i);
map.put("lable", lable);
map.put("distance", euclideanDistance);
testList.add(map);
}
//这里做排序
Collections.sort(testList, new Comparator<Map<String, Object>>() {
public int compare(Map<String, Object> o1, Map<String, Object> o2) {
Double name1 = Double.valueOf(o2.get("distance").toString()) ;//name1是从你list里面拿出来的一个
Double name2 = Double.valueOf(o1.get("distance").toString()) ; //name1是从你list里面拿出来的第二个name
return name1.compareTo(name2);
}
});
result.add(testList);
}
return result;
}
}
4:数学公式类:
/**
* 数据公式 工具类
* @author join
*
*/
public class MathUtil{
/**
* 欧式距离公式
* 公式:取得参数 互相减 的平方 和 再开放
* train 训练数据集的 一个样本
* test 测试数据的一个样本
* @return
* @throws Exception
*/
public static Double getEuclideanDistance(List<Double> train,List<Double> test) throws Exception {
Double sum = 0.00;
if(train.size() != test.size()) {
throw new Exception("两个集合的大小不一致");
}
for(int i=0;i<train.size();i++) {
sum += (train.get(i)-test.get(i)) *(train.get(i)-test.get(i));
}
return Math.sqrt(sum);
}
}
数据是从网上爬的比较小:代码粘贴即可用。供学习。也请各位大佬指教一下.