决策数之线性回归(方法二)

本文介绍了一个使用Java实现的线性回归算法示例,通过具体的代码实现展示了如何加载训练数据、进行模型训练并评估预测准确性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

       在之前写了一篇关于线性回归分类的方法,这里是记录了用Java的方法,也是在看了一位博友的进行优化的(借鉴借鉴了),主要也就是让想学机器学习的朋友好好了解一下,一起来共同学习一下而已,顺便将这些记录下来。

package xianxing;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;

public class xianxinghuigui {
    private double [][] trainData;//训练数据,一行一个数据,每一行最后一个数据为 y
    private int row;//训练数据  行数
    private int column;//训练数据 列数
    
    private double [] theta;  //参数theta
    
    private double alpha;   //训练步长
    private int iteration;//迭代次数
    
//    public xianxinghuigui(String fileName)
//    {   
//        int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的   行数
//        int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的   列数
//        
//        trainData = new double[rowoffile][columnoffile];
//        this.row=rowoffile;
//        this.column=columnoffile;
//        
//        this.alpha = 0.001;//步长默认为0.001
//        this.iteration=100;//迭代次数默认为 100000
//        
//        theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......
//        initialize_theta();
//        
//        loadTrainDataFromFile(fileName);
//    }
    
    //初始化最开始的线性回归的方程(将参数都设置为1)
    private void initialize_theta() {
    	{
            for(int i=0;i<theta.length;i++)
                theta[i]=1.0;
        }
	}
	public xianxinghuigui(String fileName,double alpha,int iteration)
    {    	
        int rowoffile=getRowNumber(fileName);       //获取输入训练数据文本的   行数
        int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的   列数

        trainData = new double[rowoffile][columnoffile];
        this.row=rowoffile;
        this.column=columnoffile;
        
        this.alpha = alpha;
        this.iteration=iteration;
        //假设的线性方程
        theta = new double [column];  //设置回归方程的参数个数
        initialize_theta();
        loadTrainDataFromFile(fileName); //将每行的数据进行存储
    }
    
    //得到文件的行数
    private int getRowNumber(String fileName)
    {
        int count =0;
        File file = new File(fileName);
        BufferedReader reader = null;
        try {
            reader = new BufferedReader(new FileReader(file));
            while (reader.readLine() != null) 
                count++;
            reader.close();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e1) {
                }
            }
        }
        return count;
        
    }
    
    //获取文本的列数
    private int getColumnNumber(String fileName)
    {
        int count =0;
        File file = new File(fileName);
        BufferedReader reader = null;
        try {
            reader = new BufferedReader(new FileReader(file));
            String tempString = reader.readLine();
            if(tempString!=null)
                count = tempString.split(",").length;
            reader.close();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e1) {
                }
            }
        }
        return count;
    }
  
    //进行对每列求参数
    public void trainTheta()
    {
        int iteration = this.iteration; //迭代次数
        while( (iteration--)>0 )
        {
                //对每个theta i 求 偏导数
            double [] partial_derivative = compute_partial_derivative();//偏导数
                //更新每个theta
            for(int i =0; i< theta.length;i++)
                theta[i]-= alpha * partial_derivative[i];
        }
    }
    
    //求偏导数
    private double [] compute_partial_derivative()
    {
        double [] partial_derivative = new double[theta.length];
        for(int j =0;j<theta.length;j++)  //遍历,对每个theta求偏导数(按列进行)
        {
            partial_derivative[j]= compute_partial_derivative_for_theta(j);//对 theta i 求 偏导
        }
        return partial_derivative;
    }
    
    //对列数据进行求导(先求第一列,然后依次往后进行)
    private double compute_partial_derivative_for_theta(int j)
    {
        double sum=0.0;
        for(int i=0;i<row;i++)//遍历 每一行数据
        {
            sum+=h_theta_x_i_minus_y_i_times_x_j_i(i,j);
        }
        return sum/row;
    }
    
    //按列进行求导数(先求第一列)
    private double h_theta_x_i_minus_y_i_times_x_j_i(int i,int j)
    {
        double[] oneRow = getRow(i);//取一行数据,前面是feature,最后一个是y
        double result = 0.0;        
        for(int k=0;k< (oneRow.length-1);k++)   //最后一列不用求导,是判断的类型
            result+=theta[k]*oneRow[k];
        result-=oneRow[oneRow.length-1];
        result*=oneRow[j];
        return result;
    }
    
    //取得某一行的数据
    private double [] getRow(int i)//从训练数据中取出第i行,i=0,1,2,。。。,(row-1)
    {
        return trainData[i];
    }
    
    //将每列的数据进行存储到一个数组中
    private void loadTrainDataFromFile(String fileName)
    {     
        try {
			String encoding = "UTF-8"; //设置编码
			File file = new File(fileName);
			if (file.isFile() && file.exists()) { // 判断文件是否存在
				InputStreamReader read = new InputStreamReader(
						new FileInputStream(file), encoding);     // 考虑到编码格式
				BufferedReader bufferedReader = new BufferedReader(read);
				String lineTxt = null;
				int index=0;
				while ((lineTxt = bufferedReader.readLine()) != null) {  //读取的行数内容不是空
					String[] everyzifu=lineTxt.split(",");
					for(int i=0;i<column;i++){
						if("Iris-setosa".equals(everyzifu[i])){
							trainData[index][i]=1.0;
						}
						else if("Iris-versicolor".equals(everyzifu[i])){
							trainData[index][i]=0.0;
						}
						else{
						trainData[index][i]=Double.parseDouble(everyzifu[i]);
						}						
					}	
					index++;
				}
				read.close();
			} else {
				System.out.println("找不到指定的文件");
			}
		} catch (Exception e) {
			System.out.println("读取文件内容出错");
			e.printStackTrace();
		}
    }
    
    //打印数据
    public void printTrainData()
    {
        System.out.println("Train Data:\n");
        for(int i=0;i<column-1;i++)
            System.out.printf("%10s","x"+i+" ");
        System.out.printf("%10s","y"+" \n");
        for(int i=0;i<row;i++)
        {
            for(int j=0;j<column;j++)
            {
                System.out.printf("%10s",trainData[i][j]+" ");
            }
            System.out.println();
        }
        System.out.println();
    }
    
    public void printTheta()
    {
    	System.out.println("线性回归方程为:");
    	System.out.print("y=");
    	for(int i=0;i<theta.length;i++){
    		if(i!=theta.length-1){
    			System.out.print("("+theta[i]+")X+");
    		}
    		else{
    			System.out.print("("+theta[i]+")");
    		}
       }
    }

    
    //加载测试集
	public void loadCeShiData(String filePath) {
		 System.out.println();
		 System.out.println("预测值\t\t\t\t实际值:\t\t\t\t结果");
          double canshu1=theta[0];  //线性回归的每个参数
          double canshu2=theta[1]; 
          double canshu3=theta[2]; 
          double canshu4=theta[3]; 
          double canshu5=theta[4]; 
          double totalpanduarightnnumber=0;
          double totalpanduanerrornumber=0;
		try {
			String encoding = "UTF-8"; //设置编码
			File file = new File(filePath);
			if (file.isFile() && file.exists()) { // 判断文件是否存在
				InputStreamReader read = new InputStreamReader(
						new FileInputStream(file), encoding);     // 考虑到编码格式
				BufferedReader bufferedReader = new BufferedReader(read);
				String lineTxt = null;
				while ((lineTxt = bufferedReader.readLine()) != null) {  //读取的行数内容不是空
					  String[] panduan=lineTxt.split(",");    //对测试集数据进行处理
					  double ceshidata1=Double.parseDouble(panduan[0]);
					  double ceshidata2=Double.parseDouble(panduan[1]);
					  double ceshidata3=Double.parseDouble(panduan[2]);
					  double ceshidata4=Double.parseDouble(panduan[3]);
					  double ceshiresult=ceshidata1*canshu1+ceshidata2*canshu2
							      +ceshidata3*canshu3+ceshidata4*canshu4+canshu5;
//					  System.out.println("result:"+ceshiresult);
					  String panduanresult="";
					  if(ceshiresult>=1.5){						 
						  panduanresult="Iris-setosa";	
						  System.out.print(panduanresult+"\t\t\t");//输出预测结果
						  System.out.print(panduan[4]+"\t\t"); //输出实际测试数据的结果
						  if(panduanresult.equals(panduan[4])){  //比较两者的结果
							  System.out.println("\tright");
							  totalpanduarightnnumber++;
						  }
						  else{
							  System.out.println("\terror");
							  totalpanduanerrornumber++;
						  }
					  }
					  else{
						  panduanresult="Iris-versicolor"; 
						  System.out.print(panduanresult+"\t\t\t"); //输出预测结果
						  System.out.print(panduan[4]+"\t\t"); //输出测试数据实际结果
						  if(panduanresult.equals(panduan[4])){  //比较两者
							  System.out.println("\tright");
							  totalpanduarightnnumber++;
						  }
						  else{
							  System.out.println("\terror"); 
							  totalpanduanerrornumber++;
						  }
					  }
				}
				//输出精准度(结果保留两位小数)
				double d = (totalpanduarightnnumber)/(totalpanduanerrornumber+totalpanduarightnnumber)*100;
				String result = String.format("%.2f", d);				
				System.out.println("预测的精确度是:"+result+"%");
				read.close();
			} else {
				System.out.println("找不到指定的文件");
			}
		} catch (Exception e) {
			System.out.println("读取文件内容出错");
			e.printStackTrace();
		}	
		
	}

}

后面这是一个主类,主要是调用前面类的方法,来方便进行使用。 

package xianxing;

public class TestLinearRegression {
	  public static void main(String[] args) {
		  //调用构造方法,进行初始化训练数据
		    xianxinghuigui m = new xianxinghuigui("H:\\Iris.txt",0.001,1000);
	         //显示训练的数据
		 //   m.printTrainData();
		    //进行计算线性回归的方程
	         m.trainTheta();
	         //打印线性方程的参数
	         m.printTheta();
	      //加载测试集
	         m.loadCeShiData("H:\\text.txt");
	    }

}

这就是关键的代码了,关于线性回归和决策树的一些知识可以参考下之前自己写的那两篇博客,希望能帮助到一些和我一样在慢慢学习机器学习的朋友,一起努力,加油!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值