题意
给两棵树,结点数分别为nnn和mmm,对于所有的点对(s,t)(s,t)(s,t)(可以在一棵树内部,也可以分别在两棵树上),求出d(s,t)d(s,t)d(s,t),其中函数ddd表示给这两棵树任意加一条边联通后,这两点的最远距离。也就是求∑1≤s≤n+m∑s<t≤n+md(s,t)\sum_{1\leq s\leq n+m}\sum_{s<t\leq n+m} d(s,t)∑1≤s≤n+m∑s<t≤n+md(s,t)
分析
很明显,对于所有的点对(s,t)(s,t)(s,t),可以分为两种情况
- sss和ttt在同一棵树内
- sss和ttt分别在两棵树上
第一种情况(s和t在同一颗树内)
我们需要求的相当于是一棵树内任意两点的距离和。要处理这个东西,可以这么考虑,对于这棵树上的一条边,我们可以统计树上所有路径中,经过这条边的次数,那么这条边对距离和的贡献就等于w(s,t)∗cntw(s,t)*cntw(s,t)∗cnt,那么怎么求cntcntcnt呢?可以这么考虑,所有的经过这条边的路径数量,相当于在左端点及其左边任选一个点,右端点及其右边任选一个点,这样选出来的路径肯定都有经过这条边。
所以我们只需要开一个数组siz[i]siz[i]siz[i]记录iii的子树的大小,那么对于结点iii和它父亲的这条边,对应的总路径数cnt=w(i,fa[i])∗(siz[i]∗(tot−siz[i]))cnt=w(i,fa[i])*(siz[i]*(tot-siz[i]))cnt=w(i,fa[i])∗(siz[i]∗(tot−siz[i])),这里tottottot代表总的结点数,减去iii的子树大小,就等于在iii上面的结点个数。本题中所有的边权为1,tottottot分别为nnn和mmm。
第二种情况(s和t在不同的树上)
那么显然有一种贪心的想法,我们可以在第一棵树内找到sss距离最远的点s′s's′,在第二棵树内找到ttt距离最远的点t′t't′,那么加的边就可以是(s′,t′)(s',t')(s′,t′),这样对应的d(s,t)=d(s,s′)+1+d(t,t′)d(s,t)=d(s,s')+1+d(t,t')d(s,t)=d(s,s′)+1+d(t,t′)是最大的。
所以接下来我们要做的就是快速地找出每棵树内,每个结点对应的最远距离dis[i]dis[i]dis[i],这里有两种做法:
针对情况二的做法1
赛时我用的是较为麻烦的dpdpdp做法,开两个数组dp1[i]dp1[i]dp1[i]和dp2[i]dp2[i]dp2[i]分别表示从结点iii开始,第一步往子节点和往父节点走,能走的最远距离。所以一个点iii能走到的最远距离,要么是往父节点走,要么是往子节点走,也就是max(dp1[i],dp2[i])max(dp1[i],dp2[i])max(dp1[i],dp2[i]),问题转化为如何求dp1[i]dp1[i]dp1[i]和dp2[i]dp2[i]dp2[i]。
那么dp1dp1dp1数组是比较好得到的,直接简单的dfsdfsdfs一遍,设当前结点为uuu,子节点为vvv,dp1[u]=max(dp1[v]+1)dp1[u]=max(dp1[v]+1)dp1[u]=max(dp1[v]+1)。
难点在于dp2[i]dp2[i]dp2[i],不过也还好,想一下递推关系就可以知道,第一步往父节点走,最远距离等于父节点往其他子节点走的最远距离加上父节点到该子节点的距离,即
dp2[u]=w(u,fa[u])+max(dp1[u′]+w(fa[u],u′))dp2[u]=w(u,fa[u])+max(dp1[u']+w(fa[u],u'))dp2[u]=w(u,fa[u])+max(dp1[u′]+w(fa[u],u′))
其中u′u'u′表示结点uuu的兄弟结点,fa[u]fa[u]fa[u]表示结点uuu的父节点。所以只需要再进行一次dfsdfsdfs即可求出dp2[]dp2[]dp2[]。
针对情况二的做法2
赛后看到题解,发现上面的dpdpdp做法太麻烦了,直接利用一个性质:树上的一个点距离最远的点的距离,等于到树上直径的两个端点中较远的点的距离。直观上感觉确实这样,不过具体证明我也不会。
那么利用这个性质的话,就可以用求树的直径的方法,两遍dfsdfsdfs,第一遍任选一个点出发dfsdfsdfs,记录能到达的最远的点为sss,sss即为直径的一端,从点sss出发,再次dfsdfsdfs,顺便维护树上所有点到sss的距离dis1[i]dis1[i]dis1[i],能到达的最远的点ttt就是直径的另一端,(s,t)(s,t)(s,t)即为直径,再从ttt出发dfsdfsdfs一遍,维护树上所有点到点ttt的距离dis2[i]dis2[i]dis2[i],那么对于所有的点,能到达的最远距离就是
dis[i]=max(dis1[i],dis2[i])dis[i]=max(dis1[i],dis2[i])dis[i]=max(dis1[i],dis2[i])。
虽然这种方法需要dfsdfsdfs的次数比较多,但是思路上还是比较简单的,不用考虑那么多dpdpdp转移,我个人感觉这种方法比前面讲的做法1要简单。
那么如果已经求出来了每个点在树内能到达的最远距离数组dis[i]dis[i]dis[i],下面考虑具体如何计算出最终答案。我是对着第一组样例来手动列一下算式总结。如对于两棵树都只有2个点,一条边,编号分别为1,2和3,4的情况,对于sss和ttt分别在两棵树上时,有
dis[1]+1+dis[3]dis[1]+1+dis[3]dis[1]+1+dis[3]
dis[1]+1+dis[4]dis[1]+1+dis[4]dis[1]+1+dis[4]
dis[2]+1+dis[3]dis[2]+1+dis[3]dis[2]+1+dis[3]
dis[2]+1+dis[4]dis[2]+1+dis[4]dis[2]+1+dis[4]
可以看出,对于手动连的中间那条边,也就是式子中间的1,,总共有n∗mn*mn∗m条。
对于左边的dis[1]dis[1]dis[1],有mmm个点的disdisdis与其相加,对于dis[2]dis[2]dis[2]同理,所以提取公因式,等于m∑i=1ndis[i]m\sum_{i=1}^{n}dis[i]m∑i=1ndis[i]。
同理,对于右边的dis[3]dis[3]dis[3]和dis[4]dis[4]dis[4],求和为n∑i=n+1n+mdis[i]n\sum_{i=n+1}^{n+m}dis[i]n∑i=n+1n+mdis[i],不妨记第一棵nnn个结点的树中,各结点的最远距离和∑i=1ndis[i]\sum_{i=1}^{n}dis[i]∑i=1ndis[i]为sum1sum1sum1,第二棵树对应的节点距离和∑i=n+1n+mdis[i]\sum_{i=n+1}^{n+m}dis[i]∑i=n+1n+mdis[i] 为sum2sum2sum2。
那么∑d(s,t)∣s,t在两棵树上\sum d(s,t)|s,t在两棵树上∑d(s,t)∣s,t在两棵树上就等于n∗m+n∗sum2+m∗sum1n*m+n*sum2+m*sum1n∗m+n∗sum2+m∗sum1。
再把这个答案加上开始求的在一棵树内的距离和,即为最终答案(记得开long long,我习惯宏定义int long long)。
下面分别为两种方法求第二种情况的代码:
第一种
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=2e5+10;
vector<int>G[maxn];
int siz[maxn],ans,dp1[maxn],dp2[maxn],fa[maxn];
void dfs(int u,int f,int tot){
siz[u]=1;
for(auto v:G[u]){
if(v==f) continue;
fa[v]=u;
dfs(v,u,tot);
dp1[u]=max(dp1[u],dp1[v]+1);
siz[u]+=siz[v];
}
ans+=siz[u]*(tot-siz[u]);
}
void dfs2(int u,int f){
int t=0;
dp2[u]=dp2[f];
for(auto v:G[f]){
if(v==fa[f])continue;
if(v==u)t=1;
else dp2[u]=max(dp2[u],dp1[v]+1);
}
dp2[u]+=t;
for(auto v:G[u]){
if(v==f) continue;
dfs2(v,u);
}
}
signed main(){
ios::sync_with_stdio(false);
cin.tie(0);cout.tie(0);
int n,m;cin>>n>>m;
for(int i=2;i<n+m;i++){
int u,v;cin>>u>>v;
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1,1,n);
dfs2(1,0);
dfs(n+1,n+1,m);
dfs2(n+1,0);
int sum1=0,sum2=0;
for(int i=1;i<=n;i++)
sum1+=max(dp1[i],dp2[i]);
for(int i=n+1;i<=n+m;i++)
sum2+=max(dp1[i],dp2[i]);
ans+=n*m+n*sum2+m*sum1;
cout<<ans<<endl;
return 0;
}
第二种
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=2e5+10;
vector<int>G[maxn];
int dis1[maxn],dis2[maxn],s,t,siz[maxn];
ll ans;
void dfs(int u,int f,int tot){
siz[u]=1;
for(auto v:G[u]){
if(v==f) continue;
dfs(v,u,tot);
siz[u]+=siz[v];
}
ans+=siz[u]*(tot-siz[u]);
}
void dfs1(int u,int f){
for(auto v:G[u]){
if(v==f) continue;
dis1[v]=dis1[u]+1;
if(dis1[v]>dis1[s]) s=v;
dfs1(v,u);
}
}
void dfs2(int u,int f){
for(auto v:G[u]){
if(v==f) continue;
dis2[v]=dis2[u]+1;
if(dis2[v]>dis2[t]) t=v;
dfs2(v,u);
}
}
signed main(){
ios::sync_with_stdio(false);
cin.tie(0);cout.tie(0);
int n,m;cin>>n>>m;
for(int i=1;i<=n+m-2;i++){
int u,v;cin>>u>>v;
G[u].push_back(v);
G[v].push_back(u);
}
dfs1(1,1);
dfs2(s,s);//得到所有点到s的距离
dis1[t]=0;
dfs1(t,t);
s=t=0;
dfs1(n+1,n+1);
dfs2(s,s);
dis1[t]=0;
dfs1(t,t);
dfs(1,1,n);
dfs(n+1,n+1,m);
int sum1=0,sum2=0;
for(int i=1;i<=n;i++)
sum1+=max(dis1[i],dis2[i]);
for(int i=n+1;i<=n+m;i++)
sum2+=max(dis1[i],dis2[i]);
ans+=n*m+n*sum2+m*sum1;
cout<<ans<<endl;
return 0;
}