最近刷了几道斜率优化dp算是对斜率优化有了一定的了解了
现在来小结一下
斜率优化dp的优化能力是将n^2优化成n,n^3优化成n^2
其转移方程一般是dp[k]=min(dp[i]+cost[i+1][k]),斜率优化优化的是排除一些不可能是最优解的解,那么什么情况下不可能是最优解呢
下面看一道题HDU 2829
状态转移方程是dp[i][j]=min(dp[k][j-1]+cost[k+1][i])
那么k2比k1更优的情况是
dp[k1][j-1]+cost[k1+1][i]>dp[k2][j-1]+cost[k2+1][i]
我们发现cost[k+1][i]=cost[1][i]-cost[1][k]-sum[k]*(sum[i]-sum[k])
那么我们带进去化简就可以得到
dp[k1][j-1]-cost[1][k1]+sum[k1]*sum[k1]-sum[k1]*sum[i]>dp[k2][j-1]-cost[1][k2]+sum[k2]*sum[k2]-sum[k2]*sum[i]
那么我们设
y1=dp[k1][j-1]-cost[1][k1]+sum[k1]*sum[k1]
y2=dp[k2][j-1]-cost[1][k2]+sum[k2]*sum[k2]
x1=sum[k1],x2=sum[k2]
所以k2比k1更优的条件是
y2-y1<(x2-x1)*sum[i]
如果等于那就是一样优,那就也是可以删掉k1的那么我们可以在条件上加一个=号
也就是(y2-y1)<=(x2-x1)*sum[i]
假设我有三个数k1,k2,k3, k1<k2<k3
那么k2永远都不会是最优的情况是
(y3-y2)/(x3-x2)<(y2-y1)/(x2-x1)
我们分类讨论一下如果对于sum[i]来说
如果(y3-y2)/(x3-x2)<=sum[i]那么k3比k2优或者一样优,那么都是可以删掉的
如果(y3-y2)/(x3-x2)>sum[i],此时(y2-y1)/(x2-x1)>sum[i]那么k1比k2优
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#define maxn 1500
using namespace std;
int sum[maxn],dp[maxn][maxn],cost[maxn];
int que[maxn];
int DP(int n,int m){
for (int k=1;k<=n;k++){
dp[k][k-1]=0;
dp[k][0]=cost[k];
}//初始化完成
int head,tail;
for (int j=1;j<=m;j++){
head=tail=0;
//我们优化的是最后一个遍历所以对于每一个i我们都有一个新的que //对于一个新的j就是说我们分成j段
//我要遍历的第一个k就是dp[j][j-1]
//所以我们把j入队即可 que[tail++]=j; for (int i=j+1;i<=n;i++){ while(head+1<tail){ int q1=que[head],q2=que[head+1]; int y1=dp[q1][j-1]-cost[q1]+sum[q1]*sum[q1]; int y2=dp[q2][j-1]-cost[q2]+sum[q2]*sum[q2]; int x1=sum[q1],x2=sum[q2];
if ((y2-y1)<=sum[i]*(x2-x1)) head++; else break; }
//利用sum[i]从头到尾找最优 //下面解释下为什么这个是最优的,可以看到前面的已经是最优了 //我们发现他的斜率是单调递减的,也就是说此时head到后面任意一点的斜率永远>sum[i],所以head是最优的情况
int k=que[head];
dp[i][j]=dp[k][j-1]+cost[i]-cost[k]-sum[k]*(sum[i]-sum[k]);
while(head+1<tail){
int q1=que[tail-2],q2=que[tail-1],q3=i;
int y3=dp[q3][j-1]-cost[q3]+sum[q3]*sum[q3];
int y2=dp[q2][j-1]-cost[q2]+sum[q2]*sum[q2];
int y1=dp[q1][j-1]-cost[q1]+sum[q1]*sum[q1];
int x3=sum[q3],x2=sum[q2],x1=sum[q1];
if (((y3-y2)*(x2-x1))<=((y2-y1)*(x3-x2))) tail--;
else break;
}
que[tail++]=i; //因为我们要遍历到的下一个节点是dp[i+1][j],他的k的取值范围0-i所以我们要把i入队 //在入队的时候我们就要淘汰y2,这样子不仅淘汰了非最优解还维护了队列斜率的单调递减性质
}
}
return dp[n][m];
}
int main()
{
int m,n;
while(~scanf("%d %d",&n,&m)&&(n&&m)){
sum[0]=0;cost[0]=0;
for (int k=1;k<=n;k++){
int save;
scanf("%d",&save);
sum[k]=sum[k-1]+save;
cost[k]=cost[k-1]+sum[k-1]*save;
}
printf("%d\n",DP(n,m));
}
return 0;
}
1078

被折叠的 条评论
为什么被折叠?



