可以把树上的一条链变成一位序列 然后用普通莫队求解
需要求出树欧拉序(和dfs序不同)
这颗树欧拉序为:1 4 8 8 4 3 7 7 6 6 5 5 3 2 2 1
根据欧拉序求出每个点第一次出现的位置和最后一次出现的位置
first[1]=1 first[2]=14
last[1]=16 last[2]=15
对于每次询问(l,r):
假设first[l]<first[r]
如果l是r的祖先节点: l到r经过的点就是first[l]到first[r]序列中只出现一次的点。
如果l不是r的祖先节点: l到r经过的点就是last[l]到first[r]序列中只出现一次的点再加上lca(l,r)(l,r的最近公共祖先)。
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<vector>
#include<cmath>
using namespace std;
const int N = 40010, M = N * 2;
vector<int>v;
int n, m;
int h[N], e[M], ne[M], idx;
int depth[N], fa[N][16];
int q[N];
int tot;
int li[N*2];
int first[N];
int last[N];
int cnt[N];
int w[N];
int book[N];
int ans[100010];
int len;
struct Query{
int id,l,r,st;
}query[100010];
int get(int x)
{
return x/len;
}
bool cmp(Query a,Query b)
{
int x=get(a.l),y=get(b.l);
if(x!=y) return x<y;
return a.r<b.r;
}
void Add(int x,int &res)
{
book[x]^=1;
if(book[x])
{
cnt[w[x]]++;
if(cnt[w[x]]==1) res++;
}
else
{
cnt[w[x]]--;
if(cnt[w[x]]==0) res--;
}
}
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}
void dfs(int x,int fa)
{
li[++tot]=x;first[x]=tot;
for(int i=h[x];i!=-1;i=ne[i])
{
int y=e[i];
if(y==fa) continue;
dfs(y,x);
}
li[++tot]=x;last[x]=tot;
}
void bfs(int root)
{
memset(depth, 0x3f, sizeof depth);
depth[0] = 0, depth[root] = 1;
int hh = 0, tt = 0;
q[0] = root;
while (hh <= tt)
{
int t = q[hh ++ ];
for (int i = h[t]; ~i; i = ne[i])
{
int j = e[i];
if (depth[j] > depth[t] + 1)
{
depth[j] = depth[t] + 1;
q[ ++ tt] = j;
fa[j][0] = t;
for (int k = 1; k <= 15; k ++ )
fa[j][k] = fa[fa[j][k - 1]][k - 1];
}
}
}
}
int lca(int a, int b)
{
if (depth[a] < depth[b]) swap(a, b);
for (int k = 15; k >= 0; k -- )
if (depth[fa[a][k]] >= depth[b])
a = fa[a][k];
if (a == b) return a;
for (int k = 15; k >= 0; k -- )
if (fa[a][k] != fa[b][k])
{
a = fa[a][k];
b = fa[b][k];
}
return fa[a][0];
}
int main()
{
scanf("%d%d", &n,&m);
int root = 1;
len=sqrt(n);
for(int i=1;i<=n;i++)
{
scanf("%d",&w[i]);
v.push_back(w[i]);
}
sort(v.begin(),v.end());
v.erase(unique(v.begin(),v.end()),v.end());
for(int i=1;i<=n;i++)
{
w[i]=lower_bound(v.begin(),v.end(),w[i])-v.begin();
}
memset(h, -1, sizeof h);
for (int i = 1; i < n; i ++ )
{
int a, b;
scanf("%d%d", &a, &b);
add(a, b), add(b, a);
}
bfs(root);
dfs(root,-1);
for(int i=0;i<m;i++)
{
int l,r;
scanf("%d%d",&l,&r);
if(first[l]>first[r]) swap(l,r);
if(lca(l,r)==l)
{
query[i].l=first[l];
query[i].r=first[r];
}
else
{
query[i].l=last[l];
query[i].r=first[r];
query[i].st=lca(l,r);
}
query[i].id=i;
}
sort(query,query+m,cmp);
int res=0;
for(int k=0,i=0,j=1;k<m;k++)
{
int l=query[k].l,r=query[k].r,st=query[k].st,id=query[k].id;
while(j<l) Add(li[j++],res);
while(j>l) Add(li[--j],res);
while(i<r) Add(li[++i],res);
while(i>r) Add(li[i--],res);
if(st)
{
//res=max(res,cnt[w[st]]++);
Add(st,res);
}
ans[id]=res;
if(st)
{
Add(st,res);
}
}
for(int i=0;i<m;i++)
{
printf("%d\n",ans[i]);
}
return 0;
}