题面
题解
首先可以看出 “等概率选连通块->连通块内等概率选点” 相当于 “全局等概率选点”。
一开始感觉无从下手,但是题目中还是给了一点提示。
题目让我们输出答案乘 n ! n! n! 后的结果,于是想到枚举一个 1 ∼ n 1\sim n 1∼n 的排列 p i p_i pi 表示依次选择并删除的点的序列。那么对于某一个特定的 p i p_i pi,这种删点方法中所有点被捶的总次数等于 ∑ i = 1 n ( p i 所 在 连 通 块 还 剩 下 的 点 数 ) \sum\limits_{i=1}^n (p_i所在连通块还剩下的点数) i=1∑n(pi所在连通块还剩下的点数)。
转换一下角度,考虑每个点会被捶多少次。于是所有点被捶的总次数又可以表示为: ∑ i = 1 n ∑ j = 1 n [ 删 除 p j 前 p j 与 p i 仍 连 通 ] = ∑ i = 1 n ∑ j = 1 n [ 删 除 j 前 j 与 i 仍 连 通 ] \sum\limits_{i=1}^n\sum\limits_{j=1}^n[删除p_j前p_j与p_i仍连通]=\sum\limits_{i=1}^n\sum\limits_{j=1}^n[删除j前j与i仍连通] i=1∑nj=1∑n[删除pj前pj与pi仍连通]=i=1∑nj=1∑n[删除j前j与i仍连通]。
其中 “删除j前j与i仍连通” 可以巧妙地转化为 “ i i i 到 j j j 路径上所有点在 p p p 序列中出现的位置(相当于被删除的时间)都比 j j j 后”,即如果设 t p i = i t_{p_i}=i tpi=i,就有 t j = min v ∈ p a t h ( i , j ) t v t_j=\min\limits_{v\in path(i,j)}t_v tj=v∈path(i,j)mintv,其中 p a t h ( i , j ) path(i,j) path(i,j) 表示 i i i 到 j j j 的路径。
考虑某个点 i i i 在所有 p p p 的排列中被捶的总次数,这之中 t j = min v ∈ p a t h ( i , j ) t v t_j=\min\limits_{v\in path(i,j)}t_v tj=v∈path(i,j)mintv 的概率是 1 d i s t ( i , j ) \dfrac{1}{dist(i,j)} dist(i,j)1,其中 d i s t ( i , j ) dist(i,j) dist(i,j) 表示 p a t h ( i , j ) path(i,j) path(i,j) 集合的大小,即 i i i 到 j j j 路径上的总点数。
于是我们要求的就是 ∑ i = 1 n ∑ j = 1 n 1 d i s t ( i , j ) \sum\limits_{i=1}^n\sum\limits_{j=1}^n\dfrac{1}{dist(i,j)} i=1∑nj=1∑ndist(i,j)1。
对于每一个 d i s ∈ [ 1 , n ] dis\in [1,n] dis∈[1,n] 求出 d i s t ( i , j ) = d i s dist(i,j)=dis dist(i,j)=dis 的点对 ( i , j ) (i,j) (i,j) 数,使用点分治+FFT即可。
#include<bits/stdc++.h>
#define LN 19
#define N 100010
#define INF 0x7fffffff
using namespace std;
namespace modular
{
const int mod=1000000007;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
}using namespace modular;
inline int poww(int a,int b)
{
int ans=1;
while(b)
{
if(b&1) ans=mul(ans,a);
a=mul(a,a);
b>>=1;
}
return ans;
}
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^'0');
ch=getchar();
}
return x*f;
}
const double pi=acos(-1);
typedef vector<int> poly;
struct Complex
{
double x,y;
Complex(){};
Complex(double a,double b){x=a,y=b;}
}F[N<<2],w[LN][N<<2][2];
Complex operator + (Complex a,Complex b){return Complex(a.x+b.x,a.y+b.y);}
Complex operator - (Complex a,Complex b){return Complex(a.x-b.x,a.y-b.y);}
Complex operator * (Complex a,Complex b){return Complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
int n;
int cnt,head[N],nxt[N<<1],to[N<<1];
int nn,rt,maxn,size[N],fa[N];
int sum[N<<1];
bool vis[N];
void init(int limit)
{
for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
{
Complex gn(cos(pi/mid),sin(pi/mid));
Complex ign(cos(pi/mid),-sin(pi/mid));
Complex g(1,0),ig(1,0);
for(int j=0;j<mid;j++,g=g*gn,ig=ig*ign)
w[bit][j][0]=g,w[bit][j][1]=ig;
}
}
void adde(int u,int v)
{
to[++cnt]=v;
nxt[cnt]=head[u];
head[u]=cnt;
}
void getsize(int u,int fa)
{
size[u]=1;
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(vis[v]||v==fa) continue;
getsize(v,u);
size[u]+=size[v];
}
}
void getroot(int u,int fa)
{
int nmax=0;
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(vis[v]||v==fa) continue;
getroot(v,u);
nmax=max(nmax,size[v]);
}
nmax=max(nmax,nn-size[u]);
if(nmax<maxn) rt=u,maxn=nmax;
}
int maxdis;
void getdis(int u,int fa,int dis)
{
F[dis].x++;
maxdis=max(maxdis,dis);
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(vis[v]||v==fa) continue;
getdis(v,u,dis+1);
}
}
int rev[N<<2];
void FFT(Complex *a,int limit,int opt)
{
opt=(opt<0);
for(int i=0;i<limit;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)*(limit>>1));
for(int i=0;i<limit;i++)
if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
{
for(int i=0,len=mid<<1;i<limit;i+=len)
{
for(int j=0;j<mid;j++)
{
Complex x=a[i+j],y=w[bit][j][opt]*a[i+mid+j];
a[i+j]=x+y,a[i+mid+j]=x-y;
}
}
}
if(opt)
for(int i=0;i<limit;i++)
a[i].x/=limit;
}
void calc(int u,int dis,int tag)
{
maxdis=0,getdis(u,0,dis);
int limit=1;
while(limit<=(maxdis<<1)) limit<<=1;
FFT(F,limit,1);
for(int i=0;i<limit;i++)
F[i]=F[i]*F[i];
FFT(F,limit,-1);
for(int i=0;i<limit;i++) sum[i+1]+=tag*(int)(F[i].x+0.5);
for(int i=0;i<limit;i++) F[i]=Complex(0,0);
}
void solve(int u)
{
vis[u]=1;
calc(u,0,1);
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(vis[v]) continue;
calc(v,1,-1);
getsize(v,0);
nn=size[v],maxn=INF,getroot(v,0);
fa[rt]=u;
solve(rt);
}
}
int main()
{
n=read();
int limit=1;
while(limit<=(n<<1)) limit<<=1;
init(limit);
for(int i=1;i<n;i++)
{
int u=read(),v=read();
adde(u,v),adde(v,u);
}
getsize(1,0);
nn=size[1],maxn=INF,getroot(1,0);
solve(rt);
int fac=1;
for(int i=1;i<=n;i++) fac=mul(fac,i);
int ans=0;
for(int i=1;i<=n;i++)
ans=add(ans,mul(sum[i],poww(i,mod-2)));
printf("%d\n",mul(ans,fac));
return 0;
}
/*
3
1 2
2 3
*/