题面
题意
给出一棵一共有k种颜色的树,每个点要么没有颜色,要么是k种颜色中的一种颜色,每种颜色至少出现一次,现在要你删去k-1条边,使每个连通块中有且仅有一种颜色,问有几种删法。
做法
首先判断无解的情况,可以发现如果一个有颜色的点为根的子树不包括所有这个颜色的点,则这个点与其父亲之间的边不可以被删去,我们可以据此将其父亲也染为与它颜色相同的颜色。
因此,我们可以进行bfs,如果这一层中有多个颜色t,则将这层中所有颜色为t的点的父亲的颜色也染为t(若其父亲已经有颜色了,且不为t,则无解),如此每种颜色的点都连成了一个连通块,然后我们可以不断将没有颜色的叶子节点去掉,使图更加精简。
接下来我们进行dp,dp[i][0]表示此时以最高点为i的连通块中没有颜色,dp[i][1]表示此时以最高点为i的连通块中有某个颜色,可以发现点u与其儿子v是否连边可以由上述状态决定。
下面考虑状态转移:
1.若点u没有颜色,则:
d
p
[
u
]
[
0
]
=
∏
i
∈
s
o
n
(
u
)
d
p
[
i
]
[
0
]
+
d
p
[
i
]
[
1
]
dp[u][0]=\prod_{i\in son(u)}{dp[i][0]+dp[i][1]}
dp[u][0]=∏i∈son(u)dp[i][0]+dp[i][1]
它可以由所有状态转移过来。
d
p
[
u
]
[
1
]
=
∑
i
∈
s
o
n
(
u
)
d
p
[
i
]
[
1
]
∗
∏
j
∈
s
o
n
(
u
)
,
j
̸
=
i
d
p
[
j
]
[
0
]
+
d
p
[
j
]
[
1
]
dp[u][1]=\sum_{i\in son(u)}{dp[i][1]*\prod_{j\in son(u),j \not=i}{dp[j][0]+dp[j][1]}}
dp[u][1]=∑i∈son(u)dp[i][1]∗∏j∈son(u),j̸=idp[j][0]+dp[j][1]
枚举它的颜色由哪个儿子转移而来
2.若点u由颜色,则:
d
p
[
u
]
[
0
]
=
0
dp[u][0]=0
dp[u][0]=0它所在的连通块不可能没有颜色
d
p
[
u
]
[
1
]
=
∏
i
∈
s
o
n
(
u
)
d
p
[
i
]
[
0
]
+
d
p
[
i
]
[
1
]
dp[u][1]=\prod_{i\in son(u)}{dp[i][0]+dp[i][1]}
dp[u][1]=∏i∈son(u)dp[i][0]+dp[i][1]
同样可以由所有颜色转移而来
代码
#include<bits/stdc++.h>
#define ll long long
#define N 300100
#define M 998244353
using namespace std;
ll n,m,bb,tt,ans=1,num[N],fa[N],bfn[N],cnt[N],ds[N],dp[N][2];
bool gg[N];
vector<ll>to[N];
queue<ll>que;
inline ll po(ll u,ll v)
{
ll res=1;
for(;v;)
{
if(v&1) res=res*u%M;
u=u*u%M;
v>>=1;
}
return res;
}
void dfs(ll now,ll last)
{
ll i,t;
if(num[now]) dp[now][1]=1;
else dp[now][0]=1;
for(i=0;i<to[now].size();i++)
{
t=to[now][i];
if(t==last) continue;
dfs(t,now);
if(num[now]) dp[now][1]=dp[now][1]*(dp[t][0]+dp[t][1])%M;
else
{
dp[now][0]=dp[now][0]*(dp[t][0]+dp[t][1])%M;
dp[now][1]=(dp[now][1]+dp[t][1]*po(dp[t][0]+dp[t][1],M-2)%M)%M;
}
}
if(!num[now]) dp[now][1]=dp[now][1]*dp[now][0]%M;
}
int main()
{
ll i,j,p,q,t;
cin>>n>>m;
for(i=1;i<=n;i++) scanf("%lld",&num[i]),cnt[num[i]]++;
for(i=1;i<n;i++)
{
scanf("%lld%lld",&p,&q);
to[p].push_back(q);
to[q].push_back(p);
ds[p]++,ds[q]++;
}
que.push(1);
for(;!que.empty();)
{
q=que.front();
que.pop();
bfn[++tt]=q;
for(i=0;i<to[q].size();i++)
{
p=to[q][i];
if(p==fa[q]) continue;
fa[p]=q;
que.push(p);
}
}
for(j=n;j>=1;j--)
{
i=bfn[j];
if(!num[i]) continue;
if(cnt[num[i]]<=1 || num[fa[i]]==num[i])
{
cnt[num[i]]--;
continue;
}
if(num[fa[i]])
{
puts("0");
return 0;
}
num[fa[i]]=num[i];
}
for(i=1;i<=n;i++)
{
if(ds[i]==1 && !num[i])
{
que.push(i);
}
}
for(;!que.empty();)
{
q=que.front();
que.pop();
gg[q]=1;
for(i=0;i<to[q].size();i++)
{
p=to[q][i];
if(num[p]) continue;
ds[p]--;
if(ds[p]==1) que.push(p);
}
}
for(i=1;i<=n;i++) if(num[i]) break;
dfs(i,-1);
cout<<dp[i][1];
}