Implementation of Strassen’s Algorithm for Matrix Multiplication

本文介绍Strassen矩阵乘法算法及其优化实现。通过设置适当的基例大小,该算法相较于传统方法,在处理大规模矩阵时展现出更好的性能。文章还提供了一个C++实现示例。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Strassen’s algorithm is not the most efficient algorithm for matrix multiplication, but it was the first algorithm that was theoretically faster than the naive algorithm. There is very good explanation and implementation of Strassen’s algorithm on Wikipedia.

However, the implementation of Strassen’s algorithm cannot be used directly, because it  just sets the base case of the divide-and-conquer to be 1×1 matrix, which would consume huge time cost for iteration. If set the base case to 2×2 matrix, which means 2×2 matrix and 1×1 matrix will be multiplied by naive algorithm, then the Strassen’s algorithm will be more efficient for matrices larger than 512×512.

When set base case to 2×2 matrix, then the Strassen’s algorithm will surpass naive algorithm for matrices larger than 512×512.

When set base case to 6×6 matrix, then the Strassen’s algorithm will surpass naive algorithm for matrices larger than 128×128.

/*------------------------------------------------------------------------------*/
// 	matrix_mult.cc -- Implementation of matrix multiplication with
// 			  Strassen's algorithm. 
//
// Compile this file with gcc command:
//     g++ -Wall -o matrix_mult matrix_mult.cc                                                
 
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <ctype.h>
#include <unistd.h>
#include <iostream>
#include <fstream>
#include <cmath>
#include <cstring>

using namespace std;

// This function allocates the matrix 
inline double** allocate_matrix(int n) 
{
    	double** mat=new double*[n];
	for(int i=0;i<n;++i)
	{
		mat[i]=new double[n];
		memset(mat[i],0,sizeof(double)*n);
	}
 
	return (mat);     // returns the pointer to the vector. 
}

/*------------------------------------------------------------------------------*/
// This function unallocates the matrix (frees memory)
inline void free_matrix(double **M, int n)
{
    for (int i = 0; i < n; i++) 
    { 
       delete [] M[i];
    } 

    delete [] M;         // frees the pointer /
    M = NULL;
}

/*------------------------------------------------------------------------------*/
// function to sum two matrices
inline void sum(double **a, double **b, double **result, int tam) {
 
    int i, j;
 
    for (i = 0; i < tam; i++) {
        for (j = 0; j < tam; j++) {
            result[i][j] = a[i][j] + b[i][j];
        }
    }
}
 
/*------------------------------------------------------------------------------*/
// function to subtract two matrices
inline void subtract(double **a, double **b, double **result, int tam) {
 
    int i, j;
 
    for (i = 0; i < tam; i++) {
        for (j = 0; j < tam; j++) {
            result[i][j] = a[i][j] - b[i][j];
        }
    }   
}

/*------------------------------------------------------------------------------*/
// naive method
void naive(double** A, double** B,double** C, int n)
{
	for (int i=0;i<n;i++)
    		for (int j=0;j<n;j++)
        		for(int k=0;k<n;k++)
            			C[i][j] += A[i][k]*B[k][j];
}

/*------------------------------------------------------------------------------*/
// Strassen's method
void strassen(double **a, double **b, double **c, int tam) 
{ 
    // Key observation: call naive method for matrices smaller than 2 x 2
    if(tam <= 4)
    {
	    naive(a,b,c,tam);
	    return;
    }
 
    // other cases are treated here:
    int newTam = tam/2;
    double **a11, **a12, **a21, **a22;
    double **b11, **b12, **b21, **b22;
    double **c11, **c12, **c21, **c22;
    double **p1, **p2, **p3, **p4, **p5, **p6, **p7;

    // memory allocation:
    a11 = allocate_matrix(newTam);
    a12 = allocate_matrix(newTam);
    a21 = allocate_matrix(newTam);
    a22 = allocate_matrix(newTam);

    b11 = allocate_matrix(newTam);
    b12 = allocate_matrix(newTam);
    b21 = allocate_matrix(newTam);
    b22 = allocate_matrix(newTam);

    c11 = allocate_matrix(newTam);
    c12 = allocate_matrix(newTam);
    c21 = allocate_matrix(newTam);
    c22 = allocate_matrix(newTam);

    p1 = allocate_matrix(newTam);
    p2 = allocate_matrix(newTam);
    p3 = allocate_matrix(newTam);
    p4 = allocate_matrix(newTam);
    p5 = allocate_matrix(newTam);
    p6 = allocate_matrix(newTam);
    p7 = allocate_matrix(newTam);

    double **aResult = allocate_matrix(newTam);
    double **bResult = allocate_matrix(newTam);

    //dividing the matrices in 4 sub-matrices:
    for (int i = 0; i < newTam; i++) {
        for (int j = 0; j < newTam; j++) {
            a11[i][j] = a[i][j];
            a12[i][j] = a[i][j + newTam];
            a21[i][j] = a[i + newTam][j];
            a22[i][j] = a[i + newTam][j + newTam];

            b11[i][j] = b[i][j];
            b12[i][j] = b[i][j + newTam];
            b21[i][j] = b[i + newTam][j];
            b22[i][j] = b[i + newTam][j + newTam];
         }
    }

       // Calculating p1 to p7:

       sum(a11, a22, aResult, newTam); // a11 + a22
       sum(b11, b22, bResult, newTam); // b11 + b22
       strassen(aResult, bResult, p1, newTam); // p1 = (a11+a22) * (b11+b22)

       sum(a21, a22, aResult, newTam); // a21 + a22
       strassen(aResult, b11, p2, newTam); // p2 = (a21+a22) * (b11)

       subtract(b12, b22, bResult, newTam); // b12 - b22
       strassen(a11, bResult, p3, newTam); // p3 = (a11) * (b12 - b22)

       subtract(b21, b11, bResult, newTam); // b21 - b11
       strassen(a22, bResult, p4, newTam); // p4 = (a22) * (b21 - b11)

       sum(a11, a12, aResult, newTam); // a11 + a12
       strassen(aResult, b22, p5, newTam); // p5 = (a11+a12) * (b22)   

       subtract(a21, a11, aResult, newTam); // a21 - a11
       sum(b11, b12, bResult, newTam); // b11 + b12
       strassen(aResult, bResult, p6, newTam); // p6 = (a21-a11) * (b11+b12)

       subtract(a12, a22, aResult, newTam); // a12 - a22
       sum(b21, b22, bResult, newTam); // b21 + b22
       strassen(aResult, bResult, p7, newTam); // p7 = (a12-a22) * (b21+b22)

       // calculating c21, c21, c11 e c22:

       sum(p3, p5, c12, newTam); // c12 = p3 + p5
       sum(p2, p4, c21, newTam); // c21 = p2 + p4

       sum(p1, p4, aResult, newTam); // p1 + p4
       sum(aResult, p7, bResult, newTam); // p1 + p4 + p7
       subtract(bResult, p5, c11, newTam); // c11 = p1 + p4 - p5 + p7

       sum(p1, p3, aResult, newTam); // p1 + p3
       sum(aResult, p6, bResult, newTam); // p1 + p3 + p6
       subtract(bResult, p2, c22, newTam); // c22 = p1 + p3 - p2 + p6

       // Grouping the results obtained in a single matrix:
       for (int i = 0; i < newTam ; i++) {
           for (int j = 0 ; j < newTam ; j++) {
               c[i][j] = c11[i][j];
               c[i][j + newTam] = c12[i][j];
               c[i + newTam][j] = c21[i][j];
               c[i + newTam][j + newTam] = c22[i][j];
           }
       }

       // deallocating memory (free):
       free_matrix(a11, newTam);
       free_matrix(a12, newTam);
       free_matrix(a21, newTam);
       free_matrix(a22, newTam);

       free_matrix(b11, newTam);
       free_matrix(b12, newTam);
       free_matrix(b21, newTam);
       free_matrix(b22, newTam);

       free_matrix(c11, newTam);
       free_matrix(c12, newTam);
       free_matrix(c21, newTam);
       free_matrix(c22, newTam);

       free_matrix(p1, newTam);
       free_matrix(p2, newTam);
       free_matrix(p3, newTam);
       free_matrix(p4, newTam);
       free_matrix(p5, newTam);
       free_matrix(p6, newTam);
       free_matrix(p7, newTam);
       free_matrix(aResult, newTam);
       free_matrix(bResult, newTam);
 
} // end of Strassen function

/*------------------------------------------------------------------------------*/
// Generate random matrices
void gen_matrix(double** M,int n)
{
	for(int i=0;i<n;++i)
	{
		for(int j=0;j<n;++j)
		{
			M[i][j]=rand()%100;
			//M[i][j]=1;
		}
	}
}

/*------------------------------------------------------------------------------*/
// print matrix M using specied fstream
void print_matrix(fstream& fs, double** M, int n)
{
	for(int i=0;i<n;++i)
	{
		for(int j=0;j<n;++j)
		{
			fs<<M[i][j]<<" ";
		}
		fs<<endl;
	}
	fs<<endl;
}

/*------------------------------------------------------------------------------*/
// record the generated matrix and the final product
void mat_mult_log(double** A, double** B,double** C,int n,char* file)
{
	fstream fs;
	fs.open(file,fstream::out);

	fs<<"Random Matrix A:"<<endl;
	print_matrix(fs,A,n);
	fs<<"Random Matrix B:"<<endl;
	print_matrix(fs,B,n);
	fs<<"C=A * B"<<endl;
	print_matrix(fs,C,n);

	fs.close();
}

/*------------------------------------------------------------------------------*/

int main(int argc, char** argv)
{
	srand(time(NULL));

	int mdim=2;	// matrix dimension
	char* output=NULL;
	bool is_strassen=false;
	int c;

	while ((c = getopt (argc, argv, "sn:o:")) != -1)
	{
		switch (c)
           	{
		case 's':
			is_strassen=true;
			break;
		case 'n':
             		mdim = pow((int)2,atoi(optarg)); // 2^n dimensions
             		break;
		case 'o':
             		output = optarg; // 2^n dimensions
             		break;
           	case '?':
             		if (optopt == 'n')
              			fprintf (stderr, "Option -%c requires an argument.\n", optopt);
             		else if (isprint (optopt))
               			fprintf (stderr, "Unknown option `-%c'.\n", optopt);
             		else
               			fprintf (stderr,
                        		"Unknown option character `\\x%x'.\n",
                        		optopt);
             		return 1;
           	default:
             		abort ();
           	}
	}

	// create new matrices
	double** A=allocate_matrix(mdim);
	double** B=allocate_matrix(mdim);
	double** C=allocate_matrix(mdim);
	gen_matrix(A,mdim);
	gen_matrix(B,mdim);

	// matrices multiplication
	if(is_strassen)
		strassen(A,B,C,mdim);
	else
		naive(A,B,C,mdim);

	if(output!=NULL)
		mat_mult_log(A,B,C,mdim,output);

	free_matrix(A,mdim);
	free_matrix(B,mdim);
	free_matrix(C,mdim);

	return 0;
}


#include <stdio.h> #include <stdlib.h> // 矩阵动态内存分配 int** allocate_matrix(int size) { int** matrix = (int**)malloc(size * sizeof(int*)); for (int i = 0; i < size; i++) { matrix[i] = (int*)malloc(size * sizeof(int)); } return matrix; } // 矩阵内存释放 void free_matrix(int** matrix, int size) { for (int i = 0; i < size; i++) { free(matrix[i]); } free(matrix); } // 标准矩阵乘法(用于小规模矩阵) void standard_mult(int** A, int** B, int** C, int n) { for (int i = 0; i < n; i++) { for (int k = 0; k < n; k++) { C[i][k] = 0; for (int j = 0; j < n; j++) { C[i][k] += A[i][j] * B[j][k]; } } } } // 矩阵加法 C = A + B void matrix_add(int** A, int** B, int** C, int size) { for (int i = 0; i < size; i++) { for (int j = 0; j < size; j++) { C[i][j] = A[i][j] + B[i][j]; } } } // 矩阵减法 C = A - B void matrix_sub(int** A, int** B, int** C, int size) { for (int i = 0; i < size; i++) { for (int j = 0; j < size; j++) { C[i][j] = A[i][j] - B[i][j]; } } } // 合并子矩阵到结果矩阵 void merge_matrix(int** C11, int** C12, int** C21, int** C22, int** C, int half) { for (int i = 0; i < half; i++) { for (int j = 0; j < half; j++) { C[i][j] = C11[i][j]; C[i][j + half] = C12[i][j]; C[i + half][j] = C21[i][j]; C[i + half][j + half] = C22[i][j]; } } } // Strassen核心算法 void strassen(int** A, int** B, int** C, int n) { if (n <= 64) { // 优化阈值参考[^3] standard_mult(A, B, C, n); return; } int half = n / 2; // 分配子矩阵内存 int** A11 = allocate_matrix(half); int** A12 = allocate_matrix(half); int** A21 = allocate_matrix(half); int** A22 = allocate_matrix(half); int** B11 = allocate_matrix(half); int** B12 = allocate_matrix(half); int** B21 = allocate_matrix(half); int** B22 = allocate_matrix(half); // 分割矩阵 for (int i = 0; i < half; i++) { for (int j = 0; j < half; j++) { A11[i][j] = A[i][j]; A12[i][j] = A[i][j + half]; A21[i][j] = A[i + half][j]; A22[i][j] = A[i + half][j + half]; B11[i][j] = B[i][j]; B12[i][j] = B[i][j + half]; B21[i][j] = B[i + half][j]; B22[i][j] = B[i + half][j + half]; } } // 分配临时矩阵内存 int** S1 = allocate_matrix(half); int** S2 = allocate_matrix(half); int** S3 = allocate_matrix(half); int** S4 = allocate_matrix(half); int** S5 = allocate_matrix(half); int** S6 = allocate_matrix(half); int** S7 = allocate_matrix(half); int** S8 = allocate_matrix(half); int** S9 = allocate_matrix(half); int** S10 = allocate_matrix(half); // 计算中间矩阵(参考[^2]) matrix_sub(B12, B22, S1, half); // S1 = B12 - B22 matrix_add(A11, A12, S2, half); // S2 = A11 + A12 matrix_add(A21, A22, S3, half); // S3 = A21 + A22 matrix_sub(B21, B11, S4, half); // S4 = B21 - B11 matrix_add(A11, A22, S5, half); // S5 = A11 + A22 matrix_add(B11, B22, S6, half); // S6 = B11 + B22 matrix_sub(A12, A22, S7, half); // S7 = A12 - A22 matrix_add(B21, B22, S8, half); // S8 = B21 + B22 matrix_sub(A11, A21, S9, half); // S9 = A11 - A21 matrix_add(B11, B12, S10, half); // S10 = B11 + B12 // 递归计算7个乘积矩阵 int** P1 = allocate_matrix(half); int** P2 = allocate_matrix(half); int** P3 = allocate_matrix(half); int** P4 = allocate_matrix(half); int** P5 = allocate_matrix(half); int** P6 = allocate_matrix(half); int** P7 = allocate_matrix(half); strassen(A11, S1, P1, half); // P1 = A11 * S1 strassen(S2, B22, P2, half); // P2 = S2 * B22 strassen(S3, B11, P3, half); // P3 = S3 * B11 strassen(A22, S4, P4, half); // P4 = A22 * S4 strassen(S5, S6, P5, half); // P5 = S5 * S6 strassen(S7, S8, P6, half); // P6 = S7 * S8 strassen(S9, S10, P7, half); // P7 = S9 * S10 // 计算结果子矩阵 int** C11 = allocate_matrix(half); int** C12 = allocate_matrix(half); int** C21 = allocate_matrix(half); int** C22 = allocate_matrix(half); matrix_add(P5, P4, C11, half); // C11 = P5 + P4 - P2 + P6 matrix_sub(C11, P2, C11, half); matrix_add(C11, P6, C11, half); matrix_add(P1, P2, C12, half); // C12 = P1 + P2 matrix_add(P3, P4, C21, half); // C21 = P3 + P4 matrix_add(P5, P1, C22, half); // C22 = P5 + P1 - P3 - P7 matrix_sub(C22, P3, C22, half); matrix_sub(C22, P7, C22, half); // 合并结果 merge_matrix(C11, C12, C21, C22, C, half); // 释放所有临时矩阵内存 free_matrix(A11, half); free_matrix(A12, half); // ... 释放所有分配的内存(此处省略实际代码) } int main() { int n = 8; // 必须是2的幂 int** A = allocate_matrix(n); int** B = allocate_matrix(n); int** C = allocate_matrix(n); // 初始化测试矩阵(示例数据) for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { A[i][j] = (i == j) ? 1 : 0; // 单位矩阵 B[i][j] = i * n + j; // 递增序列 } } strassen(A, B, C, n); // 输出结果(示例) printf("Result matrix:\n"); for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { printf("%6d ", C[i][j]); } printf("\n"); } // 释放内存 free_matrix(A, n); free_matrix(B, n); free_matrix(C, n); return 0; }根据上述代码,写出实验内容: 主要讲所如何实现的 复杂性分析: 主要两个复杂性分析 实验结果: 参考文献: 7-10个
最新发布
05-12
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值