题目大意
有一棵 n n n个节点的以 1 1 1号节点为根的有根数。
每个节点都有一个颜色,第 i i i个节点的衍颜色为 c i c_i ci。
如果一种颜色在以 z z z为根的子树内出现次数最多,称其在以 x x x为根的子树中占主导地位。显然,同一子树中可能有多种颜色占主导地位。
对于每一个 i ∈ [ 1 , n ] i\in[1,n] i∈[1,n],求以 i i i为根的子树中,占主导地位的颜色的编号和。
1 ≤ n ≤ 1 0 5 , 1 ≤ c i ≤ n 1\leq n\leq 10^5,1\leq c_i\leq n 1≤n≤105,1≤ci≤n
题解
这道题可以用启发式合并(dsu on tree),但这里介绍用 s e t set set的做法。
首先, d f s dfs dfs一次整棵树,求出每一个节点的重儿子(我们将每个点的各个儿子中子树的节点数量最多的儿子称为重儿子,重儿子连成的链称为重链)。
然后,再 d f s dfs dfs一次,用 s e t set set来维护每种颜色和这种颜色在这个节点的子树中出现的次数。这个次数用含有两个元素的结构体来维护,这两个元素分别表示颜色和数量, s e t set set以颜色编号从小到大排序。
修改时,将对应衍颜色取出并在 s e t set set中删除,修改后再加入队列,时间复杂度为 O ( log n ) O(\log n) O(logn)。
对于每个点,用其重儿子求出的 s e t set set来进行当前的修改,将每个轻儿子(不是重儿子的儿子)的各个节点的颜色一次加入当前的 s e t set set。因为每个节点到根节点的路上最多只有 log n \log n logn个节点和 log n \log n logn条重链,所以每个点最多会被加入 s e t set set的次数为 log n \log n logn,修改操作为 O ( log n ) O(\log n) O(logn)。总共有 n n n个点,所以总时间复杂度为 O ( n log 2 n ) O(n\log^2 n) O(nlog2n)。
code
#include<bits/stdc++.h>
using namespace std;
int n,c[100005],siz[100005],son[100005],mx[100005];
int x,y,tot=0,d[200005],l[200005],r[200005];
long long v[100005];
struct node{
int x,v;
bool operator<(const node ax)const{
return x<ax.x;
}
};
set<node>s[100005];
void add(int xx,int yy){
l[++tot]=r[xx];d[tot]=yy;r[xx]=tot;
}
void dfs1(int u,int fa){
siz[u]=1;
for(int i=r[u];i;i=l[i]){
if(d[i]==fa) continue;
dfs1(d[i],u);
siz[u]+=siz[d[i]];
if(siz[d[i]]>siz[son[u]]) son[u]=d[i];
}
}
void dfs2(int u,int fa){
set<node>::iterator it,vt;
if(son[u]){
dfs2(son[u],u);
swap(s[u],s[son[u]]);
mx[u]=mx[son[u]];
v[u]=v[son[u]];
}
for(int i=r[u];i;i=l[i]){
if(d[i]==fa||d[i]==son[u]) continue;
dfs2(d[i],u);
for(it=s[d[i]].begin();it!=s[d[i]].end();++it){
vt=s[u].lower_bound((node){(*it).x,0});
node w=*it;
if((*vt).x==(*it).x){
w.v+=(*vt).v;
s[u].erase(vt);
}
if(w.v>mx[u]){
mx[u]=w.v;
v[u]=w.x;
}
else if(w.v==mx[u]) v[u]+=w.x;
s[u].insert(w);
}
}
vt=s[u].lower_bound((node){c[u],0});
node w=(node){c[u],1};
if((*vt).x==c[u]){
w.v+=(*vt).v;
s[u].erase(vt);
}
if(w.v>mx[u]){
mx[u]=w.v;
v[u]=w.x;
}
else if(w.v==mx[u]) v[u]+=w.x;
s[u].insert(w);
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d",&c[i]);
s[i].insert((node){n+1,0});
}
for(int i=1;i<n;i++){
scanf("%d%d",&x,&y);
add(x,y);add(y,x);
}
dfs1(1,0);
dfs2(1,0);
for(int i=1;i<=n;i++){
printf("%lld ",v[i]);
}
return 0;
}