题链:https://www.luogu.com.cn/problem/SP10707
思路:树上莫队模板题。
这里需要了解欧拉序(强推博客),利用欧拉序,就可以把树上的路径转化为连续区间的问题。
这里还会用到LCA。对于树上u到v的路径,设first[u]<first[v],
如果lca(u,v)==u,对应欧拉序的连续区间就为[ first[u] , first[v] ] ;
否则, 对应欧拉序的连续区间就为[ last[u] , first[v] ] , 不过还要加上 lca(u,v) 。
要注意的是区间中出现两次的点并不在u到v的路径中,所以我们需要一个vis数组,每次进行异或操作。如果vis[i]==1,要进行删除操作;vis[i]==0,进行添加操作。
一定记得有的数组需要开2*n并且别忘了加lca的贡献 !!!!!!!!
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 4e4+10;
const int M = 1e5+10;
struct node{
int to,nex;
}g[N<<1];
int head[N],cnt=0;
struct Node{
int l,r,id,lca;
}q[M];
int n,m,a[N];
int base,be[N<<1];
unordered_map<int,int> mp;
int Id=0,limit=0;
void add(int u,int v){
g[cnt]=node{v,head[u]};
head[u]=cnt++;
}
int fir[N],la[N],dfn[N<<1],tot=0;
int dep[N],f[N][30];
void dfs(int u,int fa){
fir[u]=++tot;
dfn[tot]=u;
for(int i=1;i<=20;++i)
f[u][i]=f[f[u][i-1]][i-1];
for(int i=head[u];~i;i=g[i].nex){
int v=g[i].to;
if(dep[v]||v==fa) continue;
dep[v]=dep[u]+1;
f[v][0]=u;
dfs(v,u);
}
la[u]=++tot;
dfn[tot]=u;
}
int getLca(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
int dis=dep[x]-dep[y];
for(int i=20;i>=0;--i)
if(dis&(1<<i))
x=f[x][i];
if(x==y) return x;
for(int i=20;i>=0;--i)
if(f[x][i]!=f[y][i])
x=f[x][i],y=f[y][i];
return f[x][0];
}
bool cmp(Node a,Node b){
return (be[a.l]^be[b.l]) ? be[a.l]<be[b.l] : (be[a.l]&1) ? a.r<b.r : a.r>b.r;
}
int vis[N],num[N],sum=0,ans[M];
void add(int x){
if(!num[a[x]]) ++sum;
++num[a[x]];
}
void del(int x){
--num[a[x]];
if(!num[a[x]]) --sum;
}
void change(int x){
x=dfn[x];
vis[x] ? del(x) : add(x);
vis[x]^=1;
}
int read() {
int res = 0;
char c = getchar();
while(!isdigit(c)) c = getchar();
while(isdigit(c)) res = (res << 1) + (res << 3) + c - 48, c = getchar();
return res;
}
int main(void){
//scanf("%d%d",&n,&m);
n=read(),m=read();
base=ceil(sqrt(2.0*n));
limit=ceil(log2(1.0*n));
cnt=0;
for(int i=1;i<=n;++i){
//scanf("%d",&a[i]);
a[i]=read();
if(!mp[a[i]]) mp[a[i]]=++Id;
a[i]=mp[a[i]];
head[i]=-1;
be[i]=i/base;
be[i+n]=(i+n)/base;
}
for(int i=1;i<n;++i){
int u,v;
//scanf("%d%d",&u,&v);
u=read(),v=read();
add(u,v);
add(v,u);
}
dep[1]=1;
dfs(1,0);
for(int i=1;i<=m;++i){
int u,v;
//scanf("%d%d",&u,&v);
u=read(),v=read();
if(fir[u]>fir[v]) swap(u,v);
int lca=getLca(u,v);
if(lca==u)
q[i].l=fir[u];
else
q[i].l=la[u],q[i].lca=lca;;
q[i].r=fir[v];
q[i].id=i;
}
sort(q+1,q+1+m,cmp);
int l=1,r=0;
for(int i=1;i<=m;i++){
int ql=q[i].l,qr=q[i].r,lca=q[i].lca;
while(l<ql) change(l++);
while(l>ql) change(--l);
while(r<qr) change(++r);
while(r>qr) change(r--);
ans[q[i].id]=sum;
if(lca && !num[a[lca]]) ++ans[q[i].id];
}
for(int i=1;i<=m;++i)
printf("%d\n",ans[i]);
return 0;
}