正规方程:
A=(XXT)−1XTY
A
=
(
X
X
T
)
−
1
X
T
Y
之前已经证明过了。
用JAMA包做矩阵计算
结果自己造的数据矩阵不可逆。。。。。
package com.zy.ml;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.LineIterator;
import Jama.Matrix;
/**
* 多元线性回归正规方程,最小二乘 A=(X*X^T)^-1*X^T*Y
*
* @author yzhang
*
*/
public class LinearRegressionLSM {
static List<String> trainData = new ArrayList<String>();
static int m = 0;
public void getTrainDate(String file) {
LineIterator li = null;
try {
li = FileUtils.lineIterator(new File(file));
} catch (IOException e) {
e.printStackTrace();
}
for (; li.hasNext();) {
String line = li.nextLine();
trainData.add(line);
}
m = trainData.size();
}
private double[][] getX() {
double[][] x = new double[m][3];
for (int i = 0; i < m; i++) {
x[i][0] = 1.0;
x[i][1] = Double.parseDouble(trainData.get(i).split(",")[0]);
x[i][2] = Double.parseDouble(trainData.get(i).split(",")[1]);
}
return x;
}
private double[][] getY() {
double[][] y = new double[m][1];
for (int i = 0; i < m; i++) {
y[i][0] = Double.parseDouble(trainData.get(i).split(",")[2]);
}
return y;
}
public Matrix getTheta() {
double[][] x = getX();
double[][] y = getY();
Matrix XM = new Matrix(x);
Matrix YM = new Matrix(y);
Matrix A = ( (( XM.times(XM.transpose()) ).inverse() ) .times(XM.transpose()) ).times(YM);
return A;
}
public static void main(String[] args) {
LinearRegressionLSM lr = new LinearRegressionLSM();
lr.getTrainDate("/home/test.txt");
Matrix A = lr.getTheta();
double[][] array1 = A.getArray();
for (int i = 0; i < array1.length; i++) {
for (int j = 0; j < array1[i].length; j++) {
System.out.print(array1[i][j] + " ");
}
System.out.println(" ");
}
// double[][] array = { { -1, 1, 0 }, { -4, 3, 0 }, { 1, 0, 2 } };
// // 定义一个矩阵
// Matrix A = new Matrix(array);
// //转置
// Matrix B = A.transpose();
// double[][] array1 = B.getArray();
// for (int i = 0; i < array1.length; i++) {
// for (int j = 0; j < array1[i].length; j++) {
// System.out.print(array1[i][j] + " ");
// }
// System.out.println(" ");
// }
// System.out.println(" -------");
// //逆矩阵
// Matrix C = A.inverse();
// double[][] array2 = C.getArray();
// for (int i = 0; i < array2.length; i++) {
// for (int j = 0; j < array2[i].length; j++) {
// System.out.print(array2[i][j] + " ");
// }
// System.out.println(" ");
// }
// System.out.println(" -------");
// Matrix D = A.times(C);
// double[][] array3 = D.getArray();
// for (int i = 0; i < array3.length; i++) {
// for (int j = 0; j < array3[i].length; j++) {
// System.out.print(array3[i][j] + " ");
// }
// System.out.println(" ");
// }
}
}