首先要知道LU分解:
package com.zhp.third.LU分解;
import java.util.ArrayList;
import com.zhp.common.Methods;
/**
* @author 郑海鹏
* @since 2014/10/18 21:23
* @email 284967632@qq.com
*/
public class LU {
public static void main(String[] args) {
double[][] matrix = new double[][] { { 2, -1, 1, 1 }, { 4, 1, -1, 5 }, { 1, 1, 1, 0 } };
ArrayList<double[][]> lu = getLU(matrix);
Methods.print(lu.get(0));
System.out.println();
Methods.print(lu.get(1));
}
/**
* @return 第0个是L, 第1个是U;
*/
public static ArrayList<double[][]> getLU(double matrix[][]) {
int row = matrix.length;
int col = matrix[0].length;
double[][] matrixL = new double[row][col];
double[][] matrixU = Methods.copy(matrix);
for (int i = 0; i < row; i++)
matrixL[i][i] = 1;
for (int rowIndex = 0; rowIndex < row; rowIndex++) {
int maxRow = rowIndex;
// 选主元
for (int selectRow = rowIndex + 1; selectRow < row; selectRow++) {
if (Math.abs(matrix[selectRow][rowIndex]) > Math.abs(matrix[maxRow][rowIndex])) {
maxRow = selectRow;
}
}
// 交换
for (int colIndex = rowIndex; colIndex < col; colIndex++) {
double temp = matrixU[rowIndex][colIndex];
matrixU[rowIndex][colIndex] = matrixU[maxRow][colIndex];
matrixU[maxRow][colIndex] = temp;
}
// 高斯消去
for (int rollRow = rowIndex + 1; rollRow < row; rollRow++) {
double scale = matrixU[rollRow][rowIndex] / matrixU[rowIndex][rowIndex];
matrixL[rollRow][rowIndex] = scale;
// 等比缩放
for (int colIndex = rowIndex; colIndex < col; colIndex++) {
matrixU[rollRow][colIndex] = matrixU[rollRow][colIndex]
- matrixU[rowIndex][colIndex] * scale;
}
}
}
ArrayList<double[][]> lu = new ArrayList<>();
lu.add(matrixL);
lu.add(matrixU);
return lu;
}
}
其次要利用LU分解解线性方程组:
package com.zhp.third.解线性方程组;
import java.util.ArrayList;
import com.zhp.common.Methods;
import com.zhp.third.LU分解.LU;
/**
* @author 郑海鹏
* @since 2014/10/18 21:23
* @email 284967632@qq.com
*/
public class LinearEquation {
public static void main(String[] args) {
double[][] matrix = new double[][] { { 2, -1, 1, 1 }, { 4, 1, -1, 5 }, { 1, 1, 1, 0 } };
double[] result = getResult(matrix);
Methods.print(result);
}
/**
* 获得一个线性方程组的解
*/
public static double[] getResult(double[][] matrix){
ArrayList<double[][]> lu = LU.getLU(matrix);
return getResultByU(lu.get(1));
}
/**
* @param U是一个上三角矩阵
*/
private static double[] getResultByU(double[][] U) {
int row = U.length;
double[] result = new double[row];
// 最后一个答案
result[row - 1] = U[row - 1][row] / U[row - 1][row - 1];
for(int i = row - 2; i >= 0; i--){
double sum = 0;
int startCol = i + 1;
int endCol = row;
for(int j = startCol; j < endCol; j++){
sum += U[i][j] * result[j];
}
result[i] = (U[i][endCol] - sum) / U[i][i];
}
return result;
}
}
然后就是求逆矩阵啦!
package com.zhp.third.矩阵的逆;
import com.zhp.common.Methods;
import com.zhp.third.LU分解.LU;
import com.zhp.third.解线性方程组.LinearEquation;
/**
* @author 郑海鹏
* @since 2014/10/18 21:23
* @email 284967632@qq.com
*/
public class Inverse {
public static void main(String[] args) {
double[][] a = new double[][] { { 1, 1, 2 }, { 1, 2, 1 }, { 2, 1, 1 } };
double[][] result = getInverseMatrix(a);
Methods.round(result); // 四舍五入掉舍入误差。
Methods.print(result);
}
/**
* 获得逆矩阵
*
* @exception 本方法不检查传入的矩阵有没有逆矩阵
*/
public static double[][] getInverseMatrix(double[][] matrix) {
// 因为 A = LU, 得 I = AA' = LUA', 得 L(UA') = LB = I
// 利用getRightMatrix() 求得B;
// 又:UA' = B
// 再次利用 getRightMatrix() 求得A';
double[][] L = LU.getLU(matrix).get(0);
double[][] U = LU.getLU(matrix).get(1);
// 构建单位矩阵
double[][] I = new double[matrix.length][matrix.length];
for(int i = 0; i < matrix.length; i++){
I[i][i] = 1;
}
double[][] B = getRightMatrix(L, I);
double[][] inverse = getRightMatrix(U, B);
return inverse;
}
/**
* [leftMatrix][rightMatrix] = [resultMatrix]
*
* @return rightMatrix
*/
public static double[][] getRightMatrix(double[][] leftMatrix, double[][] resultMatix) {
// 构成等式的前提是:左边矩阵的行数 == 结果矩阵的行数
if (leftMatrix.length != resultMatix.length) {
System.out.println("错误!左边矩阵的行数不等于结果矩阵的行数!");
return null;
}
int row = leftMatrix[0].length; // 行数等于左边矩阵的列数
int col = resultMatix[0].length; // 列数等于结果矩阵的列数
double[][] rightMatrix = new double[row][col];
double[] colMatrix = new double[row]; // 这是右边矩阵的一列,col个它构成rightMatrix
double[] colResult = new double[row]; // 这是结果矩阵的一列
// 按列求得colMatrix
for (int colIndex = 0; colIndex < col; colIndex++) {
// 取出这一列的结果矩阵
for (int i = 0; i < row; i++) {
colResult[i] = resultMatix[colIndex][i];
}
// 获得这一列的右矩阵
colMatrix = getRightMatrix(leftMatrix, colResult);
// 把这一列的值放到右矩阵中
for (int i = 0; i < row; i++) {
rightMatrix[colIndex][i] = colMatrix[i];
}
}
return rightMatrix;
}
/**
* [leftMatrix][rightMatrix] = [resultMatrix] 意义是:一个二维矩阵 x 一个列向量 = 矩阵在列向量上的值
*
* @return rightMatrix
*/
public static double[] getRightMatrix(double[][] leftMatrix, double[] resultMatrix) {
// 构建增广矩阵
int row = leftMatrix.length;
int col = leftMatrix[0].length + 1;
double[][] augMatrix = new double[row][col];
for (int i = 0; i < row; i++) {
for (int j = 0; j < col - 1; j++) {
augMatrix[i][j] = leftMatrix[i][j];
}
augMatrix[i][col - 1] = resultMatrix[i];
}
// 增广矩阵的解就是所求
double[] rightMatrix = LinearEquation.getResult(augMatrix);
return rightMatrix;
}
}