动态规划-矩阵链乘法

本文介绍了如何使用动态规划解决矩阵链乘法问题,以找到计算代价最小的矩阵乘积顺序。通过刻画最优子结构特征、递归定义最优解以及自底向上的计算方法,展示了动态规划在降低矩阵乘法计算代价中的应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

问题

给定一个n个矩阵的序列(即矩阵链,n = < A1 , A2 , …, An >),并计算n个矩阵序列的乘积(S = A1 A2... An )。

由于矩阵的乘法是满足结合律的,所以可以通过任意的添加括号明确计算的顺序。不同位置的括号对矩阵乘法的计算代价会产生巨大影响,以矩阵链n = < A1 , A2 , A3 >为例,其中包含了两种不同的加括号方式:

  • ( A1 A2 ) A3
  • A1 ( A2 A3 )

在比较这两种方式的计算代价之前,首先看一下计算代价的定义,假设存在A和B两个矩阵,A = p X q,B = q X r,则矩阵AB的乘积执行了pqr次乘法计算(乘积矩阵上的每一个数需要计算q次,一共有pr个数需要计算),即AB的乘积的计算代价为pqr。

假设 A1 = 10 X 100, A2 = 100 X 5, A3 = 5 X 50,则第一种方式的计算代价为10 * 100 * 5 + 10 * 5 * 50 = 7500,第二种方式的计算代价为100 * 5 * 50 + 10 * 100 * 50 = 75000。两种计算方式的代价相差了10倍。

因此,在计算一个矩阵序列的乘积时,通常有一个最优的矩阵乘积的计算顺序使计算代价最低。矩阵链乘法就是通过动态规划的思想找到计算代价最小的矩阵乘积顺序。

暴力

如果不是通过动态规划的思想找出最优解,而是通过暴力解决问题,那么首先考虑的一个问题是,一个包含n个矩阵的矩阵序列包含多少个括号化方案?

详细证明略过(其实是不会。。),n个矩阵的矩阵序列的括号化方案符合卡塔兰数序列。卡塔兰数的计算公式为: hn = 1n+1 (2nn) = 1n+1 Cn2n = 2n!(n+1)!n! ,而n个矩阵的矩阵序列的括号化方案为h(n - 1)。

当n = 3时,h(2) = 2,即三个矩阵相乘时,具有两种括号化方案,与上述例子符合。

其中,卡塔兰数的增加趋势趋近于: hn ~ Ω ( 4n / n3/2 )。由于括号化方案的数量呈指数关系,因此暴力策略并不适合解决矩阵链乘的最优解问题。

动态规划

刻画最优子结构特征

假设 Ai..j (i <= j)表示 Ai Ai+1 Aj 矩阵序列的乘积,那么如果对该序列进行括号化,那么必须在k点(i <= k < j)将矩阵序列划分开,即 Ai..j 划分成 Ai..k Ak+1...j 。因此, Ai..j 的计算代价由 Ai..k 的计算代价, Ak+1..j 的计算代价,以及 Ai..k Ak+1..j 相乘的代价组成。

为了确保 Ai..j 得到最优解,那么就意味着 Ai..k Ak+1..j 也必须是最优解。

递归定义最优解

假设通过min_product[i, j]表示计算 Ai..j 矩阵链的最小乘积,分两种情况讨论:

  • 当i = j的时候,矩阵链中只含有一个矩阵,所需要的乘积运算次数为0,因此min_product[i, j] = 0。
  • 当i < j的时候,矩阵链中包含多个矩阵,需要定义一个k点将该矩阵链划分,假设将 Ai..j 矩阵链划分为 Ai..k Ak+1..j ,则 Ai..j 的计算代价为 Ai..k 的计算代价加上 Ak+1..j 的计算代价,并加上两者相乘的计算代价。
    • Ai..k 的计算代价为min_product[i, k]
    • Ak+1,j 的计算代价为min_product[k+1, j]
    • 假设 Ai = pi1 X pi Ak = pk1 X pk Aj = pj1 X pj ,则两者相乘的代价为 pi1 pk pj
    • 因此min_prodcut[i, j] = min_product[i, k] + min_product[k + 1, j] + pi1 pk pj 。(实际计算时i <= k < j,min_product[i, j]取其中的最小值)

计算最优值(自底向上)

实际计算的时候可以通过三层嵌套循环来实现:

  • 第一层循环以矩阵链长为变量
  • 第二层循环以min_product的行数为变量
  • 第三层循环以划分点k为变量

代码如下:

import sys


try:
    max_number = sys.maxsize
except AttributeError:
    max_number = sys.maxint


class MatrixChain(object):

    def __init__(self, matrix_chain):
        """
        以矩阵链[5, 10, 3, 12, 5, 50, 6]为例,该矩阵链包含了6个矩阵,分别为5 X 10, 10 X 3, 3 X 12, 12 X 5, 5 X 50, 50 X 6。因此可以看出,矩阵数 = 矩阵链的长度 - 1。为了使矩阵链的索引从1开始,在list开头插入None。
        """

        self.data = matrix_chain
        self.matrix_count = len(matrix_chain) - 1
        self.data.insert(0, None)

        # self.min_product存储计算代价,self.k存储最优划分点,分别默认初始化为0和None。
        # 由于range函数是左闭右开区间,因此self.matrix_count需要加1
        self.min_product = []
        self.k = []
        for _ in range(self.matrix_count + 1):
            self.min_product.append([0 for _ in range(self.matrix_count + 1)])
            self.k.append([None for _ in range(self.matrix_count + 1)])

        self.__matrix_chain_order()

        print("Input List is %s" % self.data)
        print("Min Product:")
        self.print_min_product()

        print("K: ")
        self.print_optimal_parens(1, self.matrix_count)
        print()

    def __matrix_chain_order(self):
        for chain_length in range(2, self.matrix_count + 1):
            for row in range(1, self.matrix_count - chain_length + 1 + 1):
                col = row + chain_length - 1
                self.min_product[row][col] = max_number
                for k in range(row, col):
                    result = self.min_product[row][k] + self.min_product[k + 1][col] + \
                             self.data[row] * self.data[k + 1] * self.data[col + 1]
                    if result < self.min_product[row][col]:
                        self.min_product[row][col] = result
                        self.k[row][col] = k

    def print_min_product(self):
        for row in range(1, self.matrix_count + 1):
            for col in range(1, self.matrix_count + 1):
                print(self.min_product[row][col], end="\t")
            print()

    def print_optimal_parens(self, start, end):
        if start == end:
            print("A%s" % start, end='')
        else:
            print("(", end="")
            self.print_optimal_parens(start, self.k[start][end])
            self.print_optimal_parens(self.k[start][end] + 1, end)
            print(")", end="")

如果清楚该算法的思想,那么实现的难点主要在于边界值的处理,为了便于处理,假设所有数组的索引都从1开始。

按照之前说的三层嵌套循环分别处理对应的变量:

  • 链长:在初始化的时候已经处理过链长为1的情况了,当矩阵链中只含有一个矩阵的时候计算代价为0,因此链长的范围为[2, 矩阵的个数]。
  • 行:在确定了链长之后,可以开始对每一行进行处理,每一行只需要处理一次,即只要求处理 Ai..i+chainlength1 即可,这里需要减1是因为链长比矩阵个数多1的缘故。随着链长的增加,一些行数包含的矩阵个数小于链长 - 1,则对这些行数不需要再进行计算。
  • 划分点k:按照公式求出最小值即可。

构造最优解

通过递归的方法,按照self.k中记录的最佳划分点进行括号化方案的构造。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值