JAVA-Knn算法-测试集验证集测试准确率

新手学习-机器学习。 用KNN 算法预测足彩赔率:

当前简单模型,只用K值。未加权重,初级距离计算公式是 -欧式距离

 

代码都有注释:

主体类:

public static void main(String[] args) throws Exception {
        
        
        //导入CSV 工具类
        String path = "F:/deeplearnsuanfa/test.csv";
        List<String> dataList=CvsUtil.getCvs(new File(path));
        dataList.remove(0);
        
        //获取训练集 测试集 和 
        Map<String,List<String>> result = ListUtil.trainTestUtil(dataList, 0.2);
        
        //定义 xTrain yTrain  xTest yTest
        Map<String,List> trainAndTestData = ListUtil.getData(result.get("train"), 53);
        
        List<List<Double>> xTrain = trainAndTestData.get("data");
        List<String> yTrain = trainAndTestData.get("lable");
        
        trainAndTestData = ListUtil.getData(result.get("test"), 53);
        
        List<List<Double>> xTest = trainAndTestData.get("data");
        List<String> yTest = trainAndTestData.get("lable");
        
        KnnAlgorithms.getKnn(xTrain, xTest, yTrain, yTest, 5, 53);
}

2:CVS读取类

    /**
     * 获取CVS 数据
     * path 路径
     */
    public static List<String>  getCvs(File file) throws IOException {
         List<String> dataList=new ArrayList<String>();
            BufferedReader br=null;
            try { 
                br = new BufferedReader(new FileReader(file));
                String line = ""; 
                while ((line = br.readLine()) != null) { 
                    dataList.add(line);
                }
            }catch (Exception e) {
            }finally{
                if(br!=null){
                    try {
                        br.close();
                        br=null;
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
            }
     
            return dataList;
        }

3:因为JAVA 暂时没有找到和PYTHON numpy 对应list 操作库。只好用for写操作list

/**
 * 集合操作工具类
 * @author join
 *
 */
public class ListUtil {
    
    
    /**
     * 分成 测试集和 训练集
     * list 传来的CVS
     * number 测试集的占比
     */
    public static Map<String,List<String>> trainTestUtil(List<String> list,Double number ){
        //随机打乱数据集
        Collections.shuffle(list);
        
        //大的数据集  里面包括两个数据集
        Map<String,List<String>> result = new HashMap<String,List<String>>();
        //训练集
        List<String> train = new ArrayList<String>();
        //测试集
        List<String> test = new ArrayList<String>();
        
        if(list.size()>0) {
            //得到测试集的 数量下标
            int testNumber = list.size() -(int)(list.size() *number);
            
            for(int i = 0;i<list.size();i++) {
                //则是训练集
                if(i<testNumber) {
                    train.add(list.get(i));
                }else {
                    test.add(list.get(i));
                }
            }
            
        }else {
            return null;
        }
        result.put("train", train);
        result.put("test", test);
        return result;
    }
    
    /**
     * 得到 每一个 cvs 的 lable 和 训练数据
     * @param list
     * @return
     */
    public static Map<String,List> getData(List<String> list,int lableNumber) {
        
        Map<String,List> result = new HashMap<String,List>();
        
        List<List<Double>> datas = new ArrayList<List<Double>>();
        List<String> lableData = new ArrayList<String>();
        
        
        for(int i=0;i<list.size();i++) {
            String [] data = list.get(i).split(",");
            List<Double> da = new ArrayList<Double>();
            for(int k=0;k<data.length;k++) {
                if(k == lableNumber) {
                    lableData.add(data[k]);
                }else {
                    if(data[k] != null || !data[k].equals("")) {
                        da.add(Double.valueOf(data[k]));
                    }else {
                        da.add(0.3);
                    }
                }
            }
            //放入数据中
            datas.add(da);
        }
        result.put("data", datas);
        result.put("lable", lableData);
        return result;
    }
    
    /**
     * xTest 和 所有的 xTrain的距离
     * @param xTrain 训练集
     * @param xTest     测试集
     * @param yTrain    训练集标签
     * @throws Exception 
     */
    public static List<List<Map<String,Object>>> allDistance(List<List<Double>> xTrain,List<List<Double>> xTest,List<String> yTrain,int dataNumber) throws Exception {
        List<List<Map<String,Object>>> result = new ArrayList<List<Map<String,Object>>>();
        //FOR  测试集
        for(List<Double> test:xTest) {
            //每个训练集到测试集的样本的距离  和 训练集距离的标签 
            List<Map<String,Object>> testList = new ArrayList<Map<String,Object>>();
            //FOR 训练集
            for(int i = 0;i<xTrain.size();i++) {
                //每个训练集到测试集的样本的距离  和 训练集距离的标签 
                Map<String,Object> map = new HashMap<String,Object>();
                Double euclideanDistance = MathUtil.getEuclideanDistance(xTrain.get(i),test);
                String lable =yTrain.get(i);
                map.put("lable", lable);
                map.put("distance", euclideanDistance);
                testList.add(map);
            }
            //这里做排序 
            Collections.sort(testList, new Comparator<Map<String, Object>>() {
                public int compare(Map<String, Object> o1, Map<String, Object> o2) {
                    Double name1 = Double.valueOf(o2.get("distance").toString()) ;//name1是从你list里面拿出来的一个 
                    Double name2 = Double.valueOf(o1.get("distance").toString()) ; //name1是从你list里面拿出来的第二个name
                    return name1.compareTo(name2);
                }
            });
            result.add(testList);
        }
        return result;
            }

}

 

4:数学公式类:

/**
 * 数据公式 工具类
 * @author join
 *
 */

public class MathUtil{
    
    /**
     * 欧式距离公式
     * 公式:取得参数 互相减 的平方 和  再开放 
     * train 训练数据集的 一个样本
     * test  测试数据的一个样本
     * @return
     * @throws Exception 
     */
    public static Double getEuclideanDistance(List<Double> train,List<Double> test) throws Exception {
        
        Double sum = 0.00;
        
        if(train.size() != test.size()) {
            throw new Exception("两个集合的大小不一致");
        }
        
        for(int i=0;i<train.size();i++) {
            sum += (train.get(i)-test.get(i)) *(train.get(i)-test.get(i));
        }
        return Math.sqrt(sum);
    }
    

}

 

数据是从网上爬的比较小:代码粘贴即可用。供学习。也请各位大佬指教一下.

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值