题干
给定一棵
n
n
n 个节点的树,节点编号为
1
∼
n
1∼n
1∼n。每个节点都被染成了黑色(用
1
1
1 表示)或白色(用
0
0
0 表示)。从黑色节点无法到达白色节点,反之亦然。因此,两个同色节点相互可达的前提是,两个同色节点之间的路径中不含另一种颜色的节点。
我们希望将树中的所有节点都染成同一种颜色(全黑或全白均可)。为此,你可以采用我们指定的染色操作。每次操作可以选择一个节点
v
v
v,并改变节点
v
v
v 以及其所有可达同色节点的颜色(黑变白、白变黑)。
例如,在下图中,点
1
1
1 和点
2
,
3
,
8
,
9
2,3,8,9
2,3,8,9 之间相互可达,但是点
1
1
1 和点
6
6
6 之间相互不可达(被点
5
5
5 挡住了),因此,如果选择点
1
1
1 进行染色操作,会将点
1
,
2
,
3
,
8
,
9
1,2,3,8,9
1,2,3,8,9 全部染黑。
请你计算,为了达成目标,至少需要进行多少次染色操作。
输入
第一行包含整数
n
n
n。
第二行包含
n
n
n 个整数
c
1
,
c
2
,
⋯
,
c
n
c_1,c_2,\cdots,c_n
c1,c2,⋯,cn,其中
c
i
c_i
ci 为节点
i
i
i 的颜色(
1
1
1 表示黑,
0
0
0 表示白)。
接下来
n
−
1
n−1
n−1 行,每行包含两个整数
u
i
,
v
i
u_i,v_i
ui,vi,表示节点
u
i
u_i
ui 和节点
v
i
v_i
vi 之间存在一条边。
输出
一个整数,表示所需的最少染色操作次数。
思路
题目的考点是并查集+树的直径。
并查集
既然相同颜色的,互相联通的点可以在一次操作内进行染色操作(即全部由黑变白,或者由白变黑),那么不如将他们视作一个点。通过并查集可以实现这一点。
找直径
缩点后形成的树 T T T 中,相邻结点具有不同颜色。考虑树 T T T 的直径 D = max u , v ∈ T d ( u , v ) D=\max_{u,v\in T}d(u,v) D=maxu,v∈Td(u,v),则必有一条长为 D D D 的路 p p p。
- 当 D D D 为奇数时,这条路由 D − 1 2 \frac{D-1}{2} 2D−1 个黑点(或者白点)和 D + 1 2 \frac{D+1}{2} 2D+1 个白点(或者黑点)组成,最少通过 D − 1 2 = ⌊ D / 2 ⌋ \frac{D-1}{2}=\lfloor D/2\rfloor 2D−1=⌊D/2⌋ 次操作将黑点全部换成白点(或者白点全部换成黑点)即可将这条路变成同一种颜色。
- 当 D D D 为偶数时,这条路由 D / 2 D/2 D/2 个黑点和 D / 2 D/2 D/2 个白点(或者黑点)组成,最少通过 D / 2 = ⌊ D / 2 ⌋ D/2=\lfloor D/2\rfloor D/2=⌊D/2⌋ 次操作将黑点全部换成白点(或者白点全部换成黑点)即可将这条路变成同一种颜色。
因此,如果要将树
T
T
T 变为一种颜色,至少要将这条路
p
p
p 变成一种颜色,次数
a
n
s
≥
⌊
D
/
2
⌋
\mathrm{ans}\geq \lfloor D/2\rfloor
ans≥⌊D/2⌋。此外,我们还能知道,从这条路的中心出发,通过
⌊
D
/
2
⌋
\lfloor D/2\rfloor
⌊D/2⌋ 次操作,还可以将
T
−
p
T-p
T−p (即路外其他结点)也转变为一种颜色。因为如果做不到,说明我们在找直径的时候就找错了。
所以答案就是
a
n
s
=
⌊
D
/
2
⌋
\mathrm{ans}=\lfloor D/2\rfloor
ans=⌊D/2⌋。对于一个数
T
T
T 而言,它的直径为
D
=
f
(
T
)
D=f(T)
D=f(T),其中
f
(
T
)
=
max
(
1
+
d
T
1
+
d
T
2
,
f
(
T
1
)
,
f
(
T
2
)
)
f(T)=\max(1+d_{T_1}+d_{T_2},f(T_1),f(T_2))
f(T)=max(1+dT1+dT2,f(T1),f(T2))
d
T
d_T
dT 表示树
T
T
T 的深度,而
T
1
,
T
2
T_1,T_2
T1,T2 表示
T
T
T 最深的两个子树。
Code
# include <iostream>
# include <cstring>
# include <vector>
using namespace std;
int n,dad[200005],dp[200005],uu[200005],vv[200005];
bool c[200005];
vector<int> nex[200005];
int getdad(int node){
if(dad[node] == node)
return node;
return dad[node] = getdad(dad[node]);
}
int getdp(int node,int fa){
int mdp = 0;
for(int &x : nex[node])
if(x != fa)
mdp = max(mdp,getdp(x,node));
return dp[node] = mdp + 1;
}
int maxdist(int node,int fa){
int dp1 = -1,dp2 = -1,ans = 1;
for(int &x : nex[node])
if(x != fa){
ans = max(ans,maxdist(x,node));
if(dp[x] > dp1){
if(dp1 < dp2) dp1 = dp[x];
else dp2 = dp[x];
}
else dp2 = max(dp2,dp[x]);
}
if(dp1 == -1 && dp2 == -1) return ans;
if(dp1 == -1) return max(ans,1 + dp2);
else if(dp2 == -1) return max(ans,1 + dp1);
return max(ans,1 + dp1 + dp2);
}
int main(){
int u,v;
cin >> n;
for(int i = 1;i <= n;i++)
cin >> c[i];
for(int i = 1;i < n;i++){
cin >> u >> v;
if(dad[u] && dad[v]){
if(c[u] == c[v]) dad[getdad(v)] = getdad(u);
}
else if(dad[u]){
dad[v] = c[u] == c[v]?getdad(u):v;
}
else if(dad[v]){
dad[u] = c[u] == c[v]?getdad(v):u;
}
else{
dad[u] = u;
dad[v] = c[u] == c[v]?u:v;
}
uu[i] = u;
vv[i] = v;
}
for(int i = 1;i < n;i++){
u = getdad(uu[i]);
v = getdad(vv[i]);
if(c[u] != c[v]){
nex[u].push_back(v);
nex[v].push_back(u);
}
}
getdp(getdad(1),0);
return cout << maxdist(getdad(1),0) / 2,0;
}