题意
维护树上两点路径上任意两点的距离的期望值
题解
搬运一波PoPoQQQ大神的博客
http://blog.youkuaiyun.com/popoqqq/article/details/40823659
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<iostream>
#define F(i) (T[i].fa)
#define L(i) (T[i].s[0])
#define R(i) (T[i].s[1])
#define S(i) (T[i].sum)
#define Ls(i) (T[i].ls)
#define Rs(i) (T[i].rs)
#define ss(i) (T[i].ss)
#define sz(i) (T[i].siz)
#define Loc(i) (R(F(i))==i)
#define For(i,j,k) for(int i=(j);i<=(int)k;i++)
#define Forr(i,j,k) for(int i=(j);i>=(int)k;i--)
#define Set(a,b) memset(a,b,sizeof(a))
#define ll long long
using namespace std;
const int N=50500;
int Next[N*2],Begin[N],to[N*2],e;
inline void add(int x,int y){
to[++e]=y,Next[e]=Begin[x],Begin[x]=e;
}
inline void read(int &x){
x=0;char c=getchar();int f(0);
while(c<'0'||c>'9')f|=(c=='-'),c=getchar();
while(c>='0'&&c<='9')x=x*10+c-'0',c=getchar();
if(f)x=-x;
}
ll gcd(ll a ,ll b){
return b?gcd(b,a%b):a;
}
struct node{
int s[2],fa,siz,rev;
ll ls,rs,sum,val,ss,mk;
node(){
s[1]=s[0]=fa=ls=rs=sum=val=ss=mk=0,siz=1;
}
};
struct LCT{
node T[N];
LCT(){T[0].siz=0;}
inline void pushup(int x){
sz(x)=sz(L(x))+sz(R(x))+1;
S(x)=S(L(x))+S(R(x))+T[x].val;
Ls(x)=Ls(L(x))+(sz(L(x))+1)*T[x].val+Ls(R(x))+(sz(L(x))+1)*S(R(x));
Rs(x)=Rs(R(x))+(sz(R(x))+1)*T[x].val+Rs(L(x))+(sz(R(x))+1)*S(L(x));
ss(x)=ss(L(x))+(sz(R(x))+1)*Ls(L(x))+ss(R(x))+(sz(L(x))+1)*Rs(R(x))+(sz(R(x))+1)*(sz(L(x))+1)*T[x].val;
}
inline void Swap(int x){
swap(L(x),R(x)),swap(Ls(x),Rs(x)),T[x].rev^=1;
}
inline void add(int x,ll v){
T[x].mk+=v;
T[x].val+=v;
T[x].sum+=sz(x)*v;
Ls(x)+=v*sz(x)*(sz(x)+1)/2;
Rs(x)+=v*sz(x)*(sz(x)+1)/2;
ss(x)+=v*sz(x)*(sz(x)+1)*(sz(x)+2)/6;
}
inline void pushdown(int x){
if(T[x].rev){
Swap(L(x)),Swap(R(x));
T[x].rev=0;
}
if(T[x].mk){
if(L(x))add(L(x),T[x].mk);
if(R(x))add(R(x),T[x].mk);
T[x].mk=0;
}
}
inline bool isrt(int x){
return R(F(x))!=x&&L(F(x))!=x;
}
inline void Pushdown(int x){
if(!isrt(x))Pushdown(F(x));
pushdown(x);
}
inline void Rotate(int x){
int A=F(x),B=F(A),l=Loc(x),r=l^1,d=Loc(A);
if(!isrt(A))T[B].s[d]=x;F(x)=B;
F(A)=x,F(T[x].s[r])=A,T[A].s[l]=T[x].s[r],T[x].s[r]=A;
pushup(A),pushup(x);
}
inline void splay(int x){
Pushdown(x);
while(!isrt(x)){
if(!isrt(F(x)))Rotate(x);
Rotate(x);
}
}
inline void access(int x){
for(int i=0;x;i=x,x=F(x))
splay(x),R(x)=i,pushup(x);
}
inline int findrt(int x){
access(x),splay(x);
while(L(x))x=L(x);
return x;
}
inline void reverse(int x){
access(x),splay(x),Swap(x);
}
inline void split(int x,int y){
reverse(x),access(y),splay(y);
}
inline void link(int x,int y){
if(findrt(x)==findrt(y))return ;
reverse(x),F(x)=y;
}
inline void cut(int x,int y){
if(x==y||findrt(x)!=findrt(y))return ;
split(x,y);
if(L(y)==x)F(x)=0,L(y)=0,pushup(y);
}
inline void modify(int x,int y){
int v;read(v);
if(findrt(x)!=findrt(y))return ;
split(x,y);add(y,v);
}
inline void query(int x,int y){
if(findrt(x)!=findrt(y)){puts("-1");return ;}
split(x,y);
ll fz=ss(y),fm=sz(y)*(sz(y)+1)/2,d=gcd(fz,fm);
printf("%lld/%lld\n",fz/d,fm/d);
}
}t;
void dfs(int u,int fa){
for(int i=Begin[u];i;i=Next[i]){
int v=to[i];
if(v==fa)continue;
t.T[v].fa=u;
dfs(v,u);
}
}
int main(){
int n,m,u,v,p;
read(n),read(m);
For(i,1,n){
read(v);
t.add(i,v);t.T[i].mk=0;
}
For(i,1,n-1){
read(u),read(v);
add(u,v),add(v,u);
}
dfs(1,0);
while(m--){
read(p),read(u),read(v);
if(p==1)t.cut(u,v);
else if(p==2)t.link(u,v);
else if(p==3)t.modify(u,v);
else if(p==4)t.query(u,v);
}
return 0;
}