Day2:矩阵链乘法
一. 问题背景:
给定一个n个矩阵的序列(矩阵链),我们希望计算它们的乘积
为了计算表达式,我们可以先用括号明确计算次序,然后利用标准的矩阵相乘算法进行计算。由于矩阵乘法满足结合律,因此任何加括号的方法都会得到相同的计算结果。我们称有如下性质的矩阵乘积链为完全括号话:它是单一矩阵,或者是两个完全括号化的矩阵乘积链的积,且已外加括号。例如,如果矩阵链为,则共有 5 种完全括号化的矩阵乘积链:
如果是
的矩阵,
是
的矩阵那么乘积
是
的矩阵。计算
的代价为
,我们使用标量乘法的次数来表示计算代价。以矩阵链
为例,来说明不同的加括号方式会导致不同的计算代价。假设三个矩阵的大小分别为10*100,100*5和5*50。如果按照
的顺序计算,需要做10*100*5+10*5*50=7500次标量乘法。如果按照
的顺序计算,共需100*5*50+10*100*50=75000次标量乘法。因此第一种顺序计算要比第二种顺序计算快10倍。
矩阵链乘法问题:给定n个矩阵的链,矩阵
的规模为
,求完全括号化方案,使得计算乘积
所需的标量乘法次数最少。
二. 解决思路:
1. 最优解结构
为了方便起见,我们用符号表示
乘积的结果矩阵。假设
的最优括号化方案的分割点在
和
之间。那么,继续对“前缀”子链
,“后缀”
进行括号化 (独立求解)。
2. 递归定义
令表示计算矩阵所需标量乘法次数的最小值,原问题的最优解—–计算
所需的最低代价就是
。我们假设
的最优括号化方案的分割点在矩阵
和
之间,其中
。那么,
就等于计算
和
的代价加上两者相乘的代价的最小值:
此递归公式假定最优分割点k是已知的,但实际上我们是不知道的。不过,k只有种可能的取值,即
。由于最优分割点必在其中,我们只需检查所有可能情况,找到最优者即可。
因此,的最优括号化方案的递归求解公式为:
的值给出了子问题最优解的代价,但它并未提供足够的信息来构造最优解。为此,我们用
保存最优括号化方案的分割点位置k,即使得
成立的k值。
3. 计算最优代价
现在,我们可以很容易地基于递归公式写出一个递归算法,但递归算法是指数时间的,并不必检查若有括号化方案的暴力搜索方法更好。注意到,我们需要求解的不同子问题的数目是相对较少的。递归算法会在递归调用树的不同分支中多次遇到同一个子问题。这种子问题重叠的性质是应用动态规划的另一标识(第一个标识是最优子结构)。
①采用自底向上的方法代替递归算法来计算最优代价(参考“钢管切割”问题)。此过程假定矩阵 的规模为
。它的输入是一个序列
,其长度为n+1。用
来保存代价
,用
记录最优值
对应的分割点k。这样就可以利用表s构造最优解。
对于矩阵链最优括号化的子问题,我们认为其规模为链的长度
。因为
个矩阵链相乘的最优计算代价
只依赖于少于
个矩阵链相乘的最优计算代价。因此,算法应该按长度递增的顺序求解矩阵链括号化问题,并按对应的顺序填写表m。
②采用备忘录的方法代替递归算法来计算最优代价。
备忘录方法也用一个表格来保存已解决的子问题的答案。在下次需要解决此问题时,只要简单地查看该子问题的解答,而不必重新计算。但与动态规划不同:备忘录方法的递归方式是自顶向下的,而动态规划算法则是自底向上递归的。备忘录方法的控制结构与直接递归方法的控制结构相同,区别在于备忘录方法为每个解过的子问题建立了备忘录以备需要时查看,避免了相同子问题的重复求解。
注意:备忘录方法为每个子问题建立了一个记录项,初始化时,该记录项存入一个特殊的值,表示该子问题尚未求解。
在求解过程中,对每个待求的子问题,首先查看相应的记录项。若记录项中存储的是初始化时存入的特殊值,则表示该子问题是第一次遇到,则此时计算出该子问题的解,并保存在相应的记录项中。若记录项中存储的已不是初始化时存入的特殊值,则表示该问题已被计算过,其相应的记录项中存储的是该子问题的解答。此时,只要从记录项中取出该子问题的解答即可。
4. 构造最优解
借助Python字典,可以很容易的一次性存储m[n][n]和s[n][n](或可以直接保存最优解形式)的内容。
三. 算法实现:
def mult(chain):
n = len(chain)
# single matrix chain has zero cost
aux = {(i, i): (0,) + chain[i] for i in range(n)} #元组中只包含一个元素时,需要在元素后面添加逗号;元组中的元素值是不允许修改的,但我们可以使用‘+’对元组进行连接组合
# i: length of subchain #aux中存储m[n][n],即全部子问题最优解
for i in range(1, n): #length = i + 1
# j: starting index of subchain
for j in range(0, n - i):
best = float('inf') #表示正无穷
# k: splitting point of subchain
for k in range(j, j + i):
# multiply subchains at splitting point
lcost, lname, lrow, lcol = aux[j, k]
rcost, rname, rrow, rcol = aux[k + 1, j + i]
cost = lcost + rcost + lrow * lcol * rcol
var = '(%s%s)' % (lname, rname)
# pick the best one
if cost < best:
best = cost
aux[j, j + i] = cost, var, lrow, rcol
#print("{}\n".format(aux))
return dict(zip(['cost', 'order', 'rows', 'cols'], aux[0, n - 1]))
result = mult([('A', 10, 20), ('B', 20, 30), ('C', 30, 40),('D',40,50)])
print(result)
注:补充参考:https://blog.youkuaiyun.com/c18219227162/article/details/50412333