题解
我们从叶子节点逐层递归到根节点。
我们可以建动态开点权值线段树,每个结点上建一颗,递归时考虑合并左右子树的信息。
可以得到,当合并左右儿子到某父结点上时,可以这样转移:
设f[i][j]是在i结点上,权值为j(令权值为j只能从左儿子中转移上来)的概率,则f[i][j]=f[lc[i]][j]∗((1−p)∗sum1[rc[i]](所有值大于i的节点上的概率之和)+p∗sum2[rc[i]](所有值小于i的节点上的概率之和))f[i][j]=f[lc[i]][j]∗((1−p)∗sum1[rc[i]](所有值大于i的节点上的概率之和)+p∗sum2[rc[i]](所有值小于i的节点上的概率之和))
具体做法见代码,感觉这个实现很巧妙啊(线段树博大精深(跪),我们记一个区间乘标记就万事大吉了啊。
代码
#include<cstdio>
#include<cctype>
#include<algorithm>
#define ls ch[k][0]
#define rs ch[k][1]
#define mid (((l)+(r))>>1)
using namespace std;
typedef long long ll;
const int N=3e5+10,M=5e6+10,mod=998244353;
int n,son[N][2],rt[N],ch[M][2],p,w[N],rk[N],tot,cnt;
int inv=796898467;ll s[M],cg[M];
struct P{
int val,id;
bool operator <(const P&u)const{
return val<u.val;
}
}t[N];
inline int rd()
{
char ch=getchar();int x=0,f=1;
while(!isdigit(ch)){if(ch=='-') f=-1;ch=getchar();}
while(isdigit(ch)){x=x*10+(ch^48);ch=getchar();}
return x*f;
}
inline void pushdown(int k)
{
if(cg[k]==1) return;
s[ls]=s[ls]*cg[k]%mod;s[rs]=s[rs]*cg[k]%mod;
cg[ls]=cg[ls]*cg[k]%mod;cg[rs]=cg[rs]*cg[k]%mod;
cg[k]=1;
}
inline void insert(int &k,int l,int r,int pos)
{
if(!k) k=++tot;s[k]=cg[k]=1;
if(l==r) return;
if(pos<=mid) insert(ls,l,mid,pos);
else insert(rs,mid+1,r,pos);
}
inline int merge(int x,int y,ll sx,ll sy)
{
if(!x){s[y]=s[y]*sy%mod;cg[y]=cg[y]*sy%mod;return y;}
if(!y){s[x]=s[x]*sx%mod;cg[x]=cg[x]*sx%mod;return x;}
pushdown(x);pushdown(y);
ll A=s[ch[x][0]],B=s[ch[x][1]],C=s[ch[y][0]],D=s[ch[y][1]];
ch[x][0]=merge(ch[x][0],ch[y][0],(sx+(1-p+mod)*D%mod)%mod,(sy+(1-p+mod)*B%mod)%mod);
ch[x][1]=merge(ch[x][1],ch[y][1],(sx+p*C%mod)%mod,(sy+p*A%mod)%mod);
s[x]=(s[ch[x][0]]+s[ch[x][1]])%mod;
return x;
}
inline int solve(int x)
{
if(!son[x][0]){insert(rt[x],1,cnt,rk[x]);return rt[x];}
int l=solve(son[x][0]);
if(!son[x][1]) return l;
int r=solve(son[x][1]);
p=w[x];
return merge(l,r,0,0);
}
inline ll cal(int k,int l,int r)
{
if(l==r){return 1ll*l*t[l].val%mod*s[k]%mod*s[k]%mod;}
pushdown(k);
return (cal(ls,l,mid)+cal(rs,mid+1,r))%mod;
}
int main(){
int i,j,k;
n=rd();
for(i=1;i<=n;++i){j=rd();son[j][(son[j][0]!=0)]=i;}
for(i=1;i<=n;++i){
w[i]=rd();
if(son[i][0]) w[i]=1ll*w[i]*inv%mod;//w[i]->1ll*w[i]
else{t[++cnt].val=w[i];t[cnt].id=i;}
}
sort(t+1,t+cnt+1);
for(i=1;i<=cnt;++i) rk[t[i].id]=i;
k=solve(1);
printf("%lld\n",cal(k,1,cnt));
}