给定两个大小分别为 nxn 的方阵 A 和 B,求它们的乘法矩阵。
朴素方法:以下是两个矩阵相乘的简单方法。
// java code
static int multiply(int A[][N], int B[][N], int C[][N])
{
for (int i = 0; i < N; i++)
{
for (int j = 0; j < N; j++)
{
C[i][j] = 0;
for (int k = 0; k < N; k++)
{
C[i][j] += A[i][k]*B[k][j];
}
}
}
}
// This code is contributed by shivanisinghss2110
上述方法的时间复杂度为O(N 3 )。
分而治之 :
以下是两个方阵相乘的简单分而治之方法。
1、将矩阵 A 和 B 分为 4 个大小为 N/2 x N/2 的子矩阵,如下图所示。
2、递归计算以下值。 ae + bg、af + bh、ce + dg 和 cf + dh。
执行:
//Java program to find the resultant
//product matrix for a given pair of matrices
//using Divide and Conquer Approach
import java.io.*;
import java.util.*;
class GFG {
static int ROW_1 = 4,COL_1 = 4, ROW_2 = 4, COL_2 = 4;
public static void printMat(int[][] a, int r, int c){
for(int i=0;i<r;i++){
for(int j=0;j<c;j++){
System.out.print(a[i][j]+" ");
}
System.out.println("");
}
System.out.println("");
}
public static void print(String display, int[][] matrix,int start_row, int start_column, int end_row,int end_column)
{
System.out.println(display + " =>\n");
for (int i = start_row; i <= end_row; i++) {
for (int j = start_column; j <= end_column; j++) {
//cout << setw(10);
System.out.print(matrix[i][j]+" ");
}
System.out.println("");
}
System.out.println("");
}
public static void add_matrix(int[][] matrix_A,int[][] matrix_B,int[][] matrix_C, int split_index)
{
for (int i = 0; i < split_index; i++){
for (int j = 0; j < split_index; j++){
matrix_C[i][j] = matrix_A[i][j] + matrix_B[i][j];
}
}
}
public static void initWithZeros(int a[][], int r, int c){
for(int i=0;i<r;i++){
for(int j=0;j<c;j++){
a[i][j]=0;
}
}
}
public static int[][] multiply_matrix(int[][] matrix_A,int[][] matrix_B)
{
int col_1 = matrix_A[0].length;
int row_1 = matrix_A.length;
int col_2 = matrix_B[0].length;
int row_2 = matrix_B.length;
if (col_1 != row_2) {
System.out.println("\nError: The number of columns in Matrix A must be equal to the number of rows in Matrix B\n");
int temp[][] = new int[1][1];
temp[0][0]=0;
return temp;
}
int[] result_matrix_row = new int[col_2];
Arrays.fill(result_matrix_row,0);
int[][] result_matrix = new int[row_1][col_2];
initWithZeros(result_matrix,row_1,col_2);
if (col_1 == 1){
result_matrix[0][0] = matrix_A[0][0] * matrix_B[0][0];
}else {
int split_index = col_1 / 2;
int[] row_vector = new int[split_index];
Arrays.fill(row_vector,0);
int[][] result_matrix_00 = new int[split_index][split_index];
int[][] result_matrix_01 = new int[split_index][split_index];
int[][] result_matrix_10 = new int[split_index][split_index];
int[][] result_matrix_11 = new int[split_index][split_index];
initWithZeros(result_matrix_00,split_index,split_index);
initWithZeros(result_matrix_01,split_index,split_index);
initWithZeros(result_matrix_10,split_index,split_index);
initWithZeros(result_matrix_11,split_index,split_index);
int[][] a00 = new int[split_index][split_index];
int[][] a01 = new int[split_index][split_index];
int[][] a10 = new int[split_index][split_index];
int[][] a11 = new int[split_index][split_index];
int[][] b00 = new int[split_index][split_index];
int[][] b01 = new int[split_index][split_index];
int[][] b10 = new int[split_index][split_index];
int[][] b11 = new int[split_index][split_index];
initWithZeros(a00,split_index,split_index);
initWithZeros(a01,split_index,split_index);
initWithZeros(a10,split_index,split_index);
initWithZeros(a11,split_index,split_index);
initWithZeros(b00,split_index,split_index);
initWithZeros(b01,split_index,split_index);
initWithZeros(b10,split_index,split_index);
initWithZeros(b11,split_index,split_index);
for (int i = 0; i < split_index; i++){
for (int j = 0; j < split_index; j++) {
a00[i][j] = matrix_A[i][j];
a01[i][j] = matrix_A[i][j + split_index];
a10[i][j] = matrix_A[split_index + i][j];
a11[i][j] = matrix_A[i + split_index][j + split_index];
b00[i][j] = matrix_B[i][j];
b01[i][j] = matrix_B[i][j + split_index];
b10[i][j] = matrix_B[split_index + i][j];
b11[i][j] = matrix_B[i + split_index][j + split_index];
}
}
add_matrix(multiply_matrix(a00, b00),multiply_matrix(a01, b10),result_matrix_00, split_index);
add_matrix(multiply_matrix(a00, b01),multiply_matrix(a01, b11),result_matrix_01, split_index);
add_matrix(multiply_matrix(a10, b00),multiply_matrix(a11, b10),result_matrix_10, split_index);
add_matrix(multiply_matrix(a10, b01),multiply_matrix(a11, b11),result_matrix_11, split_index);
for (int i = 0; i < split_index; i++){
for (int j = 0; j < split_index; j++) {
result_matrix[i][j] = result_matrix_00[i][j];
result_matrix[i][j + split_index] = result_matrix_01[i][j];
result_matrix[split_index + i][j] = result_matrix_10[i][j];
result_matrix[i + split_index] [j + split_index] = result_matrix_11[i][j];
}
}
}
return result_matrix;
}
public static void main (String[] args) {
int[][] matrix_A = { { 1, 1, 1, 1 },
{ 2, 2, 2, 2 },
{ 3, 3, 3, 3 },
{ 2, 2, 2, 2 } };
System.out.println("Array A =>");
printMat(matrix_A,4,4);
int[][] matrix_B = { { 1, 1, 1, 1 },
{ 2, 2, 2, 2 },
{ 3, 3, 3, 3 },
{ 2, 2, 2, 2 } };
System.out.println("Array B =>");
printMat(matrix_B,4,4);
int[][] result_matrix = multiply_matrix(matrix_A, matrix_B);
System.out.println("Result Array =>");
printMat(result_matrix,4,4);
}
}
// Time Complexity: O(n^3)
//This code is contributed by shruti456rawal
输出
数组A =>
1 1 1 1
2 2 2 2
3 3 3 3
2 2 2 2
数组 B =>
1 1 1 1
2 2 2 2
3 3 3 3
2 2 2 2
结果数组=>
8 8 8 8
16 16 16 16
24 24 24 24
16 16 16 16
在上述方法中,我们对大小为 N/2 x N/2 的矩阵进行 8 次乘法和 4 次加法。两个矩阵相加需要 O(N 2 ) 时间。所以时间复杂度可以写成
T(N) = 8T(N/2) + O(N 2 )
根据马斯特定理,上述方法的时间复杂度为 O(N 3 )
不幸的是,这与上面的简单方法相同。
简单的分而治之也导致O(N 3 ),有更好的方法吗?
在上面的分而治之的方法中,高时间复杂度的主要成分是8次递归调用。Strassen 方法的思想是将递归调用次数减少到 7 次。Strassen 方法与上述简单的分而治之方法类似,该方法也将矩阵划分为大小为 N/2 x N/2 的子矩阵:如上图所示,但在Strassen方法中,结果的四个子矩阵是使用以下公式计算的。
Strassen 方法的时间复杂度
两个矩阵的加法和减法需要 O(N 2 ) 时间。所以时间复杂度可以写成
T(N) = 7T(N/2) + O(N 2 )
根据马斯特定理,上述方法的时间复杂度为
O(N Log7 ) 大约为 O(N 2.8074 )
一般来说,由于以下原因,施特拉森方法在实际应用中并不优选。
1、Strassen 方法中使用的常数很高,对于典型应用,Naive 方法效果更好。
2、对于稀疏矩阵,有专门为其设计的更好的方法。
3、递归中的子矩阵占用额外的空间。
4、由于计算机对非整数值的运算精度有限,Strassen 算法中累积的误差比 Naive 方法中更大。
执行:
/**
** Java Program to Implement Strassen Algorithm
**/
import java.util.Scanner;
/** Class Strassen **/
public class Strassen
{
/** Function to multiply matrices **/
public int[][] multiply(int[][] A, int[][] B)
{
int n = A.length;
int[][] R = new int[n][n];
/** base case **/
if (n == 1)
R[0][0] = A[0][0] * B[0][0];
else
{
int[][] A11 = new int[n/2][n/2];
int[][] A12 = new int[n/2][n/2];
int[][] A21 = new int[n/2][n/2];
int[][] A22 = new int[n/2][n/2];
int[][] B11 = new int[n/2][n/2];
int[][] B12 = new int[n/2][n/2];
int[][] B21 = new int[n/2][n/2];
int[][] B22 = new int[n/2][n/2];
/** Dividing matrix A into 4 halves **/
split(A, A11, 0 , 0);
split(A, A12, 0 , n/2);
split(A, A21, n/2, 0);
split(A, A22, n/2, n/2);
/** Dividing matrix B into 4 halves **/
split(B, B11, 0 , 0);
split(B, B12, 0 , n/2);
split(B, B21, n/2, 0);
split(B, B22, n/2, n/2);
/**
M1 = (A11 + A22)(B11 + B22)
M2 = (A21 + A22) B11
M3 = A11 (B12 - B22)
M4 = A22 (B21 - B11)
M5 = (A11 + A12) B22
M6 = (A21 - A11) (B11 + B12)
M7 = (A12 - A22) (B21 + B22)
**/
int [][] M1 = multiply(add(A11, A22), add(B11, B22));
int [][] M2 = multiply(add(A21, A22), B11);
int [][] M3 = multiply(A11, sub(B12, B22));
int [][] M4 = multiply(A22, sub(B21, B11));
int [][] M5 = multiply(add(A11, A12), B22);
int [][] M6 = multiply(sub(A21, A11), add(B11, B12));
int [][] M7 = multiply(sub(A12, A22), add(B21, B22));
/**
C11 = M1 + M4 - M5 + M7
C12 = M3 + M5
C21 = M2 + M4
C22 = M1 - M2 + M3 + M6
**/
int [][] C11 = add(sub(add(M1, M4), M5), M7);
int [][] C12 = add(M3, M5);
int [][] C21 = add(M2, M4);
int [][] C22 = add(sub(add(M1, M3), M2), M6);
/** join 4 halves into one result matrix **/
join(C11, R, 0 , 0);
join(C12, R, 0 , n/2);
join(C21, R, n/2, 0);
join(C22, R, n/2, n/2);
}
/** return result **/
return R;
}
/** Function to sub two matrices **/
public int[][] sub(int[][] A, int[][] B)
{
int n = A.length;
int[][] C = new int[n][n];
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
C[i][j] = A[i][j] - B[i][j];
return C;
}
/** Function to add two matrices **/
public int[][] add(int[][] A, int[][] B)
{
int n = A.length;
int[][] C = new int[n][n];
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
C[i][j] = A[i][j] + B[i][j];
return C;
}
/** Function to split parent matrix into child matrices **/
public void split(int[][] P, int[][] C, int iB, int jB)
{
for(int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++)
for(int j1 = 0, j2 = jB; j1 < C.length; j1++, j2++)
C[i1][j1] = P[i2][j2];
}
/** Function to join child matrices intp parent matrix **/
public void join(int[][] C, int[][] P, int iB, int jB)
{
for(int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++)
for(int j1 = 0, j2 = jB; j1 < C.length; j1++, j2++)
P[i2][j2] = C[i1][j1];
}
/** Main function **/
public static void main (String[] args)
{
Scanner scan = new Scanner(System.in);
System.out.println("Strassen Multiplication Algorithm Test\n");
/** Make an object of Strassen class **/
Strassen s = new Strassen();
int N = 4;
/** Accept two 2d matrices **/
int[][] A = { { 1, 1, 1, 1 },
{ 2, 2, 2, 2 },
{ 3, 3, 3, 3 },
{ 2, 2, 2, 2 } };
int[][] B = { { 1, 1, 1, 1 },
{ 2, 2, 2, 2 },
{ 3, 3, 3, 3 },
{ 2, 2, 2, 2 } };
System.out.println("\nArray A =>");
for (int i = 0; i < N; i++)
{
for (int j = 0; j < N; j++)
System.out.print(A[i][j] +" ");
System.out.println();
}
System.out.println("\nArray B =>");
for (int i = 0; i < N; i++)
{
for (int j = 0; j < N; j++)
System.out.print(B[i][j] +" ");
System.out.println();
}
int[][] C = s.multiply(A, B);
System.out.println("\nProduct of matrices A and B : ");
for (int i = 0; i < N; i++)
{
for (int j = 0; j < N; j++)
System.out.print(C[i][j] +" ");
System.out.println();
}
}
}
输出
数组A =>
1 1 1 1
2 2 2 2
3 3 3 3
2 2 2 2
数组 B =>
1 1 1 1
2 2 2 2
3 3 3 3
2 2 2 2
结果数组=>
8 8 8 8
16 16 16 16
24 24 24 24
16 16 16 16