Description
给定一个n个节点的无根树,每个点有点权,要求将这棵树分成若干条路径(每个点在且仅在一条路径中),使得每条路径的点权和非负。
求方案数 模1000000007
n≤100000,|点权|≤10000
Solution
随便弄个根
设F[i]表示i这个点为根的子树已经全部覆盖完的方案数
那么现在要找到一条路径来覆盖i这个点
考虑两个路径的两个端点x,y
如果这两个端点不在i的同一个儿子内,或者其中一个是i,那么就可以覆盖i
设val[i]表示i这个节点到根路径上的点权和
那么路径刚才的路径x,y合法,就必须满足val[x]+val[y]−val[i]−val[fa[i]]≥0
这条路径对f[i]的贡献是∏fa[p]∈path(x,y)f[p]
可以将路径拆成X到i,Y到i两部分
类似线段树合并的,对每个点维护一个权值线段树,下标为子树内的val,值就是贡献
扫i的所有儿子,用启发式合并,直接继承最大儿子,其他的儿子暴力统计,注意扫到一个儿子还要乘上其他儿子的F之积,用前缀积和后缀积维护即可。
Code
#include <cstdio>
#include <iostream>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fod(i,a,b) for(int i=a;i>=b;i--)
#define N 500005
#define M 12000005
#define R 400000007
#define LL long long
#define mo 1000000007
using namespace std;
int t[M][2],n1,rt[N],fs[N],dt[2*N],nt[2*N],n,m,mw[N],d[N],dfw[N],dfn[N],ft[N];
LL pre[N],s1[M],lst[N],f[N],fq[N],sv[N],lz[M],su[N],val[N],sz[N];
void link(int x,int y)
{
nt[++m]=fs[x];
dt[fs[x]=m]=y;
}
void dfs(int k,int fa)
{
val[k]+=val[fa];
dfw[++dfw[0]]=k;
dfn[k]=dfw[0];
sz[k]=1;
ft[k]=fa;
int mx=0;
for(int i=fs[k];i;i=nt[i])
{
int p=dt[i];
if(p!=fa)
{
dfs(p,k);
sz[k]+=sz[p];
if(sz[p]>mx) mx=sz[p],mw[k]=p;
}
}
}
void up(int k)
{
s1[k]=(s1[t[k][0]]+s1[t[k][1]])%mo;
}
void update(LL v,int k)
{
lz[k]=(lz[k]<0)?v:lz[k]*v%mo;
lz[k]%=mo;
(s1[k]*=v)%=mo;
}
void down(int k)
{
if(lz[k]==-1) return;
update(lz[k],t[k][0]),update(lz[k],t[k][1]);
lz[k]=-1;
}
LL fd(int k,LL l,LL r,LL x,LL y)
{
x=max(x,l),y=min(y,r);
if(x>y||!k||s1[k]==0) return 0;
if(x==l&&y==r) return s1[k];
LL mid=(l+r)/2;
if(!t[k][0]) t[k][0]=++n1,t[k][1]=++n1;
down(k);
return (fd(t[k][0],l,mid,x,y)+fd(t[k][1],mid+1,r,x,y))%mo;
}
void ins(int k,LL l,LL r,LL x,LL v)
{
if(l==r&&x==l) s1[k]+=v;
else
{
LL mid=(l+r)/2;
if(!t[k][0]) t[k][0]=++n1,t[k][1]=++n1;
down(k);
if(x<=mid) ins(t[k][0],l,mid,x,v);
else ins(t[k][1],mid+1,r,x,v);
up(k);
}
}
void dp(int k)
{
for(int i=fs[k];i;i=nt[i])
{
if(dt[i]!=ft[k]) dp(dt[i]);
}
rt[k]=rt[mw[k]];
if(!mw[k]) rt[k]=++n1;
d[0]=0;
pre[0]=1;
for(int i=fs[k];i;i=nt[i])
{
int p=dt[i];
if(p!=ft[k]&&p!=mw[k]) d[++d[0]]=p,pre[d[0]]=pre[d[0]-1]*f[p]%mo;
}
lst[d[0]+1]=1;
f[k]=pre[d[0]]*fd(rt[k],0,R+R,val[ft[k]]+R,R+R)%mo;
fod(i,d[0],1)
{
int p=d[i];
fq[p]=pre[i-1]*lst[i+1]%mo*f[mw[k]]%mo;
(f[k]+=fd(rt[p],0,R+R,val[ft[k]]+R,R+R)*fq[p]%mo)%=mo;
lst[i]=lst[i+1]*f[p]%mo;
}
fq[mw[k]]=pre[d[0]];
fod(i,d[0],1)
{
int p=d[i];
fo(j,1,sz[p])
{
int q=dfw[dfn[p]+j-1];
if(ft[q]!=k) sv[q]=sv[ft[q]]*fq[q]%mo;
else sv[q]=1;
(f[k]+=fd(rt[k],0,R+R,val[k]+val[ft[k]]-val[q]+R,R+R)*sv[q]%mo*su[q]%mo*pre[i-1]%mo)%=mo;
}
update(f[p],rt[k]);
fo(j,1,sz[p])
{
int q=dfw[dfn[p]+j-1];
ins(rt[k],0,R+R,val[q]+R,sv[q]*su[q]%mo*lst[i+1]%mo*f[mw[k]]%mo);
}
}
f[0]=1;
su[k]=pre[d[0]]*f[mw[k]]%mo;
if(val[k]-val[ft[k]]>=0) (f[k]+=su[k])%=mo;
ins(rt[k],0,R+R,val[k]+R,su[k]);
}
int main()
{
cin>>n;
memset(lz,255,sizeof(lz));
fo(i,1,n) scanf("%lld",&val[i]);
fo(i,1,n-1)
{
int x,y;
scanf("%d%d",&x,&y);
link(x,y),link(y,x);
}
dfs(1,0);
dp(1);
printf("%lld\n",f[1]);
}