题意:给出一棵n个节点的树,每个节点有一种颜色,用Ci表示,问如果删除一条边(x,y),那么剩余两棵树的颜色交集的大小是多少。
思路:首先对原来的颜色dfs一遍生成有向树得出颜色序列(可参照 挑战程序竞赛 P330),记录每个节点访问时最先出现和最后出现的位置,那么对于删除(x,y)这条边时,假设x是y子节点,很显然,x最先出现和最后出现的位置形成的闭区间则是以x为根节点的子树的所有节点颜色序列,假设这个区间为L,R,要计算两颗子树交集大小,发现交集大小为 = 以x为根节点的子树的所有颜色总数(A) - 以x为根节点的子树只有的颜色总数(B),要计算A的值,等价于在区间[L,R]寻找不同的数的个数,莫队算法直接过,同样B也可以在求A的同时计算出来。
#include<cstdio>
#include<cstring>
#include<vector>
#include<set>
#include<queue>
#include<cmath>
#include<stack>
#include<iostream>
#include<algorithm>
typedef long long ll;
const int INF = 1e9 + 7;
const int maxn = 1e5 + 10;
using namespace std;
struct edge {
int from, to;
void input() { scanf("%d %d", &from, &to); }
} e[maxn];
struct P {
int bol, l, r, num;
P() {}
P(int b, int l, int r, int n) :
bol(b), l(l), r(r), num(n) {}
bool operator < (P p) const {
if(bol != p.bol) return bol < p.bol;
return r < p.r;
}
} qry[maxn];
int n, q, num, k, INIT;
int id[maxn], vs[maxn * 10]; ///每个节点第一次出现的下标, 颜色的遍历序列
int col[maxn], sum[maxn]; ///每个点的颜色,dfs序列点的总数
int blk[maxn]; ///每个端点所在的块
vector<int> G[maxn];
int res[maxn]; ///每个查询的结果
int now[maxn]; ///子树现有的颜色总数
int idx[maxn], dep[maxn]; ///每个点最后出现的下标 深度
void init() {
int b = 1, x = 0;
for(int i = 0; i < maxn; i++) {
G[i].clear();
sum[i] = now[i] = 0;
}
if(INIT) return ;
INIT = 1;
while(x < maxn) {
for(int j = 0; j < 300; j++) {
if(x >= maxn) break;
blk[x] = b; x++;
}
b++;
}
}
void dfs(int v, int fa, int &k, int d) {
id[v] = idx[v] = k; dep[v] = d;
vs[k++] = col[v];
for(int i = 0; i < G[v].size(); i++) {
int to = G[v][i];
if(to == fa) continue;
dfs(to, v, k, d + 1);
idx[v] = k;
vs[k++] = col[v];
}
}
int tal, only;
///一棵子树中 总的颜色数量-只有的颜色数量=交集的数量
void solve(int l, int r, int L, int R, int x) {
while(l > L) {
l--;
int c = vs[l];
now[c]++;
if(now[c] == 1) tal++;
if(now[c] == sum[c]) only++;
}
while(l < L) {
int c = vs[l]; now[c]--;
if(!now[c]) tal--;
if(now[c] == sum[c] - 1) only--;
l++;
}
while(r > R) {
int c = vs[r]; now[c]--;
if(!now[c]) tal--;
if(now[c] == sum[c] - 1) only--;
r--;
}
while(r < R) {
r++;
int c = vs[r];
now[c]++;
if(now[c] == 1) tal++;
if(now[c] == sum[c]) only++;
}
res[x] = tal - only;
}
int main() {
INIT = 0;
while(scanf("%d", &n) != EOF) {
init();
for(int i = 1; i <= n; i++)
scanf("%d", &col[i]);
for(int i = 0; i < n - 1; i++) {
e[i].input();
G[e[i].from].push_back(e[i].to);
G[e[i].to].push_back(e[i].from);
}
k = 0;
dfs(1, 0, k, 1);
for(int i = 0; i < n - 1; i++) {
int u = e[i].from, v = e[i].to, num = i;
if(dep[u] < dep[v]) swap(u, v);
int L = id[u], R = idx[u];
qry[i] = P(blk[L], L, R, i);
}
for(int i = 0; i < k; i++) sum[vs[i]]++;
sort(qry, qry + n - 1);
int c = vs[0]; now[c]++;
int lasl = 0, lasr = 0;
int nowl, nowr;
tal = 1; only = sum[c] == 1 ? 1 : 0;
for(int i = 0; i < n - 1; i++) {
nowl = qry[i].l; nowr = qry[i].r;
int x = qry[i].num;
solve(lasl, lasr, nowl, nowr, x);
lasl = nowl; lasr = nowr;
}
for(int i = 0; i < n - 1; i++) printf("%d\n", res[i]);
}
return 0;
}