【题目描述】
给定一个点集的两颗生成树T1,T2。对于T1中每条边,问与T2中多少边交换以后满足T1和T2仍然是两棵树。
n<=1e6n<=1e6n<=1e6
【思路】
这道题挺不错的。首先,对于两条边e1∈T1,e2∈T2e1\in T1,e2\in T2e1∈T1,e2∈T2,它们可以交换当且仅当e1的两个端点在T2上的路径经过了e2,且e2两个端点在T1的路径上经过了e1。这是两个限制,我们可以这样处理:保证数据结构中的边满足一个限制,查询满足另一个限制的边的数量。由于把每条边独立出来单独处理是难以在低时间复杂度解决的,所以我们在T1中考虑儿子对父亲的贡献。我们可以发现,在我们需要维护的数据结构中,儿子和父亲的满足一个限制的边集合只有一部分不同。从父亲这里延伸到子树以外的边会被加进来,儿子子树里向上到父亲或父亲的其它子树的边需要被删除。我们可以考虑差分,对于T2的每一条边(u,v),我们可以在T1中的u和v处+1,在lca(u,v)处-2,就可以保证数据结构里的边是合法的。遗憾的是,子树之间的差分会相互影响,所以我们可以考虑线段树合并,使得子树的线段树里只有子树里合法的边。不过还有更简单的方法,我们发现这个需要求的数量满足可减性。我们可以在未考虑子树时求一次答案,在考虑子树的影响后求一次答案。最后,我们考虑一下我们需要维护一个什么样的数据结构。我们保证这个数据结构T2的边都覆盖了正在考虑的T1的边,我们需要查询满足另一个限制的边,即当前考虑的T1的边的两端点在T2上的路径上的边均满足第二个限制。所以我们需要支持在T2上单点修改,链查询。这个可以用树状数组+dfs序实现。简单梳理一下我们需要完成的操作:
- 当我们在差分T2上的边时,我们需要求T1中两点的lca,这个可以用树链剖分实现。
- 差分时,我们可以用一个vector存在当前节点时,T2中某条边在数据结构中的变化量。
- 树状数组+dfs序实现单点加,链求和的这个经典操作需要求T2中两点的lca,这个可以用树链剖分完成。
- dfsT1求答案时,我们应该依次进行以下操作:查询答案,处理子树,处理自己,查询答案并相减。
代码:
( 注:namespace TCP是树链剖分,namespace DO是dfs序+树状数组)
#include<bits/stdc++.h>
#define re register
#define mp make_pair
using namespace std;
const int N=2e6+5;
inline char nc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int red(){
char re ch=nc();int re sum=0;
while(!(ch>='0'&&ch<='9'))ch=nc();
while(ch>='0'&&ch<='9')sum=sum*10+ch-48,ch=nc();
return sum;
}
int n,m,a,b,c,ans[N];
struct node{int u,v;}e1[N],e2[N],E1[N],E2[N];
int f1[N],f2[N],nxp1[N],nxp2[N],cnt1=0,cnt2=0,cn1=0,cn2=0;
inline void add1(int u,int v){
E1[++cn1]=(node){u,v};
e1[++cnt1]=(node){u,v};
nxp1[cnt1]=f1[u];f1[u]=cnt1;
e1[++cnt1]=(node){v,u};
nxp1[cnt1]=f1[v];f1[v]=cnt1;
}
inline void add2(int u,int v){
E2[++cn2]=(node){u,v};
e2[++cnt2]=(node){u,v};
nxp2[cnt2]=f2[u];f2[u]=cnt2;
e2[++cnt2]=(node){v,u};
nxp2[cnt2]=f2[v];f2[v]=cnt2;
}
namespace TCP1{
int dep[N],top[N],fa[N],son[N],siz[N];
void dfs1(int u,int ff){
siz[u]=1;dep[u]=dep[ff]+1;
for(int re i=f1[u];i;i=nxp1[i]){
int v=e1[i].v;
if(dep[v])continue;
dfs1(v,fa[v]=u);siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}void dfs2(int u){
if(son[u])top[son[u]]=top[u],dfs2(son[u]);
for(int re i=f1[u];i;i=nxp1[i])
if(!top[e1[i].v])dfs2(top[e1[i].v]=e1[i].v);
}inline int lca(int a,int b){
while(top[a]^top[b]){
if(dep[top[a]]<dep[top[b]])swap(a,b);
a=fa[top[a]];
}if(dep[a]<dep[b])return a;
return b;
}
}
namespace TCP2{
int dep[N],top[N],fa[N],son[N],siz[N];
void dfs1(int u,int ff){
siz[u]=1;dep[u]=dep[ff]+1;
for(int re i=f2[u];i;i=nxp2[i]){
int v=e2[i].v;
if(dep[v])continue;
dfs1(v,fa[v]=u);siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}void dfs2(int u){
if(son[u])top[son[u]]=top[u],dfs2(son[u]);
for(int re i=f2[u];i;i=nxp2[i])
if(!top[e2[i].v])dfs2(top[e2[i].v]=e2[i].v);
}inline int lca(int a,int b){
while(top[a]^top[b]){
if(dep[top[a]]<dep[top[b]])swap(a,b);
a=fa[top[a]];
}if(dep[a]<dep[b])return a;
return b;
}
}
typedef pair<int,int>T;
vector<T>g[N];
inline void insert(int u,int p,int v){g[u].push_back(mp(p,v));}
namespace DO{
int st[N],ed[N],tot=0,c[N],dep[N];
inline void add(int x,int v){while(x<=n)c[x]+=v,x+=x&-x;}
inline int ask(int x){
int ret=0;
while(x)ret+=c[x],x-=x&-x;
return ret;
}
inline void change(int u,int w){add(st[u],w),add(ed[u]+1,-w);}
inline int query(int u,int v){return ask(st[u])+ask(st[v])-2*ask(st[TCP2::lca(u,v)]);}
void dfs(int u,int fa){
st[u]=++tot;dep[u]=dep[fa]+1;
for(int re i=f2[u];i;i=nxp2[i])
if(!st[e2[i].v])dfs(e2[i].v,u);
ed[u]=tot;
}void init(){
dfs(1,0);
for(int re i=1;i<n;i++){
int u=E2[i].u,v=E2[i].v;
if(dep[u]>dep[v])swap(u,v);
int lc=TCP1::lca(u,v);
insert(u,v,1);
insert(v,v,1);
insert(lc,v,-2);
}
}
}
namespace SOLVE{
void dfs(int u,int fa){
int ans1=0;
if(u^1)ans1=DO::query(u,fa);
for(int re i=f1[u];i;i=nxp1[i]){
int v=e1[i].v;
if(v==fa)continue;
dfs(v,u);
}
for(int re i=g[u].size()-1;~i;--i)
DO::change(g[u][i].first,g[u][i].second);
if(u^1)ans[u]=DO::query(u,fa)-ans1;
}
}
int main()
{
int size=40<<20;
__asm__ ("movq %0,%%rsp\n"::"r"((char*)malloc(size)+size));
n=red();
for(int re i=1;i<n;i++)add1(red(),red());
for(int re i=1;i<n;i++)add2(red(),red());
TCP1::dfs1(1,0);TCP1::dfs2(TCP1::top[1]=1);
TCP2::dfs1(1,0);TCP2::dfs2(TCP2::top[1]=1);
DO::init();SOLVE::dfs(1,0);
for(int re i=1;i<n;i++){
int u=E1[i].u,v=E1[i].v;
if(TCP1::dep[u]<TCP1::dep[v])swap(u,v);
cout<<ans[u]<<" ";
}exit(0);
}