Description
众所周知,DH是一位人生赢家,他不仅能虐暴全场,而且还正在走向人生巅峰;
在巅峰之路上,他碰到了这一题:
给出一棵n个节点的树,我们每次随机染黑一个叶子节点(可以重复染黑),操作无限次后,这棵树的所有叶子节点必然全部会被染成黑色。
定义R为这棵树不经过黑点的直径,求使R第一次变小期望的步数。
Data Constraint
对于15%的数据,满足n<=10;
对于30%的数据,满足n<=1000;
另外有20%数据,满足树为菊花图;
另外有15%数据,满足每个点度数不超过3;
对于100%的数据,满足n<=5*10^5。
题解
考虑直径的长度,如果是奇数就是过定边,偶数过定点。
那么什么时候直径变化呢,
就是只剩下一个集合的时候。
于是枚举一个集合,
枚举它被染黑的个数,
染黑就可以得到概率了。
code
#include<cstdio>
#include<algorithm>
#include<cstring>
#define ll long long
using namespace std;
const int N=500003,mo=998244353;
int n,nxt[N*2],to[N*2],lst[N],x,y,d[N],m,sum,tot;
int mx,id,fa[N],w[N],t;
ll jc[N],ny[N],s[N],ans;
char ch;
void read(int&n)
{
for(ch=getchar();ch<'0' || ch>'9';ch=getchar());
for(n=0;'0'<=ch && ch<='9';ch=getchar())n=(n<<1)+(n<<3)+ch-48;
}
void write(int x){if(x>9)write(x/10);putchar(x%10+48);}
int max(int x,int y){return x>y?x:y;}
void ins(int x,int y)
{
nxt[++tot]=lst[x];
to[tot]=y;
lst[x]=tot;
}
void dfs(int x,int len)
{
if(len>mx)mx=len,id=x;
for(int i=lst[x];i;i=nxt[i])
if(to[i]^fa[x])fa[to[i]]=x,dfs(to[i],len+1);
}
int work(int x,int fa,int dep)
{
if(dep==1)return 1;
int s=0;
for(int i=lst[x];i;i=nxt[i])
if(to[i]^fa)s=s+work(to[i],x,dep-1);
return s;
}
ll ksm(ll x,int y)
{
ll s=1;
for(;y;y>>=1,x=x*x%mo)
if(y&1)s=s*x%mo;
return s;
}
ll C(int x,int y)
{
return jc[y]*ny[x]%mo*ny[y-x]%mo;
}
ll get(int v)
{
ll S=0;
for(int i=0;i<v;i++)
S=(S+C(i,v)*jc[sum-v+i-1]%mo*jc[v-i]%mo*(sum-v)%mo*s[sum-v+i])%mo;
return S;
}
int main()
{
freopen("winer.in","r",stdin);
freopen("winer.out","w",stdout);
read(n);ny[0]=jc[0]=1;
for(int i=1;i<=n;i++)jc[i]=jc[i-1]*i%mo;
ny[n]=ksm(jc[n],mo-2);
for(int i=n;i;i--)ny[i-1]=ny[i]*i%mo;
for(int i=1;i<n;i++)
read(x),read(y),ins(x,y),ins(y,x),d[x]++,d[y]++;
for(int i=1;i<=n;i++)if(d[i]==1)m++;
mx=0;dfs(1,0);
memset(fa,0,sizeof(fa));
mx=0;dfs(id,0);
for(int i=1;i<=mx/2;i++)id=fa[id];
if(mx&1)
{
w[t=work(id,fa[id],mx/2+1)]++;
sum=sum+t;
w[t=work(fa[id],id,mx/2+1)]++;
sum=sum+t;
}
else
{
for(int i=lst[id];i;i=nxt[i])
w[t=work(to[i],id,mx>>1)]++,sum=sum+t;
}
for(int i=1;i<=sum;i++)
s[i]=(s[i-1]+ksm(sum-i+1,mo-2)*m)%mo;
for(int i=0;i<=n;i++)if(w[i])ans=(ans+get(i)*w[i])%mo;
printf("%lld",ans*ksm(jc[sum],mo-2)%mo);
}