模拟手动推导逆矩阵过程(初等行变换)
1. 辅助函数
template<class T>
void __device__ __host__ MathUtils<T>::printArray(T* src, int rows, int cols, const char* comments, bool isInteger)
{
printf("%s:\n", comments);
if (!isInteger)
{
for (int row = 0; row < rows; row++)
{
for (int col = 0; col < cols; col++)
{
printf("%.10lf\t", src[row * cols + col]);
}
printf("\n");
}
}
else
{
for (int row = 0; row < rows; row++)
{
for (int col = 0; col < cols; col++)
{
printf("%d\t", src[row * cols + col]);
}
printf("\n");
}
}
}
template<class T>
void __device__ __host__ MathUtils<T>::embedCopy(T* dest, T* src, int destRow, int destCol, int srcRow, int srcCol, int rowOffset, int colOffset)
{
if (destCol < srcCol || destRow < srcRow)
{
for (int row = 0; row < destRow; row++)
{
for (int col = 0; col < destCol; col++)
{
dest[row * destCol + col] = src[(row + rowOffset) * srcCol + colOffset + col];
}
}
return;
}
for (int row = 0; row < srcRow; row++)
{
for (int col = 0; col < srcCol; col++)
{
dest[(row + rowOffset) * destCol + col + colOffset] = src[row * srcCol + col];
}
}
}
template<class T>
void __device__ __host__ MathUtils<T>::swapRowForMat(double* mat, int row1, int row2, int rows, int cols)
{
double temp;
for (int col = 0; col < cols; col++)
{
temp = mat[row1 * cols + col];
mat[row1 * cols + col] = mat[row2 * cols + col];
mat[row2 * cols + col] = temp;
}
}
template<class T>
bool __device__ __host__ MathUtils<T>::setMatIdentity(double* src, int rows,int cols)
{
for (int row = 0; row < rows; row++)
{
for (int col = 0; col < cols; col++)
{
src[row * cols + col] = row == col ? 1.0 : 0.;
}
}
return true;
}
2. 主体
src, dest均为矩阵展平后的一维数组,矩阵维度由rows指定。
template<class T>
bool __device__ __host__ MathUtils<T>::getMatInvert(double* src, double* dest, int rows)
{
double* srcCopy = (double*)malloc(sizeof(double) * rows * rows * 2);
MathUtils::embedCopy(srcCopy, src, rows, rows * 2, rows, rows);
MathUtils::setMatIdentity(dest, rows, rows);
MathUtils::embedCopy(srcCopy, dest, rows, rows * 2, rows, rows, 0, rows);
MathUtils::printArray(srcCopy, rows, rows * 2, "stacked");
double pivotElem;
for (int row = 0; row < rows; row++)
{
//1. divide per row's pivot elem;
pivotElem = srcCopy[row * rows*2 + row];
int changeTimes = 0;
while (abs(pivotElem - 0.) < 1e-8)
{
++changeTimes;
if (changeTimes > rows - row - 1)
{
printf("Singular mat detected!");
free(srcCopy);
return false;
}
MathUtils::swapRowForMat(srcCopy, row, row + changeTimes, rows, rows * 2);
pivotElem = srcCopy[row * rows*2 + row];
}
for (int col = row; col < rows * 2; col++)
{
srcCopy[row * rows*2 + col] /= pivotElem;
}
//MathUtils::printArray(srcCopy, rows, rows * 2, "divide pivot");
//2. Make elements before pivot elem zero
double sub;
for (int subRow = row + 1; subRow < rows; subRow++)
{
sub = srcCopy[subRow * rows * 2 + row];
for (int col = 0; col < rows * 2; col++)
{
srcCopy[subRow * rows * 2 + col] -= sub * srcCopy[row * rows * 2 + col];
}
}
//MathUtils::printArray(srcCopy, rows, rows * 2, "make elems before pivot zero");
}
//MathUtils::printArray(srcCopy, rows, rows * 2, "Up triangled");
for (int row = rows - 2; row >= 0; row--)
{
for (int subRow = rows - 1; subRow > row; subRow--)
{
double sub = srcCopy[row * rows * 2 + subRow];
for (int col = row + 1; col < rows * 2; col++)
{
srcCopy[row * rows * 2 + col] -= sub * srcCopy[subRow * rows * 2 + col];
}
//MathUtils::printArray(srcCopy, rows, rows * 2, "sub row");
}
}
MathUtils::embedCopy(dest, srcCopy, rows, rows, rows, rows * 2, 0, rows);
free(srcCopy);
return true;
}