题目大意
有一棵有nnn个节点的树,每条边有一个边权www。有mmm个特殊点,将这些点记为集合AAA。
将AAA中的元素随机打乱得到序列aaa,求∑i=2md(ai−1,ai)\sum\limits_{i=2}^md(a_{i-1},a_i)i=2∑md(ai−1,ai)的期望值模998244353998244353998244353后的值,其中d(x,y)d(x,y)d(x,y)表示xxx到yyy的边权和。
有qqq次修改,每次修改会将与xxx相连的边的权值增加kkk。求每次修改后上述式子的期望值。
1≤n≤5×105,m≤n,1≤q≤5×1051\leq n\leq 5\times 10^5,m\leq n,1\leq q\leq 5\times 10^51≤n≤5×105,m≤n,1≤q≤5×105
1≤w,k≤1091\leq w,k\leq 10^91≤w,k≤109
题解
对于每组特殊点x,yx,yx,y,我们考虑有多少种方案会计算到d(x,y)d(x,y)d(x,y)的贡献。在确定x,yx,yx,y在aaa中相邻之后,其他m−2m-2m−2个数有(m−2)!(m-2)!(m−2)!种放法,x,yx,yx,y中较前的数可以放在第一个到第m−1m-1m−1个位置上,确定了前一个数,则后一个数也确定了,而这两个数的顺序可以为x,yx,yx,y或者y,xy,xy,x,所以还要乘222,也就是说有2(m−2)!×(m−1)=2(m−1)!2(m-2)!\times (m-1)=2(m-1)!2(m−2)!×(m−1)=2(m−1)!种方案会计算到d(x,y)d(x,y)d(x,y)的贡献。而题目要求的是期望值,总共有m!m!m!种方案,那么d(x,y)d(x,y)d(x,y)对答案的贡献为2(m−1)!m!×d(x,y)=2m×d(x,y)\dfrac{2(m-1)!}{m!}\times d(x,y)=\dfrac 2m\times d(x,y)m!2(m−1)!×d(x,y)=m2×d(x,y)。
下面,我们要求每条边被多少d(x,y)d(x,y)d(x,y)计算过,这用一个dfsdfsdfs即可算出,记这个值为tditd_itdi。然后,求出所有边iii的wiw_iwi与tditd_itdi之积的和,也就是∑iwi×tdi\sum\limits_iw_i\times td_ii∑wi×tdi,m2×∑iwi×tdi\dfrac m2\times \sum\limits_iw_i\times td_i2m×i∑wi×tdi即为答案。
我们考虑每次修改对答案的贡献。设与iii相连的边的tdtdtd值之和为twitw_itwi,则每次修改会让∑iwi×tdi\sum\limits_iw_i\times td_ii∑wi×tdi增加k×twik\times tw_ik×twi。那么,我们可以O(1)O(1)O(1)修改。因为题目只需要求答案,所以我们不需要真的去修改wiw_iwi。
时间复杂度为O(n+q)O(n+q)O(n+q)。
code
#include<bits/stdc++.h>
using namespace std;
const int N=500000;
const long long mod=998244353;
int n,m,q,z[N+5],siz[N+5];
long long ans,pt,w[N+5],td[N+5],tw[N+5];
vector<pair<int,int>>g[N+5];
long long mi(long long t,long long v){
if(!v) return 1;
long long re=mi(t,v/2);
re=re*re%mod;
if(v&1) re=re*t%mod;
return re;
}
void dfs(int u,int fa){
siz[u]=z[u];
for(auto p:g[u]){
int v=p.first,id=p.second;
if(v==fa) continue;
dfs(v,u);
siz[u]+=siz[v];
td[id]=1ll*(m-siz[v])*siz[v]%mod;
}
}
int main()
{
// freopen("sakuya.in","r",stdin);
// freopen("sakuya.out","w",stdout);
scanf("%d%d",&n,&m);
for(int i=1,x,y;i<n;i++){
scanf("%d%d%lld",&x,&y,&w[i]);
g[x].push_back({y,i});
g[y].push_back({x,i});
}
for(int i=1,x;i<=m;i++){
scanf("%d",&x);
z[x]=1;
}
dfs(1,0);
for(int i=1;i<n;i++){
ans=(ans+td[i]*w[i])%mod;
}
for(int i=1;i<=n;i++){
for(auto p:g[i]){
tw[i]=(tw[i]+td[p.second])%mod;
}
}
scanf("%d",&q);
long long tq=mi(m,mod-2)*2%mod;
for(int o=1,x,k;o<=q;o++){
scanf("%d%d",&x,&k);
ans=(ans+tw[x]*k)%mod;
pt=ans*tq%mod;
printf("%lld\n",pt);
}
return 0;
}