这显然是一道dp+斜率优化题
开始大力推公式
dp[i]=minj=1i−1dp[j]+f(sum[i]−sum[j]),其中f(x)=ax2+bx+cdp[i]=minj=1i−1dp[j]+f(sum[i]−sum[j]),其中f(x)=ax2+bx+c
展开,dp[i]=minj=1i−1dp[j]+asum[i]2−2asum[i]sum[j]+asum[j]2+bsum[i]−bsum[j]+c展开,dp[i]=minj=1i−1dp[j]+asum[i]2−2asum[i]sum[j]+asum[j]2+bsum[i]−bsum[j]+c
考虑决策j和k,若决策j优于决策k,则应满足如下条件:dp[j]+asum[i]2−2asum[i]sum[j]+asum[j]2+bsum[i]−bsum[j]+c>dp[k]+asum[i]2−2asum[i]sum[k]+asum[k]2+bsum[i]−bsum[k]+c考虑决策j和k,若决策j优于决策k,则应满足如下条件:dp[j]+asum[i]2−2asum[i]sum[j]+asum[j]2+bsum[i]−bsum[j]+c>dp[k]+asum[i]2−2asum[i]sum[k]+asum[k]2+bsum[i]−bsum[k]+c
移项,合并同类项:2asum[i](sum[k]−sum[j])>dp[k]−dp[j]+asum[k]2−asum[j]2+bsum[j]−bsum[k]移项,合并同类项:2asum[i](sum[k]−sum[j])>dp[k]−dp[j]+asum[k]2−asum[j]2+bsum[j]−bsum[k]
把左边的除过去:
2asum[i]>(dp[k]+asum[k]2−bsum[k])−(dp[j]+asum[j]2−bsum[j])sum[k]−sum[j]2asum[i]>(dp[k]+asum[k]2−bsum[k])−(dp[j]+asum[j]2−bsum[j])sum[k]−sum[j]
对于jj,可以把看做一个点
如果两个点的斜率小于2asum[i]2asum[i],决策jj就比决策优
然后我们发现aa是负的,是单增的,所以2asum[i]2asum[i]是单减的
所以我们在单调队列里面维护的实际上是一个上凸壳
#include <cstdio>
#include <iostream>
#include <cstring>
#include <string>
#include <cstdlib>
#include <utility>
#include <cctype>
#include <algorithm>
#include <bitset>
#include <set>
#include <map>
#include <vector>
#include <queue>
#include <deque>
#include <stack>
#include <cmath>
#define LL long long
#define LB long double
#define x first
#define y second
#define Pair pair<int,int>
#define pb push_back
#define pf push_front
#define mp make_pair
#define LOWBIT(x) x & (-x)
using namespace std;
const int MOD=998244353;
const LL LINF=2e16;
const int INF=1e9;
const int magic=348;
const double eps=1e-10;
inline int getint()
{
char ch;int res;bool f;
while (!isdigit(ch=getchar()) && ch!='-') {}
if (ch=='-') f=false,res=0; else f=true,res=ch-'0';
while (isdigit(ch=getchar())) res=res*10+ch-'0';
return f?res:-res;
}
int n;
LL a,b,c;
LL aa[1000048],sum[1000048];
LL dp[1000048];
struct node
{
LL x,y;
int ind;
};
inline double calc(int x,int y)
{
return double(dp[x]+a*sum[x]*sum[x]+b*sum[y]-(dp[y]+a*sum[y]*sum[y]+b*sum[x]))/(sum[x]-sum[y]);
}
int head,tail;int q[1000048];
int main ()
{
int i;
n=getint();
a=getint();b=getint();c=getint();
for (i=1;i<=n;i++) aa[i]=getint(),sum[i]=sum[i-1]+aa[i];
head=tail=1;q[head]=0;
for (i=1;i<=n;i++)
{
while (head<tail && calc(q[head],q[head+1])>(long long)2*a*sum[i]) head++;
int pos=q[head];
dp[i]=dp[pos]+a*(sum[i]-sum[pos])*(sum[i]-sum[pos])+b*(sum[i]-sum[pos])+c;
q[++tail]=i;
while (head+1<tail && calc(q[tail-2],q[tail-1])<calc(q[tail-1],q[tail])) q[tail-1]=q[tail],tail--;
}
printf("%lld\n",dp[n]);
return 0;
}