[APIO/CTSC 2007]数据备份
f x , 0 = m i n ( f x − 1 , 0 , f x − 1 , 1 ) f_{x,0}=min(f_{x-1,0},f_{x-1,1}) fx,0=min(fx−1,0,fx−1,1)
f x , 1 = f x − 1 , 0 + d i s x + c f_{x,1}=f_{x-1,0}+dis_x+c fx,1=fx−1,0+disx+c
意会即可。
#include<bits/stdc++.h>
using namespace std;
const int N=100005;
int n,k;
int choose[N][2];
long long dist[N],dis[N];
long long dp[N][2];
int check(long long mid) {
for(int i=2;i<=n;i++) {
if(dp[i-1][0]<dp[i-1][1]||dp[i-1][0]==dp[i-1][1]&&choose[i-1][0]<choose[i-1][1]) {
dp[i][0]=dp[i-1][0],choose[i][0]=choose[i-1][0];
}
else dp[i][0]=dp[i-1][1],choose[i][0]=choose[i-1][1];
dp[i][1]=dp[i-1][0]+dis[i]+mid,choose[i][1]=choose[i-1][0]+1;
}
if(dp[n][0]<dp[n][1]||dp[n][0]==dp[n][1]&&choose[n][0]<choose[n][1]) return choose[n][0];
return choose[n][1];
}
int main() {
scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++) scanf("%lld",&dist[i]),dis[i]=dist[i]-dist[i-1];
long long L=-1e9-10,R=0;
long long midl=0;
while(L<=R) {
long long mid=(L+R)>>1;
if(check(mid)==k) {
printf("%lld",min(dp[n][0],dp[n][1])-k*mid);
return 0;
}
if(check(mid)>k) {
L=mid+1;
}
else {
midl=mid;
R=mid-1;
}
}
check(midl);
printf("%lld",min(dp[n][0],dp[n][1])-k*midl);
}
[八省联考2018]林克卡特树
理解一下这个过程。
本质上就是求树的 k + 1 k+1 k+1 条不相交的链的和的最大值。
d p i , k , 0 / 1 / 2 dp_{i,k,0/1/2} dpi,k,0/1/2 表示以 i i i 为根的子树, k k k 条链 ,在链外/链端/链上的和的最大值。
容易发现, k k k 越大, d p 1 , k , 0 / 1 / 2 dp_{1,k,0/1/2} dp1,k,0/1/2 的增长率越慢(有可能负增长)。
这是一个凸函数。考虑 wqs二分。
具体地,二分权值 c c c ,每增加一条链,就要加上 c c c 的代价。
具体地,可以把 d p i , 2 dp_{i,2} dpi,2 合并到 d p i , 0 dp_{i,0} dpi,0 上,简化 d p dp dp 转移方程。
需要注意的是,我们只在一条独立的链的开始,即 d p i , 2 dp_{i,2} dpi,2 上增加权值 c c c 。
c = [ − ∞ , + ∞ ] c=[-\infty,+\infty] c=[−∞,+∞]
#include<bits/stdc++.h>
using namespace std;
const int N=3e5+5;
inline int read()
{
int X=0; bool flag=1; char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-') flag=0; ch=getchar();}
while(ch>='0'&&ch<='9') {X=(X<<1)+(X<<3)+ch-'0'; ch=getchar();}
if(flag) return X;
return ~(X-1);
}
struct data{
long long v,c;
data operator +(data d) {
data a;
a.v=v+d.v;
a.c=c+d.c;
return a;
}
bool operator <(data d) {return v<d.v||v==d.v&&c>d.c;}
}dp[N][3];
data Max(data a,data b) {
return a<b?b:a;
}
int n,k;
int head[N*2],nxt[N*2],to[N*2],w[N*2],num;
long long sum,cnt,delta;
void add(int x,int y,int z) {
to[++num]=y,w[num]=z,nxt[num]=head[x],head[x]=num;
}
void dfs(int u,int fath) {
for(int i=head[u];i;i=nxt[i]) {
int v=to[i],z=w[i]; if(v==fath) continue;
dfs(v,u);
dp[u][2]=Max(dp[u][2]+dp[v][0],dp[u][1]+dp[v][1]+(data){z+delta,1});
dp[u][1]=Max(dp[u][0]+dp[v][1]+(data){z,0},dp[u][1]+dp[v][0]);
dp[u][0]=dp[u][0]+dp[v][0];
}dp[u][0]=Max(dp[u][0],dp[u][2]);
}
void check(long long mid) {
delta=mid;
for(int i=1;i<=n;i++) {
dp[i][0]={0,0},dp[i][2]={delta,1},dp[i][1]={0,0};
}
dfs(1,0);
sum=dp[1][0].v,cnt=dp[1][0].c;
}
int main() {
n=read(),k=read()+1;
for(int i=1;i<n;i++) {
int x=read(),y=read(),z=read();
add(x,y,z),add(y,x,z);
}
long long L=-1e10,R=1e10,midl;
while(L<=R) {
long long mid=(L+R)>>1;
check(mid);
if(cnt<=k) {L=mid+1,midl=mid;}
else R=mid-1;
}
check(midl);
printf("%lld\n",sum-k*midl);
}