通过初等行变换求逆矩阵代码实现

模拟手动推导逆矩阵过程(初等行变换)

1. 辅助函数

template<class T>
void __device__ __host__ MathUtils<T>::printArray(T* src, int rows, int cols, const char* comments, bool isInteger)
{
	printf("%s:\n", comments);
	if (!isInteger)
	{
		for (int row = 0; row < rows; row++)
		{
			for (int col = 0; col < cols; col++)
			{
				printf("%.10lf\t", src[row * cols + col]);
			}
			printf("\n");
		}
	}
	else
	{
		for (int row = 0; row < rows; row++)
		{
			for (int col = 0; col < cols; col++)
			{
				printf("%d\t", src[row * cols + col]);
			}
			printf("\n");
		}
	}

}

template<class T>
void __device__ __host__ MathUtils<T>::embedCopy(T* dest, T* src, int destRow, int destCol, int srcRow, int srcCol, int rowOffset, int colOffset)
{
	if (destCol < srcCol || destRow < srcRow)
	{
		for (int row = 0; row < destRow; row++)
		{
			for (int col = 0; col < destCol; col++)
			{
				dest[row * destCol + col] = src[(row + rowOffset) * srcCol + colOffset + col];
			}
		}
		return;
	}
	
	for (int row = 0; row < srcRow; row++)
	{
		for (int col = 0; col < srcCol; col++)
		{
			dest[(row + rowOffset) * destCol + col + colOffset] = src[row * srcCol + col];
		}
	}
}

template<class T>
void __device__ __host__ MathUtils<T>::swapRowForMat(double* mat, int row1, int row2, int rows, int cols)
{
	double temp;
	for (int col = 0; col < cols; col++)
	{
		temp = mat[row1 * cols + col];
		mat[row1 * cols + col] = mat[row2 * cols + col];
		mat[row2 * cols + col] = temp;
	}
}

template<class T>
bool __device__ __host__ MathUtils<T>::setMatIdentity(double* src, int rows,int cols)
{
	for (int row = 0; row < rows; row++)
	{
		for (int col = 0; col < cols; col++)
		{
			src[row * cols + col] = row == col ? 1.0 : 0.;
		}
	}
	return true;
}

2. 主体

src, dest均为矩阵展平后的一维数组,矩阵维度由rows指定。

template<class T>
bool __device__ __host__ MathUtils<T>::getMatInvert(double* src, double* dest, int rows)
{
	double* srcCopy = (double*)malloc(sizeof(double) * rows * rows * 2);
	MathUtils::embedCopy(srcCopy, src, rows, rows * 2, rows, rows);
	MathUtils::setMatIdentity(dest, rows, rows);
	MathUtils::embedCopy(srcCopy, dest, rows, rows * 2, rows, rows, 0, rows);
	MathUtils::printArray(srcCopy, rows, rows * 2, "stacked");

	double pivotElem;
	for (int row = 0; row < rows; row++)
	{
		//1. divide per row's pivot elem;
		pivotElem = srcCopy[row * rows*2 + row];
		int changeTimes = 0;
		while (abs(pivotElem - 0.) < 1e-8)
		{
			++changeTimes;
			if (changeTimes > rows - row - 1)
			{
				printf("Singular mat detected!");
				free(srcCopy);
				return false;
			}
			MathUtils::swapRowForMat(srcCopy, row, row + changeTimes, rows, rows * 2);

			pivotElem = srcCopy[row * rows*2 + row];
		}
		
		for (int col = row; col < rows * 2; col++)
		{
			srcCopy[row * rows*2 + col] /= pivotElem;
		}
		//MathUtils::printArray(srcCopy, rows, rows * 2, "divide pivot");

		//2. Make elements before pivot elem zero
		double sub;
		for (int subRow = row + 1; subRow < rows; subRow++)
		{
			sub = srcCopy[subRow * rows * 2 + row];
			for (int col = 0; col < rows * 2; col++)
			{
				srcCopy[subRow * rows * 2 + col] -= sub * srcCopy[row * rows * 2 + col];
			}
		}
		//MathUtils::printArray(srcCopy, rows, rows * 2, "make elems before pivot zero");
	}
	//MathUtils::printArray(srcCopy, rows, rows * 2, "Up triangled");

	for (int row = rows - 2; row >= 0; row--)
	{
		for (int subRow = rows - 1; subRow > row; subRow--)
		{
			double sub = srcCopy[row * rows * 2 + subRow];
			for (int col = row + 1; col < rows * 2; col++)
			{
				srcCopy[row * rows * 2 + col] -= sub * srcCopy[subRow * rows * 2 + col];
			}
			//MathUtils::printArray(srcCopy, rows, rows * 2, "sub row");
		}
	}

	MathUtils::embedCopy(dest, srcCopy, rows, rows, rows, rows * 2, 0, rows);
	free(srcCopy);
	return true;
}

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值