Description
给出一棵nnn个节点的树,再给出mmm只僵尸的位置xix_ixi以及能力hih_ihi,第iii条边可以等概率建立起高度为[li,ri][l_i,r_i][li,ri]的围墙,第iii只僵尸可以越过任何小于hih_ihi的围墙,问至少存在一个位置安全(即没有僵尸可以到达这个位置)的概率
Input
第一行一整数TTT表示用例组数,每组用例首先输入两个整数n,mn,mn,m表示点数和僵尸数量,之后n−1n-1n−1行每行输入四个整数u,v,li,riu,v,l_i,r_iu,v,li,ri表示第iii条树边为u↔vu\leftrightarrow vu↔v,可以建立围墙高度范围为[li,ri][l_i,r_i][li,ri],最后mmm行每行两个整数xi,hix_i,h_ixi,hi表示第iii个点处有一个能力为hih_ihi的僵尸
(1≤T≤5,1≤n,m≤2000,1≤li≤ri≤109,1≤hi≤109)(1\le T\le 5,1\le n,m\le 2000,1\le l_i\le r_i\le 10^9,1\le h_i\le 10^9)(1≤T≤5,1≤n,m≤2000,1≤li≤ri≤109,1≤hi≤109)
Output
输出至少存在一个位置安全的概率,结果模998244353998244353998244353
Sample Input
2
4 2
1 2 1 2
2 3 1 2
1 4 1 2
1 2
3 2
5 2
1 2 1 10
2 3 2 9
1 4 3 12
2 5 4 6
1 7
5 5
Sample Output
374341633
888437475
Solution
考虑所有位置都不安全的方案数,把僵尸按能力升序排,以dp[u][i]dp[u][i]dp[u][i]表示以uuu为根的子树全部不安全,且uuu子树中可以到达uuu的最强僵尸编号为iii,以此考虑uuu的子树,对于当前考虑的儿子vvv,假设u,vu,vu,v边被iii僵尸通过的方案数为xix_ixi,不被通过的方案数为yiy_iyi,那么有三种情况:
1.vvv被iii干掉了,dp[u][i]+=dp[u][i]⋅dp[v][i]⋅xidp[u][i]+=dp[u][i]\cdot dp[v][i]\cdot x_idp[u][i]+=dp[u][i]⋅dp[v][i]⋅xi
2.vvv被其子树内弱于iii的僵尸干掉,此时显然iii僵尸不会在vvv子树中,为使干掉vvv的最强僵尸不超过iii,u,vu,vu,v之间需要阻碍iii僵尸通过,故有dp[u][i]+=dp[u][i]⋅yi⋅∑j<idp[v][j]dp[u][i]+=dp[u][i]\cdot y_i\cdot \sum\limits_{j<i}dp[v][j]dp[u][i]+=dp[u][i]⋅yi⋅j<i∑dp[v][j]
3.vvv被其子树内强于iii的僵尸干掉,此时不能让这些强于iii的僵尸通过u,vu,vu,v之间的边去干掉uuu,故u,vu,vu,v之间需要阻碍这些更强的僵尸通过,故有dp[u][i]+=dp[u][i]⋅∑j≥idp[v][j]⋅yjdp[u][i]+=dp[u][i]\cdot \sum\limits_{j\ge i}dp[v][j]\cdot y_jdp[u][i]+=dp[u][i]⋅j≥i∑dp[v][j]⋅yj
前缀和优化一下,第二部分从弱僵尸到强僵尸考虑,第三部分从强僵尸到弱僵尸考虑即可,答案即为
1−∑i=1mdp[1][i]∏i=1n−1(ri−li+1)
1-\frac{\sum\limits_{i=1}^mdp[1][i]}{\prod\limits_{i=1}^{n-1}(r_i-l_i+1)}
1−i=1∏n−1(ri−li+1)i=1∑mdp[1][i]
时间复杂度O(nm)O(nm)O(nm)
Code
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
typedef long long ll;
typedef pair<int,int>P;
#define maxn 2005
#define mod 998244353
int mul(int x,int y)
{
ll z=1ll*x*y;
return z-z/mod*mod;
}
int add(int x,int y)
{
x+=y;
if(x>=mod)x-=mod;
return x;
}
int Pow(int x,int y)
{
int ans=1;
while(y)
{
if(y&1)ans=mul(ans,x);
x=mul(x,x);
y>>=1;
}
return ans;
}
#define id second
#define val first
P a[maxn];
vector<P>g[maxn];
int T,n,m,l[maxn],r[maxn],dp[maxn][maxn],vis[maxn][maxn],temp[maxn];
void dfs(int u,int fa)
{
int pos=1;
for(int i=1;i<=m;i++)
if(a[i].id==u)
{
vis[u][i]=1;
pos=i;
}
else vis[u][i]=0;
for(int i=pos;i<=m;i++)dp[u][i]=1;
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i].first,t=g[u][i].second;
if(v==fa)continue;
dfs(v,u);
for(int j=1;j<=m;j++)temp[j]=dp[u][j],dp[u][j]=0;
int res=0;
for(int j=1;j<=m;j++)
{
int x=max(0,min(a[j].val-1,r[t])-l[t]+1),y=r[t]-l[t]+1-x;
dp[u][j]=add(dp[u][j],mul(x,mul(temp[j],dp[v][j])));
if(vis[v][j])res=add(res,dp[v][j]);
else dp[u][j]=add(dp[u][j],mul(mul(y,res),temp[j]));
}
res=0;
for(int j=m;j>=1;j--)
{
int x=max(0,min(a[j].val-1,r[t])-l[t]+1),y=r[t]-l[t]+1-x;
if(vis[v][j])res=add(res,mul(dp[v][j],y));
else dp[u][j]=add(dp[u][j],mul(res,temp[j]));
}
for(int j=1;j<=m;j++)vis[u][j]|=vis[v][j];
}
}
int main()
{
scanf("%d",&T);
while(T--)
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)g[i].clear();
memset(dp,0,sizeof(dp));
int ans=1;
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d%d%d",&u,&v,&l[i],&r[i]);
ans=mul(ans,r[i]-l[i]+1);
g[u].push_back(P(v,i)),g[v].push_back(P(u,i));
}
for(int i=1;i<=m;i++)scanf("%d%d",&a[i].id,&a[i].val);
sort(a+1,a+m+1);
dfs(1,0);
int res=0;
for(int i=1;i<=m;i++)res=add(res,dp[1][i]);
res=add(ans,mod-res);
printf("%d\n",mul(res,Pow(ans,mod-2)));
}
return 0;
}