java 分而治之(施特拉森矩阵乘法)

给定两个大小分别为 nxn 的方阵 A 和 B,求它们的乘法矩阵。 
朴素方法:以下是两个矩阵相乘的简单方法。

// java code
static int multiply(int A[][N], int B[][N], int C[][N])
{
    for (int i = 0; i < N; i++)
    {
        for (int j = 0; j < N; j++)
        {
            C[i][j] = 0;
            for (int k = 0; k < N; k++)
            {
                C[i][j] += A[i][k]*B[k][j];
            }
        }
    }
}
 
// This code is contributed by shivanisinghss2110

上述方法的时间复杂度为O(N 3 )。 

分而治之 :
以下是两个方阵相乘的简单分而治之方法。 
1、将矩阵 A 和 B 分为 4 个大小为 N/2 x N/2 的子矩阵,如下图所示。 
2、递归计算以下值。 ae + bg、af + bh、ce + dg 和 cf + dh。 

执行:

//Java program to find the resultant 
//product matrix for a given pair of matrices
//using Divide and Conquer Approach
 
import java.io.*;
import java.util.*;
 
class GFG {
 
  static int ROW_1 = 4,COL_1 = 4, ROW_2 = 4, COL_2 = 4;
 
  public static void printMat(int[][] a, int r, int c){
    for(int i=0;i<r;i++){
      for(int j=0;j<c;j++){
        System.out.print(a[i][j]+" ");
      }
      System.out.println("");
    }
    System.out.println("");
  }
 
  public static void print(String display, int[][] matrix,int start_row, int start_column, int end_row,int end_column)
  {
    System.out.println(display + " =>\n");
    for (int i = start_row; i <= end_row; i++) {
      for (int j = start_column; j <= end_column; j++) {
        //cout << setw(10);
        System.out.print(matrix[i][j]+" ");
      }
      System.out.println("");
    }
    System.out.println("");
  }
 
  public static void add_matrix(int[][] matrix_A,int[][] matrix_B,int[][] matrix_C, int split_index)
  {
    for (int i = 0; i < split_index; i++){
      for (int j = 0; j < split_index; j++){
        matrix_C[i][j] = matrix_A[i][j] + matrix_B[i][j];
      }
    }
  }
 
  public static void initWithZeros(int a[][], int r, int c){
    for(int i=0;i<r;i++){
      for(int j=0;j<c;j++){
        a[i][j]=0;
      }
    }
  }
 
  public static int[][] multiply_matrix(int[][] matrix_A,int[][] matrix_B)
  {
    int col_1 = matrix_A[0].length;
    int row_1 = matrix_A.length;
    int col_2 = matrix_B[0].length;
    int row_2 = matrix_B.length;
 
    if (col_1 != row_2) {
      System.out.println("\nError: The number of columns in Matrix A  must be equal to the number of rows in Matrix B\n");
      int temp[][] = new int[1][1];
      temp[0][0]=0;
      return temp;
    }
 
    int[] result_matrix_row = new int[col_2];
    Arrays.fill(result_matrix_row,0);
    int[][] result_matrix = new int[row_1][col_2];
    initWithZeros(result_matrix,row_1,col_2);
 
    if (col_1 == 1){
      result_matrix[0][0] = matrix_A[0][0] * matrix_B[0][0]; 
    }else {
      int split_index = col_1 / 2;
 
      int[] row_vector = new int[split_index];
      Arrays.fill(row_vector,0);
 
      int[][] result_matrix_00 = new int[split_index][split_index];
      int[][] result_matrix_01 = new int[split_index][split_index];
      int[][] result_matrix_10 = new int[split_index][split_index];
      int[][] result_matrix_11 = new int[split_index][split_index];
      initWithZeros(result_matrix_00,split_index,split_index);
      initWithZeros(result_matrix_01,split_index,split_index);
      initWithZeros(result_matrix_10,split_index,split_index);
      initWithZeros(result_matrix_11,split_index,split_index);
 
      int[][] a00 = new int[split_index][split_index];
      int[][] a01 = new int[split_index][split_index];
      int[][] a10 = new int[split_index][split_index];
      int[][] a11 = new int[split_index][split_index];
      int[][] b00 = new int[split_index][split_index];
      int[][] b01 = new int[split_index][split_index];
      int[][] b10 = new int[split_index][split_index];
      int[][] b11 = new int[split_index][split_index];
      initWithZeros(a00,split_index,split_index);
      initWithZeros(a01,split_index,split_index);
      initWithZeros(a10,split_index,split_index);
      initWithZeros(a11,split_index,split_index);
      initWithZeros(b00,split_index,split_index);
      initWithZeros(b01,split_index,split_index);
      initWithZeros(b10,split_index,split_index);
      initWithZeros(b11,split_index,split_index);
 
 
      for (int i = 0; i < split_index; i++){
        for (int j = 0; j < split_index; j++) {
          a00[i][j] = matrix_A[i][j];
          a01[i][j] = matrix_A[i][j + split_index];
          a10[i][j] = matrix_A[split_index + i][j];
          a11[i][j] = matrix_A[i + split_index][j + split_index];
          b00[i][j] = matrix_B[i][j];
          b01[i][j] = matrix_B[i][j + split_index];
          b10[i][j] = matrix_B[split_index + i][j];
          b11[i][j] = matrix_B[i + split_index][j + split_index];
        }
      }
 
      add_matrix(multiply_matrix(a00, b00),multiply_matrix(a01, b10),result_matrix_00, split_index);
      add_matrix(multiply_matrix(a00, b01),multiply_matrix(a01, b11),result_matrix_01, split_index);
      add_matrix(multiply_matrix(a10, b00),multiply_matrix(a11, b10),result_matrix_10, split_index);
      add_matrix(multiply_matrix(a10, b01),multiply_matrix(a11, b11),result_matrix_11, split_index);
 
      for (int i = 0; i < split_index; i++){
        for (int j = 0; j < split_index; j++) {
          result_matrix[i][j] = result_matrix_00[i][j];
          result_matrix[i][j + split_index] = result_matrix_01[i][j];
          result_matrix[split_index + i][j] = result_matrix_10[i][j];
          result_matrix[i + split_index] [j + split_index] = result_matrix_11[i][j];
        }
      }
    }
    return result_matrix;
  }
 
  public static void main (String[] args) {
    int[][] matrix_A = { { 1, 1, 1, 1 },
                        { 2, 2, 2, 2 },
                        { 3, 3, 3, 3 },
                        { 2, 2, 2, 2 } };
 
    System.out.println("Array A =>");
    printMat(matrix_A,4,4);
 
    int[][] matrix_B = { { 1, 1, 1, 1 },
                        { 2, 2, 2, 2 },
                        { 3, 3, 3, 3 },
                        { 2, 2, 2, 2 } };
 
    System.out.println("Array B =>");
    printMat(matrix_B,4,4);
 
    int[][] result_matrix =  multiply_matrix(matrix_A, matrix_B);
 
    System.out.println("Result Array =>");
    printMat(result_matrix,4,4);
  }
}
// Time Complexity: O(n^3)
//This code is contributed by shruti456rawal 

输出
数组A =>
         1 1 1 1
         2 2 2 2
         3 3 3 3
         2 2 2 2


数组 B =>
         1 1 1 1
         2 2 2 2
         3 3 3 3
         2 2 2 2


结果数组=>
         8 8 8 8
        16 16 16 16
        24 24 24 24
        16 16 16 16
        
在上述方法中,我们对大小为 N/2 x N/2 的矩阵进行 8 次乘法和 4 次加法。两个矩阵相加需要 O(N 2 ) 时间。所以时间复杂度可以写成 

T(N) = 8T(N/2) + O(N 2 )  

根据马斯特定理,上述方法的时间复杂度为 O(N 3 )
不幸的是,这与上面的简单方法相同。

简单的分而治之也导致O(N 3 ),有更好的方法吗? 

        在上面的分而治之的方法中,高时间复杂度的主要成分是8次递归调用。Strassen 方法的思想是将递归调用次数减少到 7 次。Strassen 方法与上述简单的分而治之方法类似,该方法也将矩阵划分为大小为 N/2 x N/2 的子矩阵:如上图所示,但在Strassen方法中,结果的四个子矩阵是使用以下公式计算的。

Strassen 方法的时间复杂度

两个矩阵的加法和减法需要 O(N 2 ) 时间。所以时间复杂度可以写成 

T(N) = 7T(N/2) + O(N 2 )

根据马斯特定理,上述方法的时间复杂度为
O(N Log7 ) 大约为 O(N 2.8074 )

一般来说,由于以下原因,施特拉森方法在实际应用中并不优选。 

1、Strassen 方法中使用的常数很高,对于典型应用,Naive 方法效果更好。 
2、对于稀疏矩阵,有专门为其设计的更好的方法。 
3、递归中的子矩阵占用额外的空间。 
4、由于计算机对非整数值的运算精度有限,Strassen 算法中累积的误差比 Naive 方法中更大。

执行:

/**
 ** Java Program to Implement Strassen Algorithm
 **/
  
import java.util.Scanner;
  
/** Class Strassen **/
public class Strassen
{
    /** Function to multiply matrices **/
    public int[][] multiply(int[][] A, int[][] B)
    {        
        int n = A.length;
        int[][] R = new int[n][n];
        /** base case **/
        if (n == 1)
            R[0][0] = A[0][0] * B[0][0];
        else
        {
            int[][] A11 = new int[n/2][n/2];
            int[][] A12 = new int[n/2][n/2];
            int[][] A21 = new int[n/2][n/2];
            int[][] A22 = new int[n/2][n/2];
            int[][] B11 = new int[n/2][n/2];
            int[][] B12 = new int[n/2][n/2];
            int[][] B21 = new int[n/2][n/2];
            int[][] B22 = new int[n/2][n/2];
  
            /** Dividing matrix A into 4 halves **/
            split(A, A11, 0 , 0);
            split(A, A12, 0 , n/2);
            split(A, A21, n/2, 0);
            split(A, A22, n/2, n/2);
            /** Dividing matrix B into 4 halves **/
            split(B, B11, 0 , 0);
            split(B, B12, 0 , n/2);
            split(B, B21, n/2, 0);
            split(B, B22, n/2, n/2);
  
            /** 
              M1 = (A11 + A22)(B11 + B22)
              M2 = (A21 + A22) B11
              M3 = A11 (B12 - B22)
              M4 = A22 (B21 - B11)
              M5 = (A11 + A12) B22
              M6 = (A21 - A11) (B11 + B12)
              M7 = (A12 - A22) (B21 + B22)
            **/
  
            int [][] M1 = multiply(add(A11, A22), add(B11, B22));
            int [][] M2 = multiply(add(A21, A22), B11);
            int [][] M3 = multiply(A11, sub(B12, B22));
            int [][] M4 = multiply(A22, sub(B21, B11));
            int [][] M5 = multiply(add(A11, A12), B22);
            int [][] M6 = multiply(sub(A21, A11), add(B11, B12));
            int [][] M7 = multiply(sub(A12, A22), add(B21, B22));
  
            /**
              C11 = M1 + M4 - M5 + M7
              C12 = M3 + M5
              C21 = M2 + M4
              C22 = M1 - M2 + M3 + M6
            **/
            int [][] C11 = add(sub(add(M1, M4), M5), M7);
            int [][] C12 = add(M3, M5);
            int [][] C21 = add(M2, M4);
            int [][] C22 = add(sub(add(M1, M3), M2), M6);
  
            /** join 4 halves into one result matrix **/
            join(C11, R, 0 , 0);
            join(C12, R, 0 , n/2);
            join(C21, R, n/2, 0);
            join(C22, R, n/2, n/2);
        }
        /** return result **/   
        return R;
    }
    /** Function to sub two matrices **/
    public int[][] sub(int[][] A, int[][] B)
    {
        int n = A.length;
        int[][] C = new int[n][n];
        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++)
                C[i][j] = A[i][j] - B[i][j];
        return C;
    }
    /** Function to add two matrices **/
    public int[][] add(int[][] A, int[][] B)
    {
        int n = A.length;
        int[][] C = new int[n][n];
        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++)
                C[i][j] = A[i][j] + B[i][j];
        return C;
    }
    /** Function to split parent matrix into child matrices **/
    public void split(int[][] P, int[][] C, int iB, int jB) 
    {
        for(int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++)
            for(int j1 = 0, j2 = jB; j1 < C.length; j1++, j2++)
                C[i1][j1] = P[i2][j2];
    }
    /** Function to join child matrices intp parent matrix **/
    public void join(int[][] C, int[][] P, int iB, int jB) 
    {
        for(int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++)
            for(int j1 = 0, j2 = jB; j1 < C.length; j1++, j2++)
                P[i2][j2] = C[i1][j1];
    }    
    /** Main function **/
    public static void main (String[] args) 
    {
        Scanner scan = new Scanner(System.in);
        System.out.println("Strassen Multiplication Algorithm Test\n");
        /** Make an object of Strassen class **/
        Strassen s = new Strassen();
  
         
        int N = 4;
        /** Accept two 2d matrices **/
         
        int[][] A =     { { 1, 1, 1, 1 },
                        { 2, 2, 2, 2 },
                        { 3, 3, 3, 3 },
                        { 2, 2, 2, 2 } };
 
        int[][] B =     { { 1, 1, 1, 1 },
                        { 2, 2, 2, 2 },
                        { 3, 3, 3, 3 },
                        { 2, 2, 2, 2 } };
        System.out.println("\nArray A =>");
     
        for (int i = 0; i < N; i++)
        {
            for (int j = 0; j < N; j++)
                System.out.print(A[i][j] +" ");
            System.out.println();
        }
         
        System.out.println("\nArray B =>");
        for (int i = 0; i < N; i++)
        {
            for (int j = 0; j < N; j++)
                System.out.print(B[i][j] +" ");
            System.out.println();
        }
  
        int[][] C = s.multiply(A, B);
  
        System.out.println("\nProduct of matrices A and  B : ");
        for (int i = 0; i < N; i++)
        {
            for (int j = 0; j < N; j++)
                System.out.print(C[i][j] +" ");
            System.out.println();
        }
  
    }
}

输出
数组A =>
         1 1 1 1
         2 2 2 2
         3 3 3 3
         2 2 2 2


数组 B =>
         1 1 1 1
         2 2 2 2
         3 3 3 3
         2 2 2 2


结果数组=>
         8 8 8 8
        16 16 16 16
        24 24 24 24
        16 16 16 16 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

csdn_aspnet

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值