题意:
给定nnn个连通块,每个连通块的大小为aia_iai,接下来依次连n−1n-1n−1条边,得到的树TTT的价值定义为:
val(T)=(∏i=1ndim)(∑i=1ndim)val(T)=\left(\prod_{i=1}^nd_i^m\right)\left(\sum_{i=1}^nd_i^m\right)val(T)=(i=1∏ndim)(i=1∑ndim)
其中,did_idi表示与第iii个连通块连接的边的条数。请求出所有不同连边方式产生的树的价值和膜998244353998244353998244353.
n≤30,000,m≤30n\le30,000,m\le30n≤30,000,m≤30
前置技能:求数列kkk次方和。
给定kkk,对于任意的0≤t≤k0\le t\le k0≤t≤k,求出∑i=1nait\sum\limits_{i=1}^na_i^ti=1∑nait。k,n≤105k,n\le 10^5k,n≤105
考虑答案的生成函数F(x)=∑t=0kxt∑i=1nait=∑i=1n11−xaiF(x)=\sum\limits_{t=0}^kx^t\sum\limits_{i=1}^na_i^t=\sum\limits_{i=1}^n\frac{1}{1-xa_i}F(x)=t=0∑kxti=1∑nait=i=1∑n1−xai1.
直接计算仍然是不行的,注意到ln′(1−aix)=−ai1−aix=−∑t=0∞(aix)tai\ln'(1-a_ix)=\frac{-a_i}{1-a_ix}=-\sum\limits_{t=0}^\infty(a_ix)^ta_iln′(1−aix)=1−aix−ai=−t=0∑∞(aix)tai
因此考虑先计算G(x)=−∑t=0kxt∑i=1nait+1G(x)=-\sum\limits_{t=0}^kx^t\sum\limits_{i=1}^na_i^{t+1}G(x)=−t=0∑kxti=1∑nait+1,则F(x)=−xG(x)+nF(x)=-xG(x)+nF(x)=−xG(x)+n。化简G(x)G(x)G(x):
G(x)=∑i=1nln′(1−aix)=ln′(∏i=1n(1−aix))G(x)=\sum\limits_{i=1}^n\ln'(1-a_ix)=\ln'\left(\prod_{i=1}^n(1-a_ix)\right)G(x)=i=1∑nln′(1−aix)=ln′(i=1∏n(1−aix))
括号内的东西分治NTT即可,然后多项式求ln再求导,即可得到F(x)F(x)F(x)。
过程
对于每个终方案TTT,对答案的贡献为∏i=1naididim∑i=1ndim\prod\limits_{i=1}^na_i^{d_i}d_i^m\sum\limits_{i=1}^nd_i^mi=1∏naididimi=1∑ndim。
由于出现了度数,我们考虑使用prufer序列化简算式。总贡献等价于:
(n−2)!∑∑di=n−2∏i=1naidi+1di!(di+1)m∑i=1n(di+1)m=(n−2)!∏i=1nai∑∑di=n−2∏i=1naididi!(di+1)m∑i=1n(di+1)m(n-2)!\sum_{\sum d_i=n-2}\prod_{i=1}^n\frac{a_i^{d_i+1}}{d_i!}(d_i+1)^m\sum_{i=1}^n(d_i+1)^m \\
=(n-2)!\prod_{i=1}^na_i\sum_{\sum d_i=n-2}\prod_{i=1}^n\frac{a_i^{d_i}}{d_i!}(d_i+1)^m\sum_{i=1}^n(d_i+1)^m(n−2)!∑di=n−2∑i=1∏ndi!aidi+1(di+1)mi=1∑n(di+1)m=(n−2)!i=1∏nai∑di=n−2∑i=1∏ndi!aidi(di+1)mi=1∑n(di+1)m
前面的(n−2)!∏i=1nai(n-2)!\prod\limits_{i=1}^na_i(n−2)!i=1∏nai是常量,我们不需要关注。考虑后面的东西,它等价于:
∑i=1naididi!(di+1)2m∏j=1,j≠inajdjdj!(dj+1)m\sum_{i=1}^n\frac{a_i^{d_i}}{d_i!}(d_i+1)^{2m}\prod_{j=1,j\neq i}^n\frac{a_j^{d_j}}{d_j!}(d_j+1)^mi=1∑ndi!aidi(di+1)2mj=1,j̸=i∏ndj!ajdj(dj+1)m
考虑构建上式关于∑di\sum d_i∑di的生成函数F(x)F(x)F(x)。考虑下述的两个多项式:
A(x)=∑ixi(i+1)mi!A(x)=\sum_i \frac{x^i(i+1)^m}{i!}A(x)=i∑i!xi(i+1)m
B(x)=∑ixi(i+1)2mi!B(x)=\sum_i \frac{x^i(i+1)^{2m}}{i!}B(x)=i∑i!xi(i+1)2m
则有:
F(x)=∑iB(aix)∏j≠iA(ajx)=∑iB(aix)A(aix)∏jA(ajx)=∑iB(aix)A(aix)exp∑jlnA(ajx)F(x)=\sum_i B(a_ix)\prod_{j\neq i}A(a_jx)=\sum_i\frac{B(a_ix)}{A(a_ix)}\prod_jA(a_jx) \\
=\sum_i\frac{B(a_ix)}{A(a_ix)}\exp\sum_j\ln A(a_jx)F(x)=i∑B(aix)j̸=i∏A(ajx)=i∑A(aix)B(aix)j∏A(ajx)=i∑A(aix)B(aix)expj∑lnA(ajx)
也就是说求出B(x)A(x)和lnA(x)\frac{B(x)}{A(x)}和\ln A(x)A(x)B(x)和lnA(x)后,需要对于每一项乘上∑aik\sum a_i^k∑aik,这正是我们前面说过可以在O(nlog2n)O(nlog^2n)O(nlog2n)的时间内求出的东西。因此最后的复杂度为O(nlog2n)O(nlog^2n)O(nlog2n)。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 65540, mod = 998244353;
ll ta[maxn], tb[maxn], tc[maxn];
ll modpow(ll a, int b) {
ll res = 1;
for (; b; b >>= 1) {
if (b & 1) res = res * a % mod;
a = a * a % mod;
}
return res;
}
void rader(ll *a, int n) {
for (int i = 1, j = n >> 1; i < n - 1; i++) {
if (i < j) swap(a[i], a[j]);
int k = n >> 1;
for (; j >= k; k >>= 1) j -= k;
if (j < k) j += k;
}
}
void ntt(ll *a, int n, int rev) {
rader(a, n);
for (int h = 2; h <= n; h <<= 1) {
ll wn = modpow(3, rev ? mod - 1 - (mod - 1) / h : (mod - 1) / h);
int hh = h >> 1;
for (int i = 0; i < n; i += h)
for (int j = i, w = 1; j < i + hh; j++, w = w * wn % mod) {
const int x = a[j], y = a[j + hh] * w % mod;
a[j] = (x + y) % mod;
a[j + hh] = (x - y + mod) % mod;
}
}
if (rev) {
int inv = modpow(n, mod - 2);
for (int i = 0; i < n; i++) a[i] = a[i] * inv % mod;
}
}
void get_inv(ll *a, ll *b, int n) {
if (n == 1) { b[0] = modpow(a[0], mod - 2); b[1] = 0; return; }
get_inv(a, b, n >> 1);
int m = n << 1;
for (int i = n; i < m; i++) ta[i] = b[i] = 0;
for (int i = 0; i < n; i++) ta[i] = a[i];
ntt(ta, m, 0), ntt(b, m, 0);
for (int i = 0; i < m; i++) b[i] = (mod + 2 - ta[i] * b[i] % mod) * b[i] % mod;
ntt(b, m, 1);
for (int i = n; i < m; i++) b[i] = 0;
}
void get_ln(ll *a, ll *b, int n) {
get_inv(a, tb, n);
for (int i = 1; i < n; i++) b[i - 1] = a[i] * i % mod;
b[n - 1] = 0;
int m = n << 1;
for (int i = n; i < m; i++) b[i] = 0;
ntt(b, m, 0), ntt(tb, m, 0);
for (int i = 0; i < m; i++) b[i] = b[i] * tb[i] % mod;
ntt(b, m, 1);
for (int i = n; i < m; i++) b[i] = 0;
for (int i = n - 1; i > 0; i--) b[i] = b[i - 1] * modpow(i, mod - 2) % mod;
b[0] = 0;
}
void get_exp(ll *a, ll *b, int n) {
if (n == 1) { b[0] = 1, b[1] = 0; return; }
get_exp(a, b, n >> 1);
get_ln(b, tc, n);
int m = n << 1;
for (int i = n; i < m; i++) b[i] = ta[i] = 0;
for (int i = 0; i < n; i++) ta[i] = (mod + !i + a[i] - tc[i]) % mod;
ntt(b, m, 0), ntt(ta, m, 0);
for (int i = 0; i < m; i++) b[i] = b[i] * ta[i] % mod;
ntt(b, m, 1);
for (int i = n; i < m; i++) b[i] = 0;
}
ll fac[maxn], rev[maxn], A[maxn], B[maxn], C[maxn], sum[maxn];
int sz[maxn], n, m;
vector<ll> divide(int l, int r) {
if (l == r) { vector<ll> vec; vec.push_back(1); vec.push_back(mod - sz[l]); return vec; }
int mid = (l + r) >> 1, len = 1;
vector<ll> vl = divide(l, mid), vr = divide(mid + 1, r);
while (len <= r - l) len <<= 1; len <<= 1;
for (int i = 0; i < len; i++) {
ta[i] = tb[i] = 0;
if (i <= mid - l + 1) ta[i] = vl[i];
if (i <= r - mid) tb[i] = vr[i];
}
ntt(ta, len, 0), ntt(tb, len, 0);
for (int i = 0; i < len; i++) ta[i] = ta[i] * tb[i] % mod;
ntt(ta, len, 1);
vector<ll> res;
for (int i = 0; i <= r - l + 1; i++) res.push_back(ta[i]);
return res;
}
int main() {
scanf("%d%d", &n, &m);
if (n == 1) return puts("1") * 0;
for (int i = fac[0] = 1; i <= n; i++) fac[i] = fac[i - 1] * i % mod;
rev[n] = modpow(fac[n], mod - 2);
for (int i = n; i > 0; i--) rev[i - 1] = rev[i] * i % mod;
for (int i = 1; i <= n; i++) scanf("%d", sz + i);
vector<ll> vec = divide(1, n);
memset(tc, 0, sizeof(tc));
for (int i = 0; i <= n; i++) tc[i] = vec[i];
int len = 1;
while (len <= n) len <<= 1;
get_ln(tc, sum, len);
for (int i = 1; i <= n; i++) sum[i] = mod - sum[i] * i % mod;
sum[0] = n;
for (int i = 0; i < n; i++) {
A[i] = modpow(i + 1, m) * rev[i] % mod;
B[i] = modpow(i + 1, 2 * m) * rev[i] % mod;
}
get_ln(A, tc, len);
get_inv(A, C, len);
ntt(C, len << 1, 0), ntt(B, len << 1, 0);
for (int i = 0; i < len << 1; i++) B[i] = B[i] * C[i] % mod;
ntt(B, len << 1, 1);
memset(A, 0, sizeof(A));
memset(C, 0, sizeof(C));
for (int i = 0; i < n; i++) {
B[i] = B[i] * sum[i] % mod;
A[i] = tc[i] * sum[i] % mod;
}
get_exp(A, C, len);
for (int i = n; i < len << 1; i++) B[i] = 0;
ntt(B, len << 1, 0), ntt(C, len << 1, 0);
for (int i = 0; i < len << 1; i++) B[i] = B[i] * C[i] % mod;
ntt(B, len << 1, 1);
ll res = B[n - 2] * fac[n - 2] % mod;
for (int i = 1; i <= n; i++) res = res * sz[i] % mod;
printf("%lld\n", res);
return 0;
}