有一个庞大的家族,共n人。已知这n个人的祖辈关系正好形成树形结构(即父亲向儿子连边)。
在另一个未知的平行宇宙,这n人的祖辈关系仍然是树形结构,但他们相互之间的关系却完全不同了,原来的祖先可能变成了后代,后代变成的同辈……
两个人的亲密度定义为在这两个平行宇宙有多少人一直是他们的公共祖先。
整个家族的亲密度定义为任意两个人亲密度的总和。
Input
第一行一个数n(1<=n<=100000) 接下来n-1行每行两个数x,y表示在第一个平行宇宙x是y的父亲。 接下来n-1行每行两个数x,y表示在第二个平行宇宙x是y的父亲。
Output
一个数,表示整个家族的亲密度。
Input示例
5 1 3 3 5 5 4 4 2 1 2 1 3 3 4 1 5
Output示例
6
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
主席树+dfs序~
我们考虑点对对某一个点的贡献:如果两树中一个点u的子树中有x个点是公共的,那么这个子树对答案的贡献就是x*(x-1)/2。所以我们只需要统计出所有子树中公共点的个数即可。
我们先将两树分别dfs求出dfs序,那么在两树中每个点都有一个特定区间。假设u在A树中区间[l,r],在B树中区间[L,R],那么x就等于在[l,r]中出现的数在[L,R]中出现的次数。用主席树维护。
注意计算的时候也要开long long!
#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
#define ll long long
int n,x,y,root,fi[100001],w[200001],ne[200001],cnt,al[100001],ar[100001],bl[100001],br[100001],in[100001];
int rot[100001],a[100001],ls[2000001],rs[2000001];
ll ans,s[2000001];
int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0' || ch>'9') {if(ch=='-') f=-1;ch=getchar();}
while(ch>='0' && ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
return x*f;
}
void add(int u,int v)
{
w[++cnt]=v;ne[cnt]=fi[u];fi[u]=cnt;
w[++cnt]=u;ne[cnt]=fi[v];fi[v]=cnt;
}
void dfs(int u,int fa,int flag)
{
flag ? bl[u]=++cnt:al[u]=++cnt;
for(int i=fi[u];i;i=ne[i]) if(w[i]!=fa) dfs(w[i],u,flag);
flag ? br[u]=cnt:ar[u]=cnt;
}
void build(int &u,int v,int val,int l,int r)
{
u=++cnt;ls[u]=ls[v];rs[u]=rs[v];s[u]=s[v];
if(l==r)
{
s[u]++;return;
}
int mid=l+r>>1;
if(mid>=val) build(ls[u],ls[v],val,l,mid);
else build(rs[u],rs[v],val,mid+1,r);
s[u]=s[ls[u]]+s[rs[u]];
}
int cal(int u,int v,int x,int y,int l,int r)
{
if(l>=x && r<=y) return s[v]-s[u];
int mid=l+r>>1,now=0;
if(x<=mid) now+=cal(ls[u],ls[v],x,y,l,mid);
if(y>mid) now+=cal(rs[u],rs[v],x,y,mid+1,r);
return now;
}
int main()
{
n=read();
for(int i=1;i<n;i++) x=read(),y=read(),in[y]++,add(x,y);
for(root=1;in[root];root++);
cnt=0;dfs(root,0,0);
memset(fi,0,sizeof(fi));cnt=0;
memset(in,0,sizeof(in));
for(int i=1;i<n;i++) x=read(),y=read(),in[y]++,add(x,y);
for(root=1;in[root];root++);
cnt=0;dfs(root,0,1);cnt=0;
for(int i=1;i<=n;i++) a[bl[i]]=i;
for(int i=1;i<=n;i++) build(rot[i],rot[i-1],al[a[i]],1,n);
for(int i=1;i<=n;i++)
{
ll now=cal(rot[bl[i]-1],rot[br[i]],al[i],ar[i],1,n)-1;
ans+=now*(now-1)/2;
}
printf("%lld\n",ans);
return 0;
}