众所周知,莫队算法不仅仅能解决序列问题,还能解决树上问题。
当一个问题可以离线做时,我们就可以使用树上莫队算法了(带修另说)
基本思路
一道例题:给一棵树,树上节点都有颜色 c o l [ i ] col[i] col[i] ,每次给定询问 ( u , v ) (u,v) (u,v) ,求树上 u u u 到 v v v 不同的颜色数是多少。
或许可以打树链剖分+线段树合并,但是实现起来比较困难,而且代码量大,所以这时候就需要代码较为简单的树上莫队了。
首先我们要将一棵树拍成序列的状态,在这个序列上进行莫队,按照普通莫队的方法求解答案即可。问题在于怎么将一棵树拍成序列的状态。
我们可以借鉴括号序的思路:将每个点看作一个括号,第一次到达节点时将左括号加入序列,遍历完子树后再把这个点的右括号加入,这样要是一个点在序列区间内出现 2 2 2 次,那么表示这个点进行了进栈,出栈的操作,就相当与没被遍历到,这个可以用标记数组记录下来,而序列的长度是 2 n 2n 2n 的,进行莫队操作的时间复杂度是没有问题的。
具体而言,假设我们有一棵树,如图:
那么这棵树被拍成序列是这样的:
1 , 2 , 2 , 3 , 4 , 4 , 5 , 5 , 3 , 1 1 ,2,2,3,4,4,5,5,3,1 1,2,2,3,4,4,5,5,3,1
我们设一个节点在序列中第一次出现的位置为
s
t
[
i
]
st[i]
st[i] ,最后一次出现的位置为
e
n
[
i
]
en[i]
en[i]
问题在于怎么将询问的点
(
u
,
v
)
(u,v)
(u,v) 转化为区间的
(
l
,
r
)
(l,r)
(l,r)
假设
u
u
u 先于
v
v
v 被扫到(即
s
t
[
u
]
<
s
t
[
v
]
st[u]<st[v]
st[u]<st[v])
明显的,我们要先求出
l
c
a
(
u
,
v
)
lca(u,v)
lca(u,v) 。
若
l
c
a
=
u
lca=u
lca=u (即
u
u
u 和
v
v
v 在一条链上)则
l
=
s
t
[
u
]
,
r
=
s
t
[
v
]
l=st[u],r=st[v]
l=st[u],r=st[v]
若
l
c
a
≠
u
lca\not= u
lca=u 那么
l
=
e
n
[
u
]
,
r
=
s
t
[
v
]
l=en[u],r=st[v]
l=en[u],r=st[v] 但是发现好像没有把
l
c
a
lca
lca 记录下来,那我们可以将询问的
l
c
a
lca
lca 记录,在做莫队的时候特判一下就好了。
时间复杂度是 O ( m n ) O(m \sqrt n) O(mn) 的。
代码
#include<cstring>
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;
struct ppp{
int l,r,lca,id;
}Q[100100];
struct kkk{
int head,to,nxt;
}edge[100100];
int n,m,col[50500],block,rt,a,b,cnt,f[50050][22],tot,kuai[50050];
int dep[50050],st[50050],en[50050];
int eu[100100],val[100100],ans,an[100100];
bool bz[50050];
bool cmp(ppp x,ppp y){
if(kuai[x.l]<kuai[y.l]) return 1;
if(kuai[x.l]==kuai[y.l]&&x.r<y.r) return 1;
return 0;
}
void build(int u,int v){
edge[cnt].to=v;
edge[cnt].nxt=edge[u].head;
edge[u].head=cnt;
cnt++;
}
void dfs(int x,int father){
tot++,st[x]=tot;
eu[tot]=x;
dep[x]=dep[father]+1;
for(int i=1;(1<<i)<=dep[x];i++)
f[x][i]=f[f[x][i-1]][i-1];
for(int i=edge[x].head;i!=-1;i=edge[i].nxt){
int y=edge[i].to;
if(y==father) continue;
f[y][0]=x;
dfs(y,x);
}
tot++,en[x]=tot;
eu[tot]=x;
}
int lca(int u,int v){
if(dep[u]<dep[v]) swap(u,v);
for(int i=20;i>=0;i--)
if(dep[f[u][i]]>=dep[v]) u=f[u][i];
if(u==v) return u;
for(int i=20;i>=0;i--)
if(f[u][i]!=f[v][i]) u=f[u][i],v=f[v][i];
return f[u][0];
}
void change(int x){
if(x==0) return;
if(bz[x]==0) {
bz[x]=1,val[col[x]]++;
if(val[col[x]]==1) ans++;
}
else {
bz[x]=0,val[col[x]]--;
if(val[col[x]]==0) ans--;
}
}
void slove(){
int l=0,r=0;
for(int i=1;i<=m;i++){
while(l>Q[i].l) l--,change(eu[l]);
while(r<Q[i].r) r++,change(eu[r]);
while(r>Q[i].r) change(eu[r]),r--;
while(l<Q[i].l) change(eu[l]),l++;
if(Q[i].lca) change(Q[i].lca);
an[Q[i].id]=ans;
if(Q[i].lca) change(Q[i].lca);
}
return;
}
int main(){
memset(edge,-1,sizeof(edge));
scanf("%d %d",&n,&m);
block=sqrt(2*n);
for(int i=1;i<=n;i++)
scanf("%d",&col[i]);
for(int i=1;i<=2*n;i++)
kuai[i]=i/block+1;
for(int i=1;i<n;i++){
scanf("%d %d",&a,&b);
build(a,b),build(b,a);
}
dfs(1,0);
for(int i=1;i<=m;i++){
scanf("%d %d",&a,&b);
Q[i].id=i;
if(st[a]>st[b]) swap(a,b);
int L=lca(a,b);
if(L==a) Q[i].l=st[a],Q[i].r=st[b];
else Q[i].l=en[a],Q[i].r=st[b],Q[i].lca=L;
}
sort(Q+1,Q+m+1,cmp);
slove();
for(int i=1;i<=m;i++)
printf("%d\n",an[i]);
return 0;
}