传送门
解析:
在看完题目三个条件之后,我们得到需要做的事情就是,钦定两个点为特殊点,它们之间的路径为特殊路径,计算向外伸出小于等于 k k k个互不相交的分叉的方案数。
k = 1 k=1 k=1的情况直接输出答案即可。
现在我们不考虑编号大小,最后方案数 / 2 /2 /2即可。
显然我们需要树形DP。
设我们当前处理到 u u u, c [ i ] c[i] c[i]表示以 u u u为根的子树向子树内伸出 i i i个互不相交的分叉的方案数。
设 u u u的儿子们为 s 1 , s 2 , s 3 . . . s_1,s_2,s_3... s1,s2,s3...,则 c c c生成函数为 C ( x ) = ∏ i ( 1 + s i z s i ) C(x)=\prod\limits_{i}(1+siz_{s_i}) C(x)=i∏(1+sizsi)
直接利用分治乘法可以在 O ( n log 2 n ) O(n\log^2n) O(nlog2n)时间内处理完毕。
所以这部分的答案是 [ x k ] C ( x ) [x^k]C(x) [xk]C(x)?并不,我们发现我们可以令一个端点和 u u u重合。
考虑所有有 i i i个不重合的情况,我们得到答案是 ∑ i = 0 k A ( k , i ) c [ i ] \sum\limits_{i=0}^kA(k,i)c[i] i=0∑kA(k,i)c[i]
显然现在我们已经可以处理出选择的两个关键点 u , v u,v u,v没有祖先后代关系的情况了。
设 f [ u ] f[u] f[u]表示 u u u在子树内部伸出 k k k个互不相交的分叉的方案数, s u m [ u ] sum[u] sum[u]表示 u u u子树内的所有点 f f f值之和。
考虑反着做,所有点对两两配对直接得到 s u m [ 1 ] 2 sum[1]^2 sum[1]2,然后我们去掉所有不合法的情况,也就是两点重合 f [ u ] 2 f[u]^2 f[u]2,或者两点有祖先后代关系 2 ∗ f [ u ] ∗ s u m [ u ] 2*f[u]*sum[u] 2∗f[u]∗sum[u],把不合法的情况减去。
现在我们需要计算两点存在祖先后代关系的方案数。
显然值需要考虑第一个关键点在某个子树中选择,把父亲方向当做另外一棵子树,显然就是上面的多项式 C ( x ) ⋅ 1 + ( n − s i z u x ) 1 + s i z s i x C(x)\cdot\frac{1+(n-siz_ux)}{1+siz_{s_i}x} C(x)⋅1+sizsix1+(n−sizux),显然我们可以在 O ( d e g ) O(deg) O(deg)时间内处理出对于每一种 s i z siz siz的答案,而子树中不同的 s i z siz siz最多只有 O ( s i z u ) O(\sqrt {siz_u}) O(sizu)种,所以复杂度实际上只有 O ( n n ) O(n\sqrt n) O(nn)。
虽然说代码看上去有点长,但是几乎都是板子,没什么细节,主要是不好想。
代码:
#include<bits/stdc++.h>
#define ll long long
#define re register
#define gc get_char
#define cs const
namespace IO{
inline char get_char(){
static cs int Rlen=1<<20|1;
static char buf[Rlen],*p1,*p2;
return (p1==p2)&&(p2=(p1=buf)+fread(buf,1,Rlen,stdin),p1==p2)?EOF:*p1++;
}
inline int getint(){
re char c;
while(!isdigit(c=gc()));re int num=c^48;
while(isdigit(c=gc()))num=(num+(num<<2)<<1)+(c^48);
return num;
}
}
using namespace IO;
using std::cout;
using std::cerr;
cs int mod=998244353,inv2=(mod+1)/2;
inline int add(int a,int b){return a+b>=mod?a+b-mod:a+b;}
inline void Inc(int &a,int b){(a+=b)>=mod?a-=mod:a;}
inline int dec(int a,int b){return a<b?a-b+mod:a-b;}
inline void Dec(int &a,int b){(a-=b)<0?a+=mod:a;}
inline int mul(int a,int b){return (ll)a*b%mod;}
inline int quickpow(int a,int b,int res=1){
while(b){
if(b&1)res=mul(res,a);
a=mul(a,a);
b>>=1;
}
return res;
}
typedef std::vector<int> Poly;
Poly r;
int invl;
inline void NTT(Poly &A,int len,int typ){
if(typ==-1)std::reverse(A.begin()+1,A.begin()+len);
for(int re i=0;i<len;++i)if(i<r[i])std::swap(A[i],A[r[i]]);
for(int re i=1;i<len;i<<=1){
int wn=quickpow(3,(mod-1)/i/2);
for(int re j=0;j<len;j+=i<<1)
for(int re k=0,x,y,w=1;k<i;++k,w=mul(w,wn)){
x=A[j+k],y=mul(w,A[j+k+i]);
A[j+k]=add(x,y);
A[j+k+i]=dec(x,y);
}
}
if(typ==-1)for(int re i=0;i<len;++i)A[i]=mul(A[i],invl);
}
inline void init_rev(int len){
if(len==r.size())return ;
r.resize(len);
for(int re i=0;i<len;++i)r[i]=r[i>>1]>>1|((i&1)*(len>>1));
invl=quickpow(len,mod-2);
}
inline Poly operator*(Poly a,Poly b){
int deg=a.size()+b.size()-1,len=1;
while(len<deg)len<<=1;
init_rev(len);
a.resize(len),NTT(a,len,1);
b.resize(len),NTT(b,len,1);
for(int re i=0;i<len;++i)a[i]=mul(a[i],b[i]);
NTT(a,len,-1),a.resize(deg);
return a;
}
inline void mul_poly(Poly &a,int b){
a.push_back(0);
for(int re i=a.size()-2;~i;--i)Inc(a[i+1],mul(a[i],b));
}
inline void div_poly(Poly &a,int b){
int inv=quickpow(b,mod-2);
Poly d=a;
for(int re i=a.size()-1;i;--i){
a[i-1]=mul(d[i],inv);
Dec(d[i-1],a[i-1]);
}
a.pop_back();
}
cs int N=1e5+5;
int n,k,ans;
int fac[N],ifac[N],inv[N];
inline void init_inv(){
fac[0]=ifac[0]=inv[0]=fac[1]=ifac[1]=inv[1]=1;
for(int re i=2;i<N;++i){
fac[i]=mul(fac[i-1],i);
inv[i]=mul(inv[mod%i],mod-mod/i);
ifac[i]=mul(ifac[i-1],inv[i]);
}
}
inline int A(int n,int m){return mul(fac[n],ifac[n-m]);}
std::vector<int> G[N];
inline void addedge(int u,int v){
G[u].push_back(v);
G[v].push_back(u);
}
int f[N],sum[N],g[N];
int coef[N],cnt;
inline Poly build(int l,int r){
if(l==r)return Poly({1,coef[l]});
int mid=(l+r)>>1;
return build(l,mid)*build(mid+1,r);
}
inline int calc(cs Poly &a){
int lim=std::min((int)a.size()-1,k),res=0;
for(int re i=0;i<=lim;++i)Inc(res,mul(A(k,i),a[i]));
return res;
}
int siz[N];
void dfs(int u,int fa){
siz[u]=1;
for(re int v:G[u])if(v!=fa){
dfs(v,u);
siz[u]+=siz[v];
Inc(sum[u],sum[v]);
}
cnt=0;
for(re int v:G[u])if(v!=fa)coef[++cnt]=siz[v];
if(cnt){
Poly a=build(1,cnt);
f[u]=calc(a);
cnt=0;
for(re int v:G[u])if(v!=fa)coef[++cnt]=v;
std::sort(coef+1,coef+cnt+1,[&](cs int &u,cs int &v){return siz[u]<siz[v];});
mul_poly(a,n-siz[u]);
for(int re i=1;i<=cnt;++i){
if(siz[coef[i]]==siz[coef[i-1]])g[coef[i]]=g[coef[i-1]];
else {
div_poly(a,siz[coef[i]]);
g[coef[i]]=calc(a);
mul_poly(a,siz[coef[i]]);
}
}
}
else f[u]=1;
Inc(sum[u],f[u]);
for(re int v:G[u])if(v!=fa){
Inc(ans,mul(2,mul(dec(g[v],f[u]),sum[v])));
}
Dec(ans,mul(f[u],f[u]));
}
signed main(){
init_inv();
n=getint(),k=getint();
if(k==1)return cout<<((ll)n*(n-1)/2%mod)<<'\n',0;
for(int re i=1;i<n;++i)addedge(getint(),getint());
dfs(1,0);
Inc(ans,mul(sum[1],sum[1]));
cout<<mul(ans,inv2)<<"\n";
return 0;
}