题意:
给出一棵树,边长度都是1。每次任意取出两个点(u,v),他们之间的长度为素数的概率为多大?
思路:
树分治,对于每个根出发记录边的长度出现几次,然后求卷积,用素数表查一下即可添加答案。
时间复杂度:nlognlogn
#include<bits/stdc++.h>
using namespace std;
const int N=5e4+100;
struct Edge{
int to,next;
}e[N*2];
int size[N],n,cnt,Count,root,f[N],head[N],tot,pri[N],vis[N];
bool Del[N];
void init(){
memset(head,-1,sizeof(head));
tot=0,cnt=0;
for(int i=2;i<=n;i++)
if(!vis[i]){
pri[++cnt]=i;
for(int j=i+i;j<=n;j+=i) vis[j]=1;
}
}
void addedge(int from,int to){
e[tot]=(Edge){to,head[from]};
head[from]=tot++;
}
void getroot(int u,int pre){
f[u]=0,size[u]=1;
for(int i=head[u];i!=-1;i=e[i].next){
int v=e[i].to;
if(v==pre||Del[v]) continue;
getroot(v,u);
size[u]+=size[v];
f[u]=max(f[u],size[v]);
}
f[u]=max(f[u],Count-size[u]);
if(f[u]<f[root]) root=u;
}
long long ans;
int mx_deep,Cnt[N];
void dfs(int u,int dep,int pre){
Cnt[dep]++;
mx_deep=max(mx_deep,dep);
for(int i=head[u];i!=-1;i=e[i].next){
int v=e[i].to;
if(v==pre||Del[v]) continue;
dfs(v,dep+1,u);
}
}
const double PI=acos(-1.0);
struct Complex{
double x,y;
Complex(double _x=0,double _y=0){
x=_x;
y=_y;
}
Complex operator -(const Complex &b)const{
return Complex(x-b.x,y-b.y);
}
Complex operator +(const Complex &b)const{
return Complex(x+b.x,y+b.y);
}
Complex operator *(const Complex &b)const{
return Complex(x*b.x-y*b.y,x*b.y+y*b.x);
}
};
void change(Complex y[],int len){
int i,j,k;
for(i=1,j=len/2;i<len-1;i++){
if(i<j)
swap(y[i],y[j]);
k=len/2;
while(j>=k){
j-=k;
k/=2;
}
if(j<k) j+=k;
}
}
void fft(Complex y[],int len,int on){
change(y,len);
for(int h=2;h<=len;h<<=1){
Complex wn(cos(-on*2*PI/h),sin(-on*2*PI/h));
for(int j=0;j<len;j+=h){
Complex w(1,0);
for(int k=j;k<j+h/2;k++){
Complex u=y[k];
Complex t=w*y[k+h/2];
y[k]=u+t;
y[k+h/2]=u-t;
w=w*wn;
}
}
}
if(on==-1)
for(int i=0;i<len;i++)
y[i].x/=len;
}
const int MAXN=200010;
Complex x1[MAXN],x2[MAXN];
long long cal(int u,int dep){
long long tmp=0;
mx_deep=0;
dfs(u,dep,0);
int len1=mx_deep+1,len2=mx_deep+1,len=1;
while(len<len1*2||len<len2*2) len<<=1;
for(int i=0;i<len1;i++)
x1[i]=Complex(Cnt[i],0),x2[i]=Complex(Cnt[i],0);
for(int i=len1;i<len;i++)
x1[i]=Complex(0,0),x2[i]=Complex(0,0);
fft(x1,len,1);
fft(x2,len,1);
for(int i=0;i<len;i++)
x1[i]=x1[i]*x2[i];
fft(x1,len,-1);
for(int i=1;i<=cnt&&pri[i]<len;i++)
tmp+=(long long)(x1[pri[i]].x+0.5);
for(int i=0;i<=mx_deep;i++) Cnt[i]=0;
return tmp;
}
void work(int u){
Del[u]=true;
long long tmp=cal(u,0);
ans+=tmp;
//printf("PPPPPP%d %lld\n",u,tmp);
for(int i=head[u];i!=-1;i=e[i].next){
int v=e[i].to;
if(!Del[v]){
tmp=cal(v,1);
ans-=tmp;
//printf("PPPPPP%d %d %lld\n",v,u,tmp);
Count=f[0]=size[v];
getroot(v,root=0);
work(root);
}
}
}
int main(){
int u,v;
scanf("%d",&n);
init();
for(int i=1;i<n;i++){
scanf("%d%d",&u,&v);
addedge(u,v);
addedge(v,u);
}
Count=f[0]=n,ans=0;
memset(Del,false,sizeof(Del));
getroot(1,root=0);
work(root);
//printf("%lld\n",ans);
printf("%.6f\n",1.0*ans/n/(n-1));
return 0;
}