题目
https://vjudge.net/problem/SPOJ-COT2
题意
给你一颗树 若干次询问 每次问一条链上有多少不同的点
思路
树上莫队
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e5+100;
vector<int> G[maxn];
int st[maxn],en[maxn],xu[maxn],cnt;
int fa[maxn][20],deep[maxn];
int n;
void dfs(int u,int f)
{
deep[u] = deep[f] + 1;
st[u] = ++cnt;
xu[cnt] = u;
fa[u][0] = f;
for(int i = 0;i < G[u].size();i++)
{
int v = G[u][i];
if(v== f) continue;
dfs(v,u);
}
en[u] = ++cnt;
xu[cnt] = u;
}
void init()
{
for(int j = 1;j <= 19;j++)
{
for(int i = 1;i <= n;i++)
{
fa[i][j] = fa[fa[i][j-1]][j-1];
}
}
}
int lca(int u,int v)
{
if(deep[v] > deep[u]) swap(u,v);
int diff = deep[u] - deep[v];
for(int i = 19;i >= 0;i--) if(diff>>i&1) u = fa[u][i];
if(u == v) return u;
for(int i = 19;i >= 0;i--) if(fa[u][i] != fa[v][i]) u = fa[u][i],v = fa[v][i];
return fa[u][0];
}
int ans[maxn],block,bid[maxn],bnum,nn,ANS;
struct node
{
int l,r,lca,id;
bool operator < (const node &b) const{
return (bid[l]^bid[b.l])?(bid[l]<bid[b.l]):((bid[l]&1)?r<b.r:r>b.r);
}
}Q[maxn];
int a[maxn],so[maxn];
int vis[maxn],num[maxn];
void add(int x)
{
if(++num[x] == 1) ANS++;
}
void del(int x)
{
if(--num[x] == 0) ANS--;
}
void Add(int x)
{
if(vis[x]) del(a[x]);
else add(a[x]);
vis[x] ^= 1;
}
int main()
{
int m;
scanf("%d%d",&n,&m);
for(int i = 1;i <= n;i++)
{
scanf("%d",&a[i]);
so[i] = a[i];
}
sort(so+1,so+1+n);
nn = unique(so+1,so+1+n) - (so+1);
for(int i = 1;i <= n;i++)
a[i] = lower_bound(so+1,so+1+n,a[i]) - so;
for(int i = 1;i < n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
cnt = 0;
deep[0] = 0;
dfs(1,0);
init();
block = sqrt(2*n);
bnum = ceil(double(2*n)/block);
for(int i = 1;i <= bnum;i++)
for(int j = block*(i-1)+1;j <= min(2*n,i*block);j++) bid[j] = i;
for(int i = 1;i <= m;i++)
{
int u,v;
scanf("%d%d",&u,&v);
if(st[u] > st[v]) swap(u,v);
int _lca = lca(u,v);
if(_lca==u) Q[i] = node{st[u],st[v],0,i};
else Q[i] = node{en[u],st[v],_lca,i};
}
sort(Q+1,Q+1+m);
int l = 1,r = 0;
ANS = 0;
memset(vis,0,sizeof(vis));
memset(num,0,sizeof(num));
for(int i = 1;i <= m;i++)
{
while(r < Q[i].r) Add(xu[++r]);
while(r > Q[i].r) Add(xu[r--]);
while(l < Q[i].l) Add(xu[l++]);
while(l > Q[i].l) Add(xu[--l]);
if(Q[i].lca) Add(Q[i].lca);
ans[Q[i].id] = ANS;
if(Q[i].lca) Add(Q[i].lca);
}
for(int i = 1;i <= m;i++)
printf("%d\n",ans[i]);
return 0;
}

4335

被折叠的 条评论
为什么被折叠?



