我们记dis1[i]为i到根的距离,dis2[i]为根到i的距离,s[i]为i所在集合的大小
发现对于i,他对答案的贡献就是(s[i]−1)(dis1[i]+dis2[i])(s[i]−1)(dis1[i]+dis2[i])
因此处理出dis1,dis2后,令a[i]=dis1[i]+dis2[i],我们可以将问题转化成给你n个数,将他们分成恰好s个集合,每个集合S的代价是|S|∑i∈Sa[i]|S|∑i∈Sa[i],求最小的代价和
然后有个结论,就是最终分成的集合一定在a排序后是一段一段的,同一个集合的数肯定是相邻的一段
证明也不难,设a[i]<a[j]a[i]<a[j],a[i],a[j]a[i],a[j]所在的集合大小分别是b[i],b[j]b[i],b[j],则一定有b[i]>=b[j]b[i]>=b[j],否则交换他们所在的集合,答案会变得更优
那么将a升序排序后,在i处写下b[i],会发现b[i]不升,显然只有b[i]相同的在一个集合,于是同一个集合的数就一定是相邻一段
那么考虑dp,令f[i][j]f[i][j]表示前ii个数,分成段的最小代价,直接转移,转移的复杂度是O(n)O(n)的
然后这个转移其实是个二维前缀和的形式
每个点形如(i,∑ij=1a[j])(i,∑j=1ia[j])
转移P−>QP−>Q的代价是那个矩形的面积也即(Qx−Px)(Qy−Py)(Qx−Px)(Qy−Py)
然后可以发现,当P1对Q1的贡献劣于P2对Q2的贡献时P1对Q1的贡献劣于P2对Q2的贡献时,以后P1P1都不可能比P2P2更优了
于是这个东西就有决策单调性,我们可以维护一个单调队列,维护相邻两个元素的t[i]t[i]表示i+1在贡献到t[i]开始比i优,令t[i]单调,t[i]可以二分求,于是转移就可以做到均摊O(logn)O(logn)
现在复杂度已经优化到了O(n2logn)O(n2logn),已经是官方题解的复杂度了,但是这个东西其实还能优化(Orz liaoliao)
我们把f[i][j]f[i][j]写成关于j的函数f[i](j)f[i](j)的形式,会发现f[n](s)f[n](s)是个下凸的函数,于是可以用dpf[i](j)+kxf[i](j)+kx的极值点那个东西,把他的复杂度优化到O(nlog2n)O(nlog2n)
code:
#include<set>
#include<map>
#include<deque>
#include<queue>
#include<stack>
#include<cmath>
#include<ctime>
#include<bitset>
#include<string>
#include<vector>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<climits>
#include<complex>
#include<iostream>
#include<algorithm>
#define ll long long
#define ld long double
#define inf 1e15
using namespace std;
const int maxn = 50005;
const int maxm = 110000;
const ld eps = 1e-4;
int n,m,root,B,S;
int b[maxn]; ll bpre[maxn];
int ei[maxm][3];
struct Task1
{
struct edge{int y,c,nex;}a[maxm]; int len,fir[maxn];
inline void ins(const int x,const int y,const int c){a[++len]=(edge){y,c,fir[x]};fir[x]=len;}
struct node
{
int x,i;
friend inline bool operator <(const node x,const node y){ return x.x>y.x; }
};
priority_queue<node>q;
int dis1[maxn],dis2[maxn];
void Dij()
{
len=0; for(int i=1;i<=n;i++) fir[i]=0;
for(int i=1;i<=m;i++) ins(ei[i][0],ei[i][1],ei[i][2]);
for(int i=1;i<=n;i++) dis1[i]=1e9;
dis1[root]=0; q.push((node){dis1[root],root});
while(!q.empty())
{
const node tmp=q.top(); q.pop();
int x=tmp.i; if(dis1[x]!=tmp.x) continue;
for(int k=fir[x],y=a[k].y;k;k=a[k].nex,y=a[k].y) if(dis1[y]>dis1[x]+a[k].c)
{
dis1[y]=dis1[x]+a[k].c;
q.push((node){dis1[y],y});
}
}
len=0; for(int i=1;i<=n;i++) fir[i]=0;
for(int i=1;i<=m;i++) ins(ei[i][1],ei[i][0],ei[i][2]);
for(int i=1;i<=n;i++) dis2[i]=1e9;
dis2[root]=0; q.push((node){dis2[root],root});
while(!q.empty())
{
const node tmp=q.top(); q.pop();
int x=tmp.i; if(dis2[x]!=tmp.x) continue;
for(int k=fir[x],y=a[k].y;k;k=a[k].nex,y=a[k].y) if(dis2[y]>dis2[x]+a[k].c)
{
dis2[y]=dis2[x]+a[k].c;
q.push((node){dis2[y],y});
}
}
}
void Generate(int b[])
{
Dij();
for(int i=1;i<=B;i++) b[i]=dis1[i]+dis2[i];
n=B;
sort(b+1,b+n+1);
bpre[0]=0;
for(int i=1;i<=n;i++) bpre[i]=bpre[i-1]+b[i];
}
}T1;
void Read()
{
scanf("%d%d%d%d",&n,&B,&S,&m); root=B+1;
for(int i=1;i<=m;i++) scanf("%d%d%d",&ei[i][0],&ei[i][1],&ei[i][2]);
}
struct Task2
{
struct Ci
{
int i,x; ld val;
}f[maxn];
int Get_Time(Ci x,Ci y)
{
int i=x.i,j=y.i;
int l=j-1,r=n;
while(l<=r)
{
int mid=l+r>>1;
ld c1=x.val+(mid-i)*(bpre[mid]-bpre[i]),c2=y.val+(mid-j)*(bpre[mid]-bpre[j]);
if(c1>c2) r=mid-1;
else l=mid+1;
}
return r+1;
}
int qi[maxn],tq[maxn],head,tail; ld ki;
int dp()
{
f[0]=(Ci){0,0,0.0};
qi[head=tail=1]=0;
for(int i=1;i<=n;i++)
{
while(head<tail&&tq[head]<=i) head++;
Ci tmp=f[qi[head]];
f[i].i=i,f[i].x=tmp.x+1,f[i].val=tmp.val+ki+(ld)(i-tmp.i)*(bpre[i]-bpre[tmp.i]);
int tk;
while((tk=Get_Time(f[qi[tail]],f[i]))<=tq[tail-1]&&head<tail) tail--;
tq[tail]=tk; qi[++tail]=i;
}
return f[n].x;
}
void Solve()
{
ld u=0;for(int i=1;i<=n;i++) u+=(ld)b[i];
ld L=-u*n,R=u*n;
while(R-L>eps)
{
ld mid=(L+R)/2.0; ki=mid;
int loc=dp();
if(loc<=S) R=mid;
else L=mid;
}
ld ans=f[n].val-R*S-(ld)bpre[n]+0.0001;
printf("%.0Lf\n",ans);
}
}T2;
int main()
{
Read();
T1.Generate(b);
T2.Solve();
return 0;
}