http://www.elijahqi.win/archives/3566
对于本题我们显然有个暴力的做法 可以搞到40分 那就是每一层的时候暴力枚举权值 然后看右儿子中比我大的有多少个 算出概率和 然后*我这个点的概率即可
题目中有个重要条件 即每个点权值 均不同 那么就应该考虑线段树合并了 这题在线段树合并的时候怎么办
设greatr[i]表示右子树中比i大的概率是duos greatl[i]同理 那么不妨线段树合并的时候从大往小合并这样就可以一路累加下来了
那么最后的 设P表示这个非叶子点在题目中的定义
那么有如下转移
设x为左子树权值为x的点 他成为当前节点的值的概率是greatr[x]*(1-P)+(1-greatr[x])*p再乘左子树这个点出现的概率即可
#include<queue>
#include<cstdio>
#include<cctype>
#include<algorithm>
#define fi first
#define se second
#define ll long long
#define pa pair<int,int>
#define mp(x,y) make_pair(x,y)
using namespace std;
inline char gc(){
static char now[1<<16],*S,*T;
if (T==S){T=(S=now)+fread(now,1,1<<16,stdin);if (T==S) return EOF;}
return *S++;
}
inline int read(){
int x=0,f=1;char ch=gc();
while(!isdigit(ch)) {if (ch=='-') f=-1;ch=gc();}
while(isdigit(ch)) x=x*10+ch-'0',ch=gc();
return x*f;
}
const int N=3e5+10;
const int mod=998244353;
inline int ksm(ll b,int t){static ll tmp;
for (tmp=1;t;b=b*b%mod,t>>=1) if(t&1) tmp=tmp*b%mod;return tmp;
}
struct node{
int left,right,v,tag;
}tree[N*20];
pa q[N];int c[N][2],num,d[N],n,top,rt[N],ans,greatl,greatr,a[N],p[N];
inline bool cmp(const pa &a,const pa &b){return a.fi<b.fi;}
inline void insert1(int &x,int l,int r,int p){
x=++num;tree[x].tag=1;tree[x].v=1;
if (l==r) return;int mid=l+r>>1;
if (p<=mid) insert1(tree[x].left,l,mid,p);
else insert1(tree[x].right,mid+1,r,p);
}
inline void pushdown(int x){
if (tree[x].tag==1) return;
int l=tree[x].left,r=tree[x].right,tag=tree[x].tag;
if (l) tree[l].tag=(ll)tree[l].tag*tag%mod,
tree[l].v=(ll)tree[l].v*tag%mod;
if (r) tree[r].tag=(ll)tree[r].tag*tag%mod,
tree[r].v=(ll)tree[r].v*tag%mod;tree[x].tag=1;
}
inline int inc(int x,int v){return x+v>=mod?x+v-mod:x+v;}
inline int dec(int x,int v){return x-v<0?x-v+mod:x-v;}
inline int merge(int rt1,int rt2,int p){
if (!rt1&&!rt2) return 0;
if (rt1&&!rt2){
greatl=inc(greatl,tree[rt1].v);
int tmp=dec(inc(greatr,p),2LL*greatr*p%mod);
tree[rt1].tag=(ll)tree[rt1].tag*tmp%mod;
tree[rt1].v=(ll)tree[rt1].v*tmp%mod;
return rt1;
}
if (!rt1&&rt2){
greatr=inc(greatr,tree[rt2].v);
int tmp=dec(inc(greatl,p),2LL*greatl*p%mod);
tree[rt2].tag=(ll)tree[rt2].tag*tmp%mod;
tree[rt2].v=(ll)tree[rt2].v*tmp%mod;
return rt2;
}pushdown(rt1);pushdown(rt2);
tree[rt1].right=merge(tree[rt1].right,tree[rt2].right,p);
tree[rt1].left=merge(tree[rt1].left,tree[rt2].left,p);
int l=tree[rt1].left,r=tree[rt1].right;
tree[rt1].v=inc(tree[l].v,tree[r].v);return rt1;
}
inline void dfs(int x){
if (!d[x]) return;
if (d[x]==1) {dfs(c[x][0]);rt[x]=rt[c[x][0]];return;}
if (d[x]==2) {
dfs(c[x][0]);dfs(c[x][1]);
greatl=0;greatr=0;
rt[x]=merge(rt[c[x][0]],rt[c[x][1]],p[x]);
}
}
inline int sqr(ll x){return x*x%mod;}
inline void get_ans(int x,int l,int r){
if (l==r){
ans=inc(ans,(ll)l*q[l].fi%mod*sqr(tree[x].v)%mod);return;
}int mid=l+r>>1;pushdown(x);
get_ans(tree[x].left,l,mid);get_ans(tree[x].right,mid+1,r);
}
int main(){
freopen("loj2537.in","r",stdin);
n=read();
for (int i=1;i<=n;++i){
int f=read();
if(i==1) continue;++d[f];
if (!c[f][0]) c[f][0]=i;
else c[f][1]=i;
}int inv=ksm(10000,mod-2);
for (int i=1;i<=n;++i){a[i]=read();
if (!d[i]) q[++top]=mp(a[i],i);
else p[i]=(ll)a[i]*inv%mod;
}sort(q+1,q+top+1,cmp);
for (int i=1;i<=top;++i) insert1(rt[q[i].se],1,top,i);
dfs(1);get_ans(rt[1],1,top);
printf("%d\n",ans);
return 0;
}