清华集训2017 生成树计数

探讨了在给定连通块的情况下,通过特定算法计算所有可能生成的树的价值总和,利用生成函数和多项式运算优化计算过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

题意:

给定nnn个连通块,每个连通块的大小为aia_iai,接下来依次连n−1n-1n1条边,得到的树TTT的价值定义为:
val(T)=(∏i=1ndim)(∑i=1ndim)val(T)=\left(\prod_{i=1}^nd_i^m\right)\left(\sum_{i=1}^nd_i^m\right)val(T)=(i=1ndim)(i=1ndim)
其中,did_idi表示与第iii个连通块连接的边的条数。请求出所有不同连边方式产生的树的价值和膜998244353998244353998244353.
n≤30,000,m≤30n\le30,000,m\le30n30,000,m30

前置技能:求数列kkk次方和。

给定kkk,对于任意的0≤t≤k0\le t\le k0tk,求出∑i=1nait\sum\limits_{i=1}^na_i^ti=1naitk,n≤105k,n\le 10^5k,n105
考虑答案的生成函数F(x)=∑t=0kxt∑i=1nait=∑i=1n11−xaiF(x)=\sum\limits_{t=0}^kx^t\sum\limits_{i=1}^na_i^t=\sum\limits_{i=1}^n\frac{1}{1-xa_i}F(x)=t=0kxti=1nait=i=1n1xai1.
直接计算仍然是不行的,注意到ln⁡′(1−aix)=−ai1−aix=−∑t=0∞(aix)tai\ln'(1-a_ix)=\frac{-a_i}{1-a_ix}=-\sum\limits_{t=0}^\infty(a_ix)^ta_iln(1aix)=1aixai=t=0(aix)tai
因此考虑先计算G(x)=−∑t=0kxt∑i=1nait+1G(x)=-\sum\limits_{t=0}^kx^t\sum\limits_{i=1}^na_i^{t+1}G(x)=t=0kxti=1nait+1,则F(x)=−xG(x)+nF(x)=-xG(x)+nF(x)=xG(x)+n。化简G(x)G(x)G(x)
G(x)=∑i=1nln⁡′(1−aix)=ln⁡′(∏i=1n(1−aix))G(x)=\sum\limits_{i=1}^n\ln'(1-a_ix)=\ln'\left(\prod_{i=1}^n(1-a_ix)\right)G(x)=i=1nln(1aix)=ln(i=1n(1aix))
括号内的东西分治NTT即可,然后多项式求ln再求导,即可得到F(x)F(x)F(x)

过程

对于每个终方案TTT,对答案的贡献为∏i=1naididim∑i=1ndim\prod\limits_{i=1}^na_i^{d_i}d_i^m\sum\limits_{i=1}^nd_i^mi=1naididimi=1ndim
由于出现了度数,我们考虑使用prufer序列化简算式。总贡献等价于:
(n−2)!∑∑di=n−2∏i=1naidi+1di!(di+1)m∑i=1n(di+1)m=(n−2)!∏i=1nai∑∑di=n−2∏i=1naididi!(di+1)m∑i=1n(di+1)m(n-2)!\sum_{\sum d_i=n-2}\prod_{i=1}^n\frac{a_i^{d_i+1}}{d_i!}(d_i+1)^m\sum_{i=1}^n(d_i+1)^m \\ =(n-2)!\prod_{i=1}^na_i\sum_{\sum d_i=n-2}\prod_{i=1}^n\frac{a_i^{d_i}}{d_i!}(d_i+1)^m\sum_{i=1}^n(d_i+1)^m(n2)!di=n2i=1ndi!aidi+1(di+1)mi=1n(di+1)m=(n2)!i=1naidi=n2i=1ndi!aidi(di+1)mi=1n(di+1)m
前面的(n−2)!∏i=1nai(n-2)!\prod\limits_{i=1}^na_i(n2)!i=1nai是常量,我们不需要关注。考虑后面的东西,它等价于:
∑i=1naididi!(di+1)2m∏j=1,j≠inajdjdj!(dj+1)m\sum_{i=1}^n\frac{a_i^{d_i}}{d_i!}(d_i+1)^{2m}\prod_{j=1,j\neq i}^n\frac{a_j^{d_j}}{d_j!}(d_j+1)^mi=1ndi!aidi(di+1)2mj=1,j̸=indj!ajdj(dj+1)m
考虑构建上式关于∑di\sum d_idi的生成函数F(x)F(x)F(x)。考虑下述的两个多项式:
A(x)=∑ixi(i+1)mi!A(x)=\sum_i \frac{x^i(i+1)^m}{i!}A(x)=ii!xi(i+1)m
B(x)=∑ixi(i+1)2mi!B(x)=\sum_i \frac{x^i(i+1)^{2m}}{i!}B(x)=ii!xi(i+1)2m
则有:
F(x)=∑iB(aix)∏j≠iA(ajx)=∑iB(aix)A(aix)∏jA(ajx)=∑iB(aix)A(aix)exp⁡∑jln⁡A(ajx)F(x)=\sum_i B(a_ix)\prod_{j\neq i}A(a_jx)=\sum_i\frac{B(a_ix)}{A(a_ix)}\prod_jA(a_jx) \\ =\sum_i\frac{B(a_ix)}{A(a_ix)}\exp\sum_j\ln A(a_jx)F(x)=iB(aix)j̸=iA(ajx)=iA(aix)B(aix)jA(ajx)=iA(aix)B(aix)expjlnA(ajx)
也就是说求出B(x)A(x)和ln⁡A(x)\frac{B(x)}{A(x)}和\ln A(x)A(x)B(x)lnA(x)后,需要对于每一项乘上∑aik\sum a_i^kaik,这正是我们前面说过可以在O(nlog2n)O(nlog^2n)O(nlog2n)的时间内求出的东西。因此最后的复杂度为O(nlog2n)O(nlog^2n)O(nlog2n)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int maxn = 65540, mod = 998244353;
ll ta[maxn], tb[maxn], tc[maxn];
ll modpow(ll a, int b) {
	ll res = 1;
	for (; b; b >>= 1) {
		if (b & 1) res = res * a % mod;
		a = a * a % mod;
	}
	return res;
}
void rader(ll *a, int n) {
	for (int i = 1, j = n >> 1; i < n - 1; i++) {
		if (i < j) swap(a[i], a[j]);
		int k = n >> 1;
		for (; j >= k; k >>= 1) j -= k;
		if (j < k) j += k;
	}
}
void ntt(ll *a, int n, int rev) {
	rader(a, n);
	for (int h = 2; h <= n; h <<= 1) {
		ll wn = modpow(3, rev ? mod - 1 - (mod - 1) / h : (mod - 1) / h);
		int hh = h >> 1;
		for (int i = 0; i < n; i += h)
		for (int j = i, w = 1; j < i + hh; j++, w = w * wn % mod) {
			const int x = a[j], y = a[j + hh] * w % mod;
			a[j] = (x + y) % mod;
			a[j + hh] = (x - y + mod) % mod;
		}
	}
	if (rev) {
		int inv = modpow(n, mod - 2);
		for (int i = 0; i < n; i++) a[i] = a[i] * inv % mod;
	}
}
void get_inv(ll *a, ll *b, int n) {
	if (n == 1) { b[0] = modpow(a[0], mod - 2); b[1] = 0; return; }
	get_inv(a, b, n >> 1);
	int m = n << 1;
	for (int i = n; i < m; i++) ta[i] = b[i] = 0;
	for (int i = 0; i < n; i++) ta[i] = a[i];
	ntt(ta, m, 0), ntt(b, m, 0);
	for (int i = 0; i < m; i++) b[i] = (mod + 2 - ta[i] * b[i] % mod) * b[i] % mod;
	ntt(b, m, 1);
	for (int i = n; i < m; i++) b[i] = 0;
}
void get_ln(ll *a, ll *b, int n) {
	get_inv(a, tb, n);
	for (int i = 1; i < n; i++) b[i - 1] = a[i] * i % mod;
	b[n - 1] = 0;
	int m = n << 1;
	for (int i = n; i < m; i++) b[i] = 0;
	ntt(b, m, 0), ntt(tb, m, 0);
	for (int i = 0; i < m; i++) b[i] = b[i] * tb[i] % mod;
	ntt(b, m, 1);
	for (int i = n; i < m; i++) b[i] = 0;
	for (int i = n - 1; i > 0; i--) b[i] = b[i - 1] * modpow(i, mod - 2) % mod;
	b[0] = 0;
}
void get_exp(ll *a, ll *b, int n) {
	if (n == 1) { b[0] = 1, b[1] = 0; return; }
	get_exp(a, b, n >> 1);
	get_ln(b, tc, n);
	int m = n << 1;
	for (int i = n; i < m; i++) b[i] = ta[i] = 0;
	for (int i = 0; i < n; i++) ta[i] = (mod + !i + a[i] - tc[i]) % mod;
	ntt(b, m, 0), ntt(ta, m, 0);
	for (int i = 0; i < m; i++) b[i] = b[i] * ta[i] % mod;
	ntt(b, m, 1);
	for (int i = n; i < m; i++) b[i] = 0;
}
ll fac[maxn], rev[maxn], A[maxn], B[maxn], C[maxn], sum[maxn];
int sz[maxn], n, m;
vector<ll> divide(int l, int r) {
	if (l == r) { vector<ll> vec; vec.push_back(1); vec.push_back(mod - sz[l]); return vec; }
	int mid = (l + r) >> 1, len = 1;
	vector<ll> vl = divide(l, mid), vr = divide(mid + 1, r);
	while (len <= r - l) len <<= 1; len <<= 1;
	for (int i = 0; i < len; i++) {
		ta[i] = tb[i] = 0;
		if (i <= mid - l + 1) ta[i] = vl[i];
		if (i <= r - mid) tb[i] = vr[i];
	}
	ntt(ta, len, 0), ntt(tb, len, 0);
	for (int i = 0; i < len; i++) ta[i] = ta[i] * tb[i] % mod;
	ntt(ta, len, 1);
	vector<ll> res;
	for (int i = 0; i <= r - l + 1; i++) res.push_back(ta[i]);
	return res;
}
int main() {
	scanf("%d%d", &n, &m);
	if (n == 1) return puts("1") * 0;
	for (int i = fac[0] = 1; i <= n; i++) fac[i] = fac[i - 1] * i % mod;
	rev[n] = modpow(fac[n], mod - 2);
	for (int i = n; i > 0; i--) rev[i - 1] = rev[i] * i % mod;
	for (int i = 1; i <= n; i++) scanf("%d", sz + i);
	vector<ll> vec = divide(1, n);
	memset(tc, 0, sizeof(tc));
	for (int i = 0; i <= n; i++) tc[i] = vec[i];
	int len = 1;
	while (len <= n) len <<= 1;
	get_ln(tc, sum, len);
	for (int i = 1; i <= n; i++) sum[i] = mod - sum[i] * i % mod;
	sum[0] = n;
	for (int i = 0; i < n; i++) {
		A[i] = modpow(i + 1, m) * rev[i] % mod;
		B[i] = modpow(i + 1, 2 * m) * rev[i] % mod;
	}
	get_ln(A, tc, len);
	get_inv(A, C, len);
	ntt(C, len << 1, 0), ntt(B, len << 1, 0);
	for (int i = 0; i < len << 1; i++) B[i] = B[i] * C[i] % mod;
	ntt(B, len << 1, 1);
	memset(A, 0, sizeof(A));
	memset(C, 0, sizeof(C));
	for (int i = 0; i < n; i++) {
		B[i] = B[i] * sum[i] % mod;
		A[i] = tc[i] * sum[i] % mod;
	}
	get_exp(A, C, len);
	for (int i = n; i < len << 1; i++) B[i] = 0;
	ntt(B, len << 1, 0), ntt(C, len << 1, 0);
	for (int i = 0; i < len << 1; i++) B[i] = B[i] * C[i] % mod;
	ntt(B, len << 1, 1);
	ll res = B[n - 2] * fac[n - 2] % mod;
	for (int i = 1; i <= n; i++) res = res * sz[i] % mod;
	printf("%lld\n", res);
	return 0;
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值