P5405 [CTS2019]氪金手游 【数学概率+树形dp】
先考虑外向树的情况:
这个的关键是要把求满足拓扑序的概率转化为求 每个点都比它的子树中的所有节点先取到的概率 。单个节点
x
x
x 的概率是独立的,为
w
x
∑
y
∈
s
u
b
t
r
e
e
(
x
)
w
y
\frac{w_x}{\sum_{y\in subtree(x)}w_y}
∑y∈subtree(x)wywx ,答案就是所有情况下的 节点概率之积 的和。
记
f
[
x
]
f[x]
f[x] 表示
x
x
x 子树内满足拓扑序的概率,
s
z
[
x
]
sz[x]
sz[x] 表示
∑
y
∈
s
u
b
t
r
e
e
(
x
)
w
y
\sum_{y\in subtree(x)}w_y
∑y∈subtree(x)wy。
可以发现每个
f
[
x
]
f[x]
f[x] 是只与
x
x
x 的子树有关,
f
[
x
]
=
w
x
s
z
[
x
]
∏
x
→
y
f
[
y
]
f[x]=\frac{w_x}{sz[x]}\prod_{x \rightarrow y} f[y]
f[x]=sz[x]wx∏x→yf[y] 。
然后就加一维
f
[
x
]
[
s
z
[
x
]
]
f[x][sz[x]]
f[x][sz[x]] 就可以直接dp了。
考虑有内向边的情况:
可以用 这条边可以外向也可以内向 的方案数,减去 这条边一定外向 的方案数。
也可以容斥,记
g
[
i
]
g[i]
g[i] 表示至少有
i
i
i 条边不满足条件的方案数。
即
i
i
i 条内向边变外向边,剩下的内向边就是 可以外向也可以内向 。
答案就是:
至少零个条件不满足 − - − 至少一个条件不满足 + + + 至少两个条件不满足 − ⋯ -\cdots −⋯
时间复杂度 O ( n 2 ) O(n^2) O(n2)
#include <bits/stdc++.h>
#define N 1003
using namespace std;
typedef long long ll;
const int mod=998244353;
int head[N],nxt[N<<1],to[N<<1],tag[N<<1];
int sz[N],lst[N];
int a1[N],a2[N],a3[N];
ll F[N][N*3],arr[N*3];
ll inv[N*3],n,_;
ll ksm(ll x,ll y){
ll res=1;
while(y){
if(y&1) res=res*x%mod;
x=x*x%mod; y>>=1;
}
return res;
}
void init(){
inv[0]=inv[1]=1;
for(int i=2;i<=n*3;i++) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
}
void add(int x,int y,int z){ nxt[++_]=head[x],head[x]=_,to[_]=y,tag[_]=z; }
void dfs(int x){
ll *f=F[x]; f[0]=1;
int y;
for(int __=head[x],y=to[__];__;__=nxt[__],y=to[__]){
if(y==lst[x]) continue;
lst[y]=x; dfs(y); ll *g=F[y];
for(int i=0;i<=(sz[x]+sz[y])*3;i++) arr[i]=0;
if(tag[__]){
for(int i=0;i<=sz[x]*3;i++)
for(int j=0;j<=sz[y]*3;j++)
arr[i+j]=(arr[i+j]+f[i]*g[j]%mod)%mod;
for(int i=0;i<=(sz[x]+sz[y])*3;i++) f[i]=arr[i];
}
else{
ll now=0;
for(int i=0;i<=sz[y]*3;i++) now=(now+g[i])%mod;
for(int i=0;i<=sz[x]*3;i++) arr[i]=(f[i]*now)%mod;
for(int i=0;i<=sz[x]*3;i++)
for(int j=0;j<=sz[y]*3;j++)
arr[i+j]=(arr[i+j]-f[i]*g[j]%mod+mod)%mod;
for(int i=0;i<=(sz[x]+sz[y])*3;i++) f[i]=arr[i];
}
sz[x]+=sz[y];
}
for(int i=0;i<=sz[x]*3+3;i++) arr[i]=0;
for(int i=0;i<=sz[x]*3;i++){
arr[i+1]=(arr[i+1]+f[i]*a1[x]%mod*inv[i+1]%mod)%mod;
arr[i+2]=(arr[i+2]+f[i]*a2[x]%mod*inv[i+2]*2ll%mod)%mod;
arr[i+3]=(arr[i+3]+f[i]*a3[x]%mod*inv[i+3]*3ll%mod)%mod;
}
sz[x]++;
for(int i=0;i<=sz[x]*3;i++) f[i]=arr[i];
// cout<<'\n';
}
int main(){
// freopen("fgo9.in","r",stdin);
cin>>n;
init();
for(int i=1;i<=n;i++){
cin>>a1[i]>>a2[i]>>a3[i];
ll now=ksm(a1[i]+a2[i]+a3[i],mod-2);
a1[i]=a1[i]*now%mod;
a2[i]=a2[i]*now%mod;
a3[i]=a3[i]*now%mod;
// cout<<a1[i]<<' '<<a2[i]<<' '<<a3[i]<<'\n';
}
int u,v;
for(int i=1;i<n;i++) cin>>u>>v,add(u,v,1),add(v,u,0);
dfs(1);
ll ans=0;
for(int i=0;i<=sz[1]*3;i++) (ans+=F[1][i])%=mod;
cout<<ans;
}