矩阵乘法的Strassen算法

本文介绍了矩阵乘法的基本算法及其实现,并提出了一种基于分治的优化算法。进一步介绍了Strassen算法,通过减少递归调用次数降低时间复杂度。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

A=(aij) B=(bij) nn 方阵,则对于i,j=1,2,…,n;定义乘积C = A * B中的元素 cij 为:

cij=k=1naikbkj

我们需要计算 n2 个矩阵元素,每个元素是n个值的和。下面过程接收n*n矩阵A和B,返回它们的乘积 n* n矩阵C.假设每个矩阵都有一个属性rows,给出矩阵的行数。
SQUARE-MATRIX-MULTIPLY(A, B)

n = A.rows
let C be a new n*n matrix
for i = 1 to n
    for j = 1 to n
        c(ij) = 0
        for k = 1 to n
            c(ij) = c(ij)+a(ik)*b(kj)
return C

C语言代码实现如下:

void Mul(int **matrixA,int **matrixB, int **matrixC, int n)
{
    for (int i=0; i < n; ++i){
        for (int j=0; j < n; ++j){
            matrixC[i][j] = 0;
            for (int k=0; k < n; ++k)
                matrixC[i][j] += matrixA[i][k] + matrixB[k][j];
        }
    }
}

由于三重for循环的每一重都恰好执行n步,而第7步每次执行都花费常量时间,因此过程SQUARE-MATRIX-MULTIPLY(A, B)花费 Θ(n3)

试着用一个简单的分治算法优化刚刚的矩阵算法, C = A * B,假定三个矩阵均为 n * n 矩阵,其中 n 为 2 的幂。我们做这个假定因为在每个分解步骤中, n * n 矩阵都被划分为 4 个 n/2 * n/2 的子矩阵,如果假定 n 是 2 的幂, 则 只要 n2 即可保证子矩阵规模 n/2 为整数。
假定将 A、B 和 C 均分解为 4 个 n/2 * n/2 的子矩阵:

A = [A11A21A12A22] , B = [B11B21B12B22] , C = [C11C21C12C22] (公式4.9)

因此可以将公式C = A * B 改写为

[A11A21A12A22] * [B11B21B12B22] = [C11C21C12C22]

等价于如下 4 个公式:

C11=A11B11+A12B21
C12=A11B12+A12B22
C21=A21B11+A22B21
C22=A21B12+A22B22

直接的递归分治算法:

SQUARE-MATRIX-MULTIPLY-RECURSIVE(A, B)

n = A.rows
let C be a new n * n matrix
if n == 1
    c11 = a11 * b11
else partion A, B, and C as in equations (4, 9)
    C11 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A11, B11)
    + SQUARE-MATRIX-MULTIPLY-RECURSIVE(A12, B21)
    C12 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A11, B12)
    + SQUARE-MATRIX-MULTIPLY-RECURSIVE(A12, B22)
    C21 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A21, B11)
    + SQUARE-MATRIX-MULTIPLY-RECURSIVE(A22, B21)
    C22 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A21, B12)
    + SQUARE-MATRIX-MULTIPLY-RECURSIVE(A22, B22)
return C

由于每次递归调用完成两个 n/2 * n/2 矩阵的乘法,因此花费时间为T(n/2), 8 次递归调用总时间为8T(n/2). 我们还需要计算第6-9行的4次矩阵加法。每个矩阵包含 n2/4 个元素,因此,每次矩阵加法花费 Θ(n2) 时间其他的时间开销为 Θ(1) 故矩阵加法时间之和:
T(n)=Θ(1)+8T(n/2)+Θ(n2)=8T(n/2)+Θ(n2)

T(n)={Θ(1)8T(n/2)+Θ(n2)==n=1n>1

利用主方法求解上式得:
T(n)=Θ(n3)
( T(n)=aT(n/b)+f(n)Θ(nlogba)>f(n),Θ(nlog28)=Θ(n3)>Θ(n2) )
可见用了一般的分治算法并没有使算法的 T(n) 减少

所以才有了天才般的Strassen方法,Strassen算法的核心思想是令递归树稍微不那么茂盛一点儿,即只进行7次而不是8次递n/2 * n/2的矩阵乘法。它包含4个步骤:
1. 按公式(4.9)将输入矩阵A、B和输出矩阵C分解为n/2 * n/2 的子矩阵。采用下标计算方法,此步骤花费 Θ(1) 时间,与SQUARE-MATRIX-MULTIPLY-RECURSIVE相同
2. 创建10个n/2 * n/2的矩阵 S1,S2,...,S10 每个矩阵保存步骤 1 中创建的两个矩阵的和或差。花费时间为 Θ(n2)
3. 用步骤1中创建的子矩阵和步骤2中创建的10个矩阵,递归的计算7个矩阵积 P1,P2,..P7 。每个矩阵 Pi 都是n/2 * n/2 的。
4. 通过 Pi 矩阵的不同组合进行加减运算,计算出结果矩阵C的子矩阵 C11,C12,C21,C22 .花费时间 Θ(n2) .

故我们得到Strassen算法的运行时间递归式:

T(n)={Θ(1)7T(n/2)+Θ(n2)==n=1n>1

利用主方法可以求解出递归式的解为 T(n)=Θ(nlg7) .
创建的十个矩阵:
S1=B12B22
S2=A11+A12
S3=A21+A22
S4=B21B11
S5=A11+A22
S6=B11+B22
S7=A12A22
S8=B21+B22
S9=A11A21
S10=B11+B12
由于必须进行10次n/2 * n/2的矩阵加减法,因此,该步骤花费 Θ(n2) 时间。
如下所示:
P1=A11S1=A11B12A11B22
P2=B22S2=A11B22+A12B22
P3=B11S3=A21B11+A22B11
P4=A22S4=A22B21A22B11
P5=S5S6=A11B12+A11B22+A22B11+A22B22
P6=S7S8=A12B21+A12B22A22B21A22B22
P7=S9S10=A11B11+A11B12A21B11A21B12
首先
C11=P5+P4P2+P6 (自己可以验证)
C12=P1+P2
C21=P3+P4
C22=P5+P1P3P7

总结:矩阵乘法一般意义上还是选择:朴素的方法(暴力解法),只有当矩阵阶数很大(稠密)时,才会选择Strassen算法.
详细的比较看这里:http://www.mamicode.com/info-detail-673908.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值