strassen算法之python

最近在看《算法导论》第四章,发现矩阵乘法很有意思,书上面讲述了暴力法,直接递归分治法以及strassen算法,有兴趣的同学看一看一下这本书。前面两种方法理解起来比较简单,下面我讲述一下自己对strassen算法理解。
首先,明确一下strassen算法的使用是有限制的,它针对的是n*n矩阵,且n为2的幂数。当所求的矩阵不满足要求时,我们可以通过补充0元素来构建满足该算法要求的矩阵。strassen算法的优化在于将需要8次递归调用减为7次递归调用,即原本需要8次分割之后子矩阵之间相乘,strassen只需要7次。原因是最初矩阵分割的子矩阵通过加减得到新子矩阵,这些新子矩阵之间进行7次相乘,再进行矩阵之间的加减就可以构建出最终矩阵的子矩阵。所以在计算较大矩阵相乘时计算时间会比前两种短,前两种的运行时间与n^3
成正比,而strassen算法运行时间与n^log2(7),log2(7)约为2.81。
python代码如下:
#strassen方法限定了矩阵的维数n x n,只能为2次幂,由原来要计算8次缩减为7次.

def matrix_add(matrix_a, matrix_b):#将两个矩阵进行相加


    rows = len(matrix_a)#得到其中一个矩阵的行数
    columns = len(matrix_a[0])#得到矩阵的列数
    matrix_c = [list() for i in range(rows)]#生成行数个空列表
    for i in range(rows):
        for j in range(columns):
             matrix_c_temp = matrix_a[i][j] + matrix_b[i][j]
             matrix_c[i].append(matrix_c_temp)#将两个数组对应位置的元
### 回答1: Strassen矩阵乘法是一种高效的矩阵乘法算法,它可以在较短的时间内计算出两个矩阵的乘积。在Python中,可以使用递归的方式实现Strassen矩阵乘法算法。以下是一个示例代码: ```python import numpy as np def strassen(A, B): n = len(A) if n == 1: return A * B else: # 将矩阵A和B分成四个子矩阵 A11, A12, A21, A22 = A[:n//2, :n//2], A[:n//2, n//2:], A[n//2:, :n//2], A[n//2:, n//2:] B11, B12, B21, B22 = B[:n//2, :n//2], B[:n//2, n//2:], B[n//2:, :n//2], B[n//2:, n//2:] # 计算七个子矩阵P1-P7 P1 = strassen(A11 + A22, B11 + B22) P2 = strassen(A21 + A22, B11) P3 = strassen(A11, B12 - B22) P4 = strassen(A22, B21 - B11) P5 = strassen(A11 + A12, B22) P6 = strassen(A21 - A11, B11 + B12) P7 = strassen(A12 - A22, B21 + B22) # 计算结果矩阵C的四个子矩阵 C11 = P1 + P4 - P5 + P7 C12 = P3 + P5 C21 = P2 + P4 C22 = P1 - P2 + P3 + P6 # 将四个子矩阵合并成结果矩阵C C = np.zeros((n, n)) C[:n//2, :n//2], C[:n//2, n//2:], C[n//2:, :n//2], C[n//2:, n//2:] = C11, C12, C21, C22 return C ``` 该函数接受两个矩阵A和B作为输入,并返回它们的乘积。在函数内部,首先检查矩阵的大小是否为1,如果是,则直接返回它们的乘积。否则,将矩阵A和B分成四个子矩阵,并递归地计算七个子矩阵P1-P7。然后,将四个子矩阵合并成结果矩阵C,并返回它。 ### 回答2: Strassen矩阵乘法法是一种用于矩阵乘法计算的分治算法,它采用递归和矩阵分解的方法将两个大矩阵分解成四个子矩阵,以较小的子矩阵计算矩阵乘积,最后再将结果组合成一个大的矩阵。 Python中可以通过递归的方式实现Strassen矩阵乘法,步骤如下: 1. 定义一个函数,接收两个矩阵A和B作为参数。 2. 检查矩阵的大小是否符合要求,如果不符合则进行矩阵补零。 3. 根据Strassen算法,将矩阵A和B分解成四个子矩阵,称为A11、A12、A21、A22和B11、B12、B21、B22。 4. 用递归的方式计算P1、P2、P3、P4、P5、P6、P7,其中: - P1 = (A11 + A22)(B11 + B22) - P2 = (A21 + A22)B11 - P3 = A11(B12 - B22) - P4 = A22(B21 - B11) - P5 = (A11 + A12)B22 - P6 = (A21 - A11)(B11 + B12) - P7 = (A12 - A22)(B21 + B22) 这种计算方法避免了逐个计算矩阵元素的低效率。 5. 根据P1至P7的值计算矩阵C11、C12、C21、C22。 6. 根据C11、C12、C21、C22将矩阵C组合成一个大的矩阵。 这样就完成了矩阵乘法的计算。需要注意的是,Strassen算法对于矩阵大小的要求比较特殊,要求矩阵大小为2的幂次方。因此,在程序中需要对矩阵进行补零或者截取而使其满足大小要求。 以下是一个简单的Strassen矩阵乘法的Python实现: ```python def strassen_matrix_mul(A, B): size = len(A) if size == 1: return [[A[0][0]*B[0][0]]] # Padding A and B to make their sizes power of 2 while size % 2 != 0: A.append([0] * size) B.append([0] * size) size += 1 for i in range(size): A[i].append(0) B[i].append(0) mid = size // 2 # Partition matrices into submatrices A11 = [A[i][0:mid] for i in range(0,mid)] A12 = [A[i][mid:size] for i in range(0,mid)] A21 = [A[i][0:mid] for i in range(mid:size)] A22 = [A[i][mid:size] for i in range(mid:size)] B11 = [B[i][0:mid] for i in range(0,mid)] B12 = [B[i][mid:size] for i in range(0,mid)] B21 = [B[i][0:mid] for i in range(mid:size)] B22 = [B[i][mid:size] for i in range(mid:size)] # Calculate P1 to P7 P1 = strassen_matrix_mul(add(A11, A22), add(B11, B22)) P2 = strassen_matrix_mul(add(A21, A22), B11) P3 = strassen_matrix_mul(A11, subtract(B12, B22)) P4 = strassen_matrix_mul(A22, subtract(B21, B11)) P5 = strassen_matrix_mul(add(A11, A12), B22) P6 = strassen_matrix_mul(subtract(A21, A11), add(B11, B12)) P7 = strassen_matrix_mul(subtract(A12, A22), add(B21, B22)) # Calculate submatrices of C C11 = subtract(add(add(P1, P4), P7), P5) C12 = add(P3, P5) C21 = add(P2, P4) C22 = subtract(add(add(P1, P3), P6), P2) # Combine submatrices of C into a single matrix C = [] for i in range(0, mid): row = C11[i] + C12[i] C.append(row) for i in range(0, mid): row = C21[i] + C22[i] C.append(row) return C def add(A, B): return [[A[i][j] + B[i][j] for j in range(0,len(A))] for i in range(0,len(A))] def subtract(A, B): return [[A[i][j] - B[i][j] for j in range(0,len(A))] for i in range(0,len(A))] ``` 对于输入的矩阵A和B,可以通过strassen_matrix_mul函数计算它们的乘积,并返回结果矩阵C。其中,add和subtract函数是辅助函数,用于对矩阵进行加法和减法计算。 在实际运用中,Strassen算法的效率很高,但是在一些情况下,它并不是最优解,因此需要结合具体的应用场景进行选择。 ### 回答3: Strassen矩阵乘法是一种基于分治策略的矩阵乘法算法,在某些情况下可以比普通的矩阵乘法算法更快地计算矩阵乘积。Python是一种动态类型、面向对象、解释性的高级编程语言,因其易用性和丰富的库文件而受到广泛关注。 在Python中实现Strassen矩阵乘法,首先需要将矩阵分解为更小的子矩阵。然后,通过逐层分治的方式,将每个子矩阵乘以自己的转置矩阵,再将结果组合起来,得到原始矩阵的乘积。 下面是一个简单的Python代码实现: ```python def strassen_multiply(a, b): n = len(a) if n == 1: return [[a[0][0] * b[0][0]]] else: # divide matrices into submatrices a11, a12, a21, a22 = split_matrix(a) b11, b12, b21, b22 = split_matrix(b) # compute products of submatrices m1 = strassen_multiply(add_matrices(a11, a22), add_matrices(b11, b22)) m2 = strassen_multiply(add_matrices(a21, a22), b11) m3 = strassen_multiply(a11, subtract_matrices(b12, b22)) m4 = strassen_multiply(a22, subtract_matrices(b21, b11)) m5 = strassen_multiply(add_matrices(a11, a12), b22) m6 = strassen_multiply(subtract_matrices(a21, a11), add_matrices(b11, b12)) m7 = strassen_multiply(subtract_matrices(a12, a22), add_matrices(b21, b22)) # combine submatrices to construct result c11 = add_matrices(subtract_matrices(add_matrices(m1, m4), m5), m7) c12 = add_matrices(m3, m5) c21 = add_matrices(m2, m4) c22 = add_matrices(subtract_matrices(add_matrices(m1, m3), m2), m6) # construct result matrix from submatrices return merge_matrices(c11, c12, c21, c22) ``` 在此Python代码中,函数`strassen_multiply`接受两个矩阵`a`和`b`作为参数,并返回它们的乘积。首先,如果矩阵是大小为1的矩阵,则直接返回其乘积。否则,我们将矩阵分解为四个子矩阵,对每个子矩阵进行递归调用,并进行一系列矩阵操作来计算结果矩阵。最后,将子矩阵合并为结果矩阵。 总体来说,Strassen矩阵乘法能够在一定程度上优化矩阵乘积的计算时间。但是,由于其需要递归地对矩阵进行分解和重组,因此在某些情况下,普通的矩阵乘法算法Strassen算法更有效率。因此,在实际使用中,我们应该根据具体情况选择合适的矩阵乘法算法以获得最优的性能。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值