好长啊题目。。。
大概就是把一棵树分成两棵再拼成一棵后的最大最小直径。应该可以想到树形DP,保存每个节点为根的子树的直径和除去该子树后的直径。为此我们需要维护每个节点向下的前三长的链(每个儿子只记一次),向上最长的链,儿子中前二长的直径,然后可以求出在哪里断开。最后把两棵树的直径找粗来,并起来即可。
说起来好像很简单,然而我做了一周。。真是太弱了。。。
代码比较丑2333
#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#define N 500005
using namespace std;
int n,x,y;
int first[N],to[N<<1],next[N<<1],l;
int fa[N],dep[N],Max[N],Min[N],dis[N],X,ba[N];
int dia[N],cha[N][4],Mx[N][3],dia1[N],cha1[N];
void link(int x,int y)
{
to[++l]=y;next[l]=first[x];first[x]=l;
to[++l]=x;next[l]=first[y];first[y]=l;
}
void dfs(int x)
{
dep[x]=dep[fa[x]]+1;
for (int i=first[x];i;i=next[i])
if (to[i]!=fa[x])
{
fa[to[i]]=x;
dfs(to[i]);
int t;
int k=cha[to[i]][0]+1;
for (t=3;t&&cha[x][t-1]<k;t--)cha[x][t]=cha[x][t-1];
cha[x][t]=k;
k=dia[to[i]];
for (t=2;t&&Mx[x][t-1]<k;t--)Mx[x][t]=Mx[x][t-1];
Mx[x][t]=k;
}
dia[x]=max(Mx[x][0],cha[x][0]+cha[x][1]);
}
void dfs1(int x)
{
for (int i=first[x];i;i=next[i])
if (to[i]!=fa[x])
{
int t,t1,k=to[i];
t=(cha[k][0]+1==cha[x][0]);
cha1[k]=max(cha1[x],cha[x][t])+1;
t=(dia[k]==Mx[x][0]);
dia1[k]=max(Mx[x][t],dia1[x]);
if (cha[k][0]+1==cha[x][0]) t=1,t1=2;
else if (cha[k][0]+1==cha[x][1]) t=0,t1=2;
else t=0,t1=1;
dia1[k]=max(dia1[k],max(cha[x][t]+cha[x][t1],cha[x][t]+cha1[x]));
dfs1(k);
}
}
void Dfs(int x,int y)
{
dis[x]=dis[ba[x]]+1;
if (dis[x]>dis[X]) X=x;
for (int i=first[x];i;i=next[i])
if (to[i]!=ba[x]&&to[i]!=y)
{
ba[to[i]]=x;
Dfs(to[i],y);
}
}
int Getfar(int x,int y)
{
memset(dis,-1,sizeof dis);
memset(ba,0,sizeof ba);
dis[x]=0;X=x;Dfs(x,y);
return X;
}
int Getmid(int x,int y,int l)
{
l/=2;
if (dep[x]<dep[y]) x=y;
for (int i=1;i<=l;i++) x=fa[x];
return x;
}
int main()
{
scanf("%d",&n);
for (int i=1;i<n;i++)
scanf("%d%d",&x,&y),link(x,y);
dfs(1);dfs1(1);
int Ans1=1,Ans2=1;
Max[1]=Min[1]=dia[1];
for (int i=2;i<=n;i++)
{
Max[i]=dia[i]+dia1[i]+1;
Min[i]=max(max(dia[i],dia1[i]),(int)(ceil(dia[i]/2.0)+ceil(dia1[i]/2.0)+1));
if (Max[Ans1]<Max[i]) Ans1=i;
if (Min[Ans2]>Min[i]) Ans2=i;
}
if (Ans2==1) printf("%d %d %d %d %d\n",Min[1],x,y,x,y);
else
{
printf("%d %d %d ",Min[Ans2],Ans2,fa[Ans2]);
int t1=Getfar(Ans2,fa[Ans2]),t2=Getfar(t1,fa[Ans2]),tt1=Getfar(fa[Ans2],Ans2),tt2=Getfar(tt1,Ans2);
printf("%d %d\n",Getmid(t1,t2,dia[Ans2]),Getmid(tt1,tt2,dia1[Ans2]));
}
if (Ans1==1) printf("%d %d %d %d %d\n",Max[1],x,y,x,y);
else
{
printf("%d %d %d ",Max[Ans1],Ans1,fa[Ans1]);
printf("%d %d\n",Getfar(Ans1,fa[Ans1]),Getfar(fa[Ans1],Ans1));
}
}
自己的"SPJ"
#include<iostream>
#include<cstdio>
#include<cstring>
#define N 2000005
using namespace std;
int n,Ans,x,y,xx,yy,t,X;
int first[N],next[N],to[N],f[N],fa[N],dis[N],l;
void link(int x,int y)
{
to[++l]=y;next[l]=first[x];first[x]=l;
to[++l]=x;next[l]=first[y];first[y]=l;
}
void dfs(int x)
{
dis[x]=dis[fa[x]]+1;
if (dis[x]>dis[X]) X=x;
for (int i=first[x];i;i=next[i])
if (!f[i]&&to[i]!=fa[x])
{
fa[to[i]]=x;
dfs(to[i]);
}
}
int Get()
{
memset(dis,-1,sizeof dis);
memset(fa,0,sizeof fa);
dis[1]=0;X=1;dfs(1);
memset(dis,-1,sizeof dis);
memset(fa,0,sizeof fa);
dis[X]=0;dfs(X);
return dis[X];
}
int main()
{
freopen("1.in","r",stdin);
scanf("%d",&n);
for (int i=1;i<n;i++)
scanf("%d%d",&x,&y),link(x,y);
fclose(stdin);
freopen("1.out","r",stdin);
scanf("%d%d%d%d%d",&Ans,&x,&y,&xx,&yy);
printf("%d %d %d %d\n",x,y,xx,yy);
for (int i=first[x];i;i=next[i])
if (to[i]==y) f[i]=1;
for (int i=first[y];i;i=next[i])
if (to[i]==x) f[i]=1;
link(xx,yy);
t=Get();
printf("Ans=%d fac=%d\n",Ans,t);
if (Ans!=t) puts("WA");
for (int i=first[x];i;i=next[i])
if (to[i]==y) f[i]=0;
for (int i=first[y];i;i=next[i])
if (to[i]==x) f[i]=0;
f[l]=f[l-1]=1;
scanf("%d%d%d%d%d",&Ans,&x,&y,&xx,&yy);
printf("%d %d %d %d\n",x,y,xx,yy);
for (int i=first[x];i;i=next[i])
if (to[i]==y) f[i]=1;
for (int i=first[y];i;i=next[i])
if (to[i]==x) f[i]=1;
link(xx,yy);
t=Get();
printf("Ans=%d fac=%d\n",Ans,t);
if (Ans!=t) puts("WA");
for (int i=first[x];i;i=next[i])
if (to[i]==y) f[i]=0;
for (int i=first[y];i;i=next[i])
if (to[i]==x) f[i]=0;
f[l]=f[l-1]=1;
}