2021牛客暑期多校训练营#9:G-Glass Balls
原题链接:https://ac.nowcoder.com/acm/contest/11260/G
题目大意
有一棵
n
n
n个节点的树,根为
1
1
1,每个节点上有个玻璃球。
对于
v
v
v的任意儿子
u
u
u,
u
u
u的高度比
v
v
v高
1
1
1个单位,因此
u
u
u上的玻璃球可以沿
(
u
,
v
)
(u,v)
(u,v)边滚到
v
v
v。球沿一条边滚动需要花费
1
1
1秒。
所有球同时开始滚动。有一些节点是可存储节点,其中根节点必定是可存储节点,其他节点有 p p p的概率是可存储节点。如果球滚到可存储节点就会立刻破裂,然后从树上取下。如果球初始就在可存储节点上,会立刻破裂。
两个球如果同时滚到同一个节点就会发生碰撞,不论节点是否是可存储节点。如果发生碰撞不仅两个球会碎,整个系统也会崩坏,所有滚动都会停止。
如果发生碰撞,则分数为0,否则分数为
∑
i
=
1
n
f
(
i
)
\sum_{i=1}^{n}f(i)
∑i=1nf(i),其中
f
(
u
)
f(u)
f(u)表示初始在
u
u
u上的球所滚过的边数。
求分数的期望。
解题思路
因为任何节点最多同时存在一个玻璃球,所以任意节点的子节点都最多只能留一个不成为可存储节点。
一个节点及其子节点合法的概率如下图描述:
设
s
i
z
e
i
size_i
sizei为i节点的子节点个数,
P
P
P为合法的概率,则合法的概率为:
P
=
∏
s
i
z
e
i
(
1
−
p
)
p
s
i
z
e
i
−
1
+
p
s
i
z
e
i
P=\prod size_i(1-p)p^{size_i-1}+p^{size_i}
P=∏sizei(1−p)psizei−1+psizei
设
d
p
x
dp_x
dpx表示在合法情况下
x
x
x上的球的期望,,设
y
y
y为
x
x
x的父节点,整理可得:
d
p
1
=
0
;
d
p
x
=
(
1
+
d
p
y
)
(
1
−
p
)
p
s
i
z
e
y
−
1
s
i
z
e
y
(
1
−
p
)
p
s
i
z
e
y
−
1
+
p
s
i
z
e
y
\begin{aligned} dp_1&=0;\\ dp_x&=(1+dp_y)\frac{(1-p)p^{size_y-1}}{size_y(1-p)p^{size_y-1}+p^{size_y}}\\ \end{aligned}
dp1dpx=0;=(1+dpy)sizey(1−p)psizey−1+psizey(1−p)psizey−1
从上往下遍历,求出每一个
d
p
i
dp_i
dpi的值,最终答案为:
a
n
s
=
P
∑
d
p
x
ans=P\sum dp_x
ans=P∑dpx
代码实现
#include<bits/stdc++.h>
#define pb push_back
using namespace std;
const int mod=998244353,N=5e5+7;
int n,m,dp[N],dfn[N],cnt,p,P=1,ans;
vector<int>ve[N];
int pm(int x,int p)//快速幂
{
int res=1;
while(p>0)
{
if(p&1)res=1ll*res*x%mod;
x=1ll*x*x%mod;
p>>=1;
}
return res;
}
void rd(int &x){int res=0;char c=getchar(),last=' ';while(c<'0'||c>'9')last=c,c=getchar();while(c<='9'&&c>='0')res=res*10-'0'+c,c=getchar();x=last=='-'?-res:res;}//据某人说快读不压成一行就没有灵性
void dfs(int x,int fa){
dp[x]=1ll*(1+dp[fa])%mod*(1-p+mod)%mod*pm(p,ve[fa].size()-1)%mod;
dp[x]=1ll*dp[x]*pm((ve[fa].size()*(1-p+mod)%mod*pm(p,ve[fa].size()-1)%mod+pm(p,ve[fa].size()))%mod,mod-2)%mod;//求dp
for(auto i:ve[x])dfs(i,x);
}
int main()
{
rd(n),rd(p);
for(int i=2;i<=n;i++)
{
int x;
rd(x);
ve[x].pb(i);
}
for(int i=1;i<=n;i++)
{
int s=ve[i].size();
if(s>0)
P=1ll*P*(1ll*s*(1-p+mod)%mod*pm(p,s-1)%mod+pm(p,s))%mod;//求P的值
}
dp[1]=0;
for(auto i:ve[1])
dfs(i,1);
for(int i=1;i<=n;i++)
ans=(1ll*ans+dp[i])%mod;//求dp总和的值
cout<<1ll*ans*P%mod;//求最终答案
}