题目大意
将一棵树分为三部分,使最大部分与最小部分的差最小。
正解
O(n^2)做法,枚举两条边,并用dfs序判断是否有祖先关系。
O(nlogn)做法,考虑用权值线段树来维护,记住是绝对值,要用前驱后继查询。
仍分为两种情况:
1.有祖先关系。统计答案时取最接近(n+size)/2的。dfs时将size丢进权值线段树线段树,遍历后将其从树中删除。
2.无祖先关系。统计答案时取最接近(n-size)/2的。遍历后将size丢进权值线段树线段树,无需删除。
代码
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;
int n,m=200000,sx,sy,bb[500005],size[200005],h[200005],fa[200005],dep[200005],cnt=0,tot=0,ans=2000000007;
struct node
{
int next,to,from;
}e[400005];
struct segment_tree
{
int val;
}a[10000005];
void read(int &x){
char c=getchar();
for(;c<33;c=getchar());
for(x=0;47<c&&c<58;x=(x<<3)+(x<<1)+c-48,c=getchar());
}
inline void add(int x,int y)
{
e[++cnt].to=y;
e[cnt].from=x;
e[cnt].next=h[x];
h[x]=cnt;
}
inline void dfs(int x)
{
size[x]=1;
for(int i=h[x];i;i=e[i].next)
{
int y=e[i].to;
if(y!=fa[x])
{
fa[y]=x;
dfs(y);
size[x]+=size[y];
}
}
}
inline void add(int l,int r,int k,int x,int val)
{
if(l==r)
{
a[k].val+=val;
}
else
{
int mid=(l+r)>>1;
if(x<=mid)add(l,mid,k<<1,x,val);
else add(mid+1,r,(k<<1)+1,x,val);
a[k].val=a[k<<1].val+a[(k<<1)+1].val;
}
}
inline int ub(int l,int r,int k,int x)
{
if(l==r)
{
return l;
}
else
{
int mid=(l+r)>>1,mx=0;
if((r<=x||x>mid)&&a[(k<<1)+1].val)
{
mx=ub(mid+1,r,(k<<1)+1,x);
}
if(a[k<<1].val&&!mx)mx=ub(l,mid,k<<1,x);
return mx;
}
return -2000000007;
}
inline int lb(int l,int r,int k,int x)
{
if(l==r)
{
return l;
}
else
{
int mid=(l+r)>>1,mx=0;
if((l>=x||x<=mid)&&a[k<<1].val)
{
mx=lb(l,mid,k<<1,x);
}
if(a[(k<<1)+1].val&&!mx)mx=lb(mid+1,r,(k<<1)+1,x);
return mx;
}
return 2000000007;
}
inline void dfs1(int x)
{
sx=ub(1,m,1,(n+size[x])>>1);
sy=lb(1,m,1,(n+size[x])>>1);
int c1=size[x],c2=sx-size[x],c3=n-sx;
ans=min(ans,max(abs(c1-c2),max(abs(c1-c3),abs(c2-c3))));
c1=size[x],c2=sy-size[x],c3=n-sy;
ans=min(ans,max(abs(c1-c2),max(abs(c1-c3),abs(c2-c3))));
add(1,m,1,size[x],1);
for(int i=h[x];i;i=e[i].next)
{
int y=e[i].to;
if(y!=fa[x])
{
dfs1(y);
}
}
add(1,m,1,size[x],-1);
}
inline void dfs2(int x)
{
sx=ub(1,m,1,(n-size[x])>>1);
sy=lb(1,m,1,(n-size[x])>>1);
int c3=n-size[x]-sx;
ans=min(ans,max(abs(size[x]-sx),max(abs(size[x]-c3),abs(sx-c3))));
c3=n-size[x]-sy;
ans=min(ans,max(abs(size[x]-sy),max(abs(size[x]-c3),abs(sy-c3))));
for(int i=h[x];i;i=e[i].next)
{
if(e[i].to!=fa[x])
{
dfs2(e[i].to);
}
}
add(1,m,1,size[x],1);
}
int main()
{
freopen("chilli.in","r",stdin);
freopen("chilli.out","w",stdout);
read(n);
for(int i=1,x,y;i<n;i++)
{
read(x);read(y);
add(x,y);add(y,x);
}
dfs(1);
dfs1(1);
dfs2(1);
printf("%d\n",ans);
}