题意
一个树上每个点有一个颜色,每次操作可以把所有颜色为i的点变成另一个你指定的j,问最少需要多少次操作可以使有一种颜色C的所有点两两之间的路径上的点的颜色均为C
分析
考虑树上的路径问题,使用点分治解决
我们每次考虑把分治中心的点的颜色作为C,那么我们把这个点子树内的所有颜色为C的点放到一个队列中,依次考虑队列中每个人的父亲节点是不是C,如果不是,就需要一次操作把父亲点的颜色换成C。
如果出现了颜色C的点不在当前的分治子树内,那么就不需要计算了,因为一定不优,因为跨过了分支中心的父亲节点或者更高的位置。这个画一画可以理解。
剩下就是注意点分治回收的时候的效率即可
代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=2e5+5;
int c[maxn];
vector <int> G[maxn];
struct edge
{
int to,nxt;
}e[maxn<<1];
int tot,head[maxn];
void add(int x,int y)
{
e[++tot].to=y; e[tot].nxt=head[x]; head[x]=tot;
}
int n,k;
int siz[maxn],rt,gs[maxn],col[maxn];
int vis[maxn],totsiz;
void find(int u,int fa)
{
siz[u]=1; gs[u]=0;
for(int i=head[u];i;i=e[i].nxt)
{
int to=e[i].to;
if(to==fa || vis[to]) continue;
find(to,u);
siz[u]+=siz[to];
gs[u]=max(gs[u],siz[to]);
}
gs[u]=max(gs[u],totsiz-siz[u]);
if(gs[rt]>gs[u]) rt=u;
}
int viscol[maxn],fath[maxn],used[maxn];
int res,ans=0x3f3f3f3f;
queue <int> q;
bool push(int x)
{
for(auto u:G[x])
{
if(!used[u]) return 1;
q.push(u);
}
res++;
return 0;
}
void dfs(int u,int fa)
{
fath[u]=fa;
for(int i=head[u];i;i=e[i].nxt)
{
int to=e[i].to;
if(to==fa || vis[to]) continue;
dfs(to,u);
}
}
int st[maxn],top;
void reset(int u,int fa)
{
st[++top]=u; used[u]=1;
for(int i=head[u];i;i=e[i].nxt)
{
int to=e[i].to;
if(to==fa || vis[to]) continue;
reset(to,u);
}
}
void calc(int u)
{
res=0;
while(!q.empty()) q.pop();
viscol[col[u]]=1;
if(push(col[u])) return;
dfs(u,u);
while(!q.empty())
{
int u=q.front(); q.pop();
if(!viscol[col[fath[u]]])
{
viscol[col[fath[u]]]=1;
if(push(col[fath[u]])) return;
}
}
ans=min(ans,res);
}
void solve(int u)
{
vis[u]=1;
reset(u,u);
calc(u);
while(top) used[st[top]]=viscol[col[st[top]]]=0,top--;
for(int i=head[u];i;i=e[i].nxt)
{
int to=e[i].to;
if(vis[to]) continue;
rt=0; totsiz=siz[to];
find(to,u);
solve(rt);
}
}
int main()
{
freopen("color.in","r",stdin);
freopen("color.out","w",stdout);
scanf("%d%d",&n,&k);
int x,y;
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
for(int i=1;i<=n;i++)
{
scanf("%d",&col[i]);
G[col[i]].push_back(i);
}
gs[rt=0]=n; totsiz=n;
find(1,1);
solve(rt);
printf("%d\n",ans-1);
return 0;
}