文章目录
【模板】多项式乘法(FFT)
https://www.luogu.com.cn/problem/P3803
//luogu P3808 输入两个多项式系数(低到高)
#include<bits/stdc++.h>
const int N = 4e6 + 10;
const double pi = acos(-1.0);
using namespace std;
typedef long long ll;
struct Complex {
double x, y;
Complex(double xx = 0, double yy = 0) { x = xx, y = yy; }
Complex operator +(const Complex& a) { return Complex(x + a.x, y + a.y); }
Complex operator -(const Complex& a) { return Complex(x - a.x, y - a.y); }
Complex operator *(const Complex& a) { return Complex(x * a.x - y * a.y, x * a.y + y * a.x); }
};
int r[N];
Complex a[N], b[N];
void fft(Complex* a, int n, int type) {
for (int i = 0; i < n; i++) {
if (i < r[i])
swap(a[i], a[r[i]]);
}
for (int mid = 1; mid < n; mid <<= 1) {
Complex Wn = Complex(cos(pi / mid), type * sin(pi / mid));
for (int R = mid << 1, j = 0; j < n; j += R) {
Complex w(1, 0);
for (int k = 0; k < mid; k++, w = w * Wn) {
Complex x = a[j + k], y = w * a[j + mid + k];
a[j + k] = x + y;
a[j + mid + k] = x - y;
}
}
}
}
int main() {
int n, m, x;
cin >> n >> m;
for (int i = 0; i <= n; i++) cin >> x, a[i].x = x, a[i].y = 0;
for (int i = 0; i <= m; i++) cin >> x, b[i].x = x, b[i].y = 0;
int len = 1, l = 0;
while (len <= (n + m)) len *= 2, l++;
for (int i = 0; i < len; i++)
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
fft(a, len, 1);
fft(b, len, 1);
for (int i = 0; i <= len; i++) {
a[i] = a[i] * b[i];
}
fft(a, len, -1);
for (int i = 0; i <= (n + m); i++) printf("%d ", int(a[i].x / (len)+0.5));
return 0;
}
【模板】A*B Problem 升级版(FFT)
https://www.luogu.com.cn/problem/P1919
给你两个正整数
a
,
b
a,b
a,b,求
a
×
b
a \times b
a×b。
1
≤
a
,
b
≤
1
0
1000000
1≤a,b≤10^{1000000}
1≤a,b≤101000000
#include<bits/stdc++.h>
const int N = 4e6 + 10;
const double pi = acos(-1.0);
using namespace std;
typedef long long ll;
struct Complex {
double x, y;
Complex(double xx = 0, double yy = 0) { x = xx, y = yy; }
Complex operator +(const Complex& a) { return Complex(x + a.x, y + a.y); }
Complex operator -(const Complex& a) { return Complex(x - a.x, y - a.y); }
Complex operator *(const Complex& a) { return Complex(x * a.x - y * a.y, x * a.y + y * a.x); }
};
int r[N];
Complex a[N], b[N];
void fft(Complex* a, int n, int type) {
for (int i = 0; i < n; i++) {
if (i < r[i])
swap(a[i], a[r[i]]);
}
for (int mid = 1; mid < n; mid <<= 1) {
Complex Wn = Complex(cos(pi / mid), type * sin(pi / mid));
for (int R = mid << 1, j = 0; j < n; j += R) {
Complex w(1, 0);
for (int k = 0; k < mid; k++, w = w * Wn) {
Complex x = a[j + k], y = w * a[j + mid + k];
a[j + k] = x + y;
a[j + mid + k] = x - y;
}
}
}
}
int ans[2000005];
int main() {
string s, t;
cin >> s >> t;
int n = s.size() - 1, m = t.size() - 1;
a[0].x = 0, b[0].x = 0, a[0].y = 0, b[0].y = 0;
for (int i = 0; i <= n; i++) a[i].x = s[n - i] - '0', a[i].y = 0;
for (int i = 0; i <= m; i++) b[i].x = t[m - i] - '0', b[i].y = 0;
int len = 1, l = 0;
while (len <= (n + m)) len *= 2, l++;
for (int i = 0; i < len; i++)
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
fft(a, len, 1);
fft(b, len, 1);
for (int i = 0; i <= len; i++) {
a[i] = a[i] * b[i];
}
fft(a, len, -1);
int carry = 0;
for (int i = 0; i <= (n + m); i++) {
carry /= 10;
carry += int(a[i].x / len + 0.5);
ans[i] = carry % 10;
}
int id = n + m;
while (carry) {
carry /= 10;
ans[++id] = carry;
}
while (ans[id] == 0)id--;
for (int i = id; i >= 0; i--)cout << ans[i];
return 0;
}
3-idiots(FFT)
题意
给定n个数字,求任取其中3个数,能组成三角形的概率
3
≤
n
≤
1
e
5
3\le n\le1e5
3≤n≤1e5
思路
处理出
s
u
m
[
i
]
sum[i]
sum[i]:两个数和为
i
i
i的方案数
记
d
[
i
]
d[i]
d[i]为值为
i
i
i的数字的个数,则
s
u
m
[
k
]
=
∑
i
=
1
k
−
1
d
[
i
]
⋅
d
[
k
−
i
]
sum[k]=\sum_{i=1}^{k-1}d[i]\sdot d[k-i]
sum[k]=∑i=1k−1d[i]⋅d[k−i],d的卷积就是sum
用FFT可求得
s
u
m
[
i
]
sum[i]
sum[i]
重复部分:
x自己和自己(2*x=x+x),一对数被算2次(x+y=y+x)
因此sum[x[i]*2]--,sum[i]/=2
代码
#include<bits/stdc++.h>
#define show(x) cerr<<#x<<" : "<<x<<endl;
const int N = 4e5 + 10;
const double pi = acos(-1.0);
using namespace std;
typedef long long ll;
struct Complex {
double x, y;
Complex(double xx = 0, double yy = 0) { x = xx, y = yy; }
Complex operator +(const Complex& a) { return Complex(x + a.x, y + a.y); }
Complex operator -(const Complex& a) { return Complex(x - a.x, y - a.y); }
Complex operator *(const Complex& a) { return Complex(x * a.x - y * a.y, x * a.y + y * a.x); }
};
int r[N];
Complex a[N], b[N];
void fft(Complex* a, int n, int type) {
for (int i = 0; i < n; i++) {
if (i < r[i])
swap(a[i], a[r[i]]);
}
for (int mid = 1; mid < n; mid <<= 1) {
Complex Wn = Complex(cos(pi / mid), type * sin(pi / mid));
for (int R = mid << 1, j = 0; j < n; j += R) {
Complex w(1, 0);
for (int k = 0; k < mid; k++, w = w * Wn) {
Complex x = a[j + k], y = w * a[j + mid + k];
a[j + k] = x + y;
a[j + mid + k] = x - y;
}
}
}
}
ll x[N];
ll sum[N];
ll cnt[N];
void solve() {
int tot; cin >> tot;
ll n = 0;
for (int i = 1; i <= tot; i++) cin >> x[i], a[x[i]].x++, a[x[i]].y = 0, n = max(n, x[i]);
for (int i = n; i >= 0; i--) cnt[i] = cnt[i + 1] + a[i].x;
int len = 1, l = 0;
while (len <= (n + n)) len *= 2, l++;
for (int i = 0; i < len; i++)
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
fft(a, len, 1);
//fft(b, len, 1);
for (int i = 0; i <= len; i++) {
a[i] = a[i] * a[i];
}
fft(a, len, -1);
for (int i = 0; i <= (n + n); i++) sum[i] = int(a[i].x / (len)+0.5);
for (int i = 1; i <= tot; i++) sum[x[i] * 2]--;
for (int i = 0; i <= 2 * n; i++) sum[i] /= 2;
ll ans = 1ll * tot * (tot - 1) * (tot - 2) / 6;
ll res = ans;
for (int i = 1; i <= n; i++) {
res -= sum[i] * cnt[i];
}
printf("%.7lf\n", 1.0 * res / ans);
for (int i = 0; i <= len; i++)a[i].x = 0, a[i].y = 0, cnt[i] = 0, sum[i] = 0;
return;
}
signed main()
{
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
int t; cin >> t; while (t--)
solve();
}
【模板】NTT
https://www.luogu.com.cn/problem/P3803
代码
// NTT模板 模数只能是998244353
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 4000005;
const int G = 3, mod = 998244353;
int n, m, L, limit = 1, R[maxn];
int A[maxn], B[maxn];
int qpow(int a, int b) {
int ans = 1;
for (; b; b >>= 1, a = 1ll * a * a % mod)
if (b & 1) ans = 1ll * ans * a % mod;
return ans;
}
void NTT(int* a, int f) {
for (int i = 0; i < limit; i++) if (i < R[i]) swap(a[i], a[R[i]]);
for (int i = 1; i < limit; i <<= 1) {
int gn = qpow(G, (mod - 1) / (i << 1));
for (int j = 0; j < limit; j += (i << 1)) {
int g = 1;
for (int k = 0; k < i; k++, g = 1ll * g * gn % mod) {
int x = a[j + k], y = 1ll * g * a[j + k + i] % mod;
a[j + k] = (x + y) % mod; a[j + k + i] = (x - y + mod) % mod;
}
}
}
if (f == 1) return;
int inv = qpow(limit, mod - 2); reverse(a + 1, a + limit);
for (int i = 0; i <= n + m; i++) a[i] = 1ll * a[i] * inv % mod;
}
void poly_mul(int* a, int* b, int deg) {
L = 0, limit = 1;
while (limit <= deg) {
limit <<= 1;
L++;
}
for (int i = 0; i < limit; i++) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (L - 1));
NTT(A, 1); NTT(B, 1);
for (int i = 0; i < limit; i++) A[i] = 1ll * A[i] * B[i] % mod;
NTT(A, -1);
}
int main() {
cin >> n >> m;
for (int i = 0; i <= n; i++) cin >> A[i];
for (int i = 0; i <= m; i++) cin >> B[i];
poly_mul(A, B, n + m);
for (int i = 0; i <= n + m; i++) cout << A[i] << ' ';
cout << endl;
return 0;
}
//模板2
#include<bits/stdc++.h>
#define show(x) cerr<<#x<<" : "<<x<<endl;
#define ll long long
#define int ll
using namespace std;
const int maxn = 3e6 + 5;
const int mod = 998244353;
int qpow(int a, int b) {
int res = 1;
while (b) {
if (b & 1)res = res * a % mod;
b >>= 1;
a = a * a % mod;
}
return res;
}
int r[maxn];
void ntt(vector<int>& a, int lim, int b = 0)
{
int i, j, k, l, * p, * q, u, v, w;
for (i = 0; i < lim; ++i)
if (i < r[i])
swap(a[i], a[r[i]]);
for (i = 1; i < lim; i = l)
for (j = 0, l = i * 2, u = qpow(3, (mod - 1) / l); j < lim; j += l)
for (k = 0, v = 1, p = &a[j], q = &p[i]; k < i; ++k, v = v * 1ll * u % mod)
{
w = v * 1ll * q[k] % mod;
q[k] = p[k] < w ? p[k] - w + mod : p[k] - w;
p[k] = (p[k] += w) < mod ? p[k] : p[k] - mod;
}
if (b)
for (i = 0, u = qpow(lim, mod - 2), reverse(&a[1], &a[lim]); i < lim; ++i)
a[i] = a[i] * 1ll * u % mod;
}
struct poly
{
int l;
vector<int> a;
poly() { l = 0; }
//poly(int x){for(l=x;~x;--x)a.push_back(1);}
//bool operator<(poly x)const{return l>x.l;}
void mul(poly x) {
int i, lim;
for (lim = 1, l += x.l; lim <= l; lim *= 2);
for (i = 1; i < lim; ++i)r[i] = (r[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0);
a.resize(lim), x.a.resize(lim), ntt(a, lim), ntt(x.a, lim);
for (i = 0; i < lim; ++i)a[i] = a[i] * 1ll * x.a[i] % mod;
ntt(a, lim, 1);//ntt(x.a,lim,1);
}
};
void solve() {
int n, m; cin >> n >> m;
poly p1, p2;
p1.l = n, p2.l = m;
int x;
for (int i = 0; i <= n; i++) {
cin >> x; p1.a.push_back(x);
}
for (int i = 0; i <= m; i++) {
cin >> x; p2.a.push_back(x);
}
p1.mul(p2);
for (int i = 0; i <= n + m; i++)cout << p1.a[i] << ' ';
}
signed main() {
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
//int t; cin >> t; while (t--)
solve();
}
/*
*/
【模板】分治FFT(NTT)
https://www.luogu.com.cn/problem/P4721
代码
//给定序列g[1..n-1],求序列f[0...n-1]
//f[i]=\sum^i_{j=1}f[i-j]g[j] 边界f[0]=1
//答案对998244353取模
#include<bits/stdc++.h>
#define ll long long
#define show(x) cerr<<#x<<" : "<<x<<endl;
#define int ll
using namespace std;
const int maxn = 4e6 + 5;
const int G = 3, mod = 998244353;
int n, L, limit = 1, R[maxn];
int A[maxn], B[maxn], f[maxn], g[maxn];
int qpow(int a, int b) {
int res = 1;
while (b) {
if (b & 1)res = res * a % mod;
b >>= 1;
a = a * a % mod;
}
return res;
}
void NTT(int* a, int f) {
for (int i = 0; i < limit; i++) if (i < R[i]) swap(a[i], a[R[i]]);
for (int i = 1; i < limit; i <<= 1) {
int gn = qpow(G, (mod - 1) / (i << 1));
for (int j = 0; j < limit; j += (i << 1)) {
int g = 1;
for (int k = 0; k < i; k++, g = 1ll * g * gn % mod) {
int x = a[j + k], y = 1ll * g * a[j + k + i] % mod;
a[j + k] = (x + y) % mod; a[j + k + i] = (x - y + mod) % mod;
}
}
}
if (f == 1) return;
int inv = qpow(limit, mod - 2); reverse(a + 1, a + limit);
for (int i = 0; i < limit; i++) a[i] = 1ll * a[i] * inv % mod;
}
void poly_mul(int* a, int* b, int deg) {
L = 0, limit = 1;
while (limit <= deg) {
limit <<= 1;
L++;
}
for (int i = 0; i < limit; i++) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (L - 1));
NTT(A, 1); NTT(B, 1);
for (int i = 0; i < limit; i++) A[i] = 1ll * A[i] * B[i] % mod;
NTT(A, -1);
}
void solve(int l,int r) { //分治
if (l == r) {
//递推式子在这写
if (l) {
}
return;
}
int mid = (l + r) >> 1;
solve(l, mid);
L = 0, limit = 1;
while (limit <= (mid - l + r - l))L++, limit <<= 1;
for (int i = 0; i < limit; i++)A[i] = B[i] = 0;
for (int i = l; i <= mid; i++)A[i - l] = f[i];
for (int i = 1; i <= r - l; i++)B[i] = g[i];
poly_mul(A, B, mid - l + r - l);
for (int i = mid + 1; i <= r; i++)f[i] = (f[i] + A[i - l]) % mod;
solve(mid + 1, r);
}
signed main() {
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
cin >> n;
f[0] = 1;
for (int i = 1; i < n; i++)cin >> g[i];
solve(0, n - 1);
for (int i = 0; i < n; i++)cout << f[i] << ' ';
}
2022牛客多校 Falfa with Substring (容斥+NTT)
https://ac.nowcoder.com/acm/contest/33187/E
题意
F
n
,
k
F_{n,k}
Fn,k:长度为n的串中"bit"出现次数为k的串的数量
计算
F
n
,
0
,
F
n
,
1
,
.
.
.
,
F
n
,
n
F_{n,0} ,F_{n,1},...,F_{n,n}
Fn,0,Fn,1,...,Fn,n,答案对998244353取模
1
≤
n
≤
1
0
6
1\le n\le 10^6
1≤n≤106
思路
代码
#include<bits/stdc++.h>
#define ll long long
#define show(x) cerr<<#x<<" : "<<x<<endl;
#define int ll
using namespace std;
const int maxn = 4e6 + 5;
const int G = 3, mod = 998244353;
int n, L, limit = 1, R[maxn];
int A[maxn], B[maxn], F[maxn], fac[maxn], inv[maxn];
int qpow(int a, int b) {
int res = 1;
while (b) {
if (b & 1)res = res * a % mod;
b >>= 1;
a = a * a % mod;
}
return res;
}
void NTT(int* a, int f) {
for (int i = 0; i < limit; i++) if (i < R[i]) swap(a[i], a[R[i]]);
for (int i = 1; i < limit; i <<= 1) {
int gn = qpow(G, (mod - 1) / (i << 1));
for (int j = 0; j < limit; j += (i << 1)) {
int g = 1;
for (int k = 0; k < i; k++, g = 1ll * g * gn % mod) {
int x = a[j + k], y = 1ll * g * a[j + k + i] % mod;
a[j + k] = (x + y) % mod; a[j + k + i] = (x - y + mod) % mod;
}
}
}
if (f == 1) return;
int inv = qpow(limit, mod - 2); reverse(a + 1, a + limit);
for (int i = 0; i <= n + n; i++) a[i] = 1ll * a[i] * inv % mod;
}
void poly_mul(int* a, int* b, int deg) {
L = 0, limit = 1;
while (limit <= deg) {
limit <<= 1;
L++;
}
for (int i = 0; i < limit; i++) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (L - 1));
NTT(A, 1); NTT(B, 1);
for (int i = 0; i < limit; i++) A[i] = 1ll * A[i] * B[i] % mod;
NTT(A, -1);
}
int C(int m, int n)
{
if (m > n)return 0;
return fac[n] * inv[m] % mod * inv[n - m] % mod;
}
signed main() {
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
cin >> n;
fac[0] = 1;
for (int i = 1; i <= n; i++)fac[i] = (fac[i - 1] * i) % mod;
for (int i = 0; i <= n; i++)inv[i] = qpow(fac[i], mod - 2);
for (int i = 0; 3 * i <= n; i++)A[i] = fac[i] * C(i, n - 2 * i) % mod * qpow(26, n - 3 * i) % mod;
for (int i = 0; i <= n; i++)B[i] = ((((n - i) % 2 == 0) ? 1 : -1) * inv[n - i] % mod + mod) % mod;
poly_mul(A, B, n + n);
for (int i = 0; i <= n; i++) cout << A[n + i] * inv[i] % mod << ' ';
cout << endl;
return 0;
}
2022杭电多校3 Equipment Upgrade(期望dp+分治NTT)
https://acm.hdu.edu.cn/showproblem.php?pid=7162
题意
要将武器从0级升到n级
当前等级为
i
(
0
≤
i
<
n
)
i(0\le i<n)
i(0≤i<n),升级花费
c
i
c_i
ci,升到
i
+
1
i+1
i+1级概率为
p
i
p_i
pi,降到
i
−
j
i-j
i−j级概率为
(
1
−
p
i
)
w
j
∑
k
=
1
i
w
k
(1-p_i)\frac{w_j}{\sum_{k=1}^{i}w_k}
(1−pi)∑k=1iwkwj
给定c、p、w数组
求升级到n的期望花费
思路
记
d
p
[
i
]
dp[i]
dp[i]为从0级升到
i
i
i级的期望花费,
p
r
e
w
[
i
]
prew[i]
prew[i]为
w
w
w的前缀和
d
p
[
i
+
1
]
=
d
p
[
i
]
+
c
[
i
]
+
∑
j
=
1
i
(
1
−
p
[
i
]
)
∑
j
=
1
i
(
d
p
[
i
+
1
]
−
d
p
[
i
−
j
]
)
w
[
j
]
∑
k
=
1
i
w
k
dp[i+1]=dp[i]+c[i]+\sum_{j=1}^i(1-p[i])\frac{\sum^i_{j=1}(dp[i+1]-dp[i-j])w[j]}{\sum^i_{k=1}w_k}
dp[i+1]=dp[i]+c[i]+∑j=1i(1−p[i])∑k=1iwk∑j=1i(dp[i+1]−dp[i−j])w[j]
化简得到:
d
p
[
i
+
1
]
=
d
p
[
i
]
+
c
[
i
]
−
(
1
−
p
[
i
]
)
p
r
e
w
[
i
]
∑
j
=
1
i
d
p
[
i
−
j
]
w
[
j
]
dp[i+1]=dp[i]+c[i]-\frac{(1-p[i])}{prew[i]}\sum_{j=1}^idp[i-j]w[j]
dp[i+1]=dp[i]+c[i]−prew[i](1−p[i])∑j=1idp[i−j]w[j]
其中
d
p
[
i
−
j
]
w
[
j
]
dp[i-j]w[j]
dp[i−j]w[j]是卷积形式
分治NTT加速计算即可
代码
#include<bits/stdc++.h>
#define ll long long
#define show(x) cerr<<#x<<" : "<<x<<endl;
#define int ll
using namespace std;
const int maxn = 4e6 + 5;
const int G = 3, mod = 998244353;
int n, L, limit = 1, R[maxn];
int A[maxn], B[maxn], f[maxn];
int w[maxn], prew[maxn], p[maxn], c[maxn], res[maxn];
int qpow(int a, int b) {
int res = 1;
while (b) {
if (b & 1)res = res * a % mod;
b >>= 1;
a = a * a % mod;
}
return res;
}
int inv(int x) {
return qpow(x, mod - 2);
}
void NTT(int* a, int f) {
for (int i = 0; i < limit; i++) if (i < R[i]) swap(a[i], a[R[i]]);
for (int i = 1; i < limit; i <<= 1) {
int gn = qpow(G, (mod - 1) / (i << 1));
for (int j = 0; j < limit; j += (i << 1)) {
int g = 1;
for (int k = 0; k < i; k++, g = 1ll * g * gn % mod) {
int x = a[j + k], y = 1ll * g * a[j + k + i] % mod;
a[j + k] = (x + y) % mod; a[j + k + i] = (x - y + mod) % mod;
}
}
}
if (f == 1) return;
int inv = qpow(limit, mod - 2); reverse(a + 1, a + limit);
for (int i = 0; i < limit; i++) a[i] = 1ll * a[i] * inv % mod;
}
void poly_mul(int* a, int* b, int deg) {
L = 0, limit = 1;
while (limit <= deg) {
limit <<= 1;
L++;
}
for (int i = 0; i < limit; i++) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (L - 1));
NTT(A, 1); NTT(B, 1);
for (int i = 0; i < limit; i++) A[i] = 1ll * A[i] * B[i] % mod;
NTT(A, -1);
}
void solve(int l, int r) { //分治
if (l == r) {
//递推式子在这写
if (l) {
res[l + 1] = ((res[l] + c[l] - ((1 - p[l]) % mod + mod) * inv(prew[l]) % mod * f[l] % mod) % mod + mod) % mod * inv(p[l]) % mod;
}
return;
}
int mid = (l + r) >> 1;
solve(l, mid);
L = 0, limit = 1;
while (limit <= (mid - l + r - l))L++, limit <<= 1;
for (int i = 0; i < limit; i++)A[i] = B[i] = 0;
for (int i = l; i <= mid; i++)A[i - l] = res[i];
for (int i = 1; i <= r - l; i++)B[i] = w[i];
poly_mul(A, B, mid - l + r - l);
for (int i = mid + 1; i <= r; i++)f[i] = (f[i] + A[i - l]) % mod;
solve(mid + 1, r);
}
signed main() {
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
int t; cin >> t; while (t--) {
cin >> n;
for (int i = 1; i <= n; i++)cin >> p[i - 1] >> c[i - 1], p[i - 1] = p[i - 1] * inv(100) % mod;
for (int i = 1; i <= n - 1; i++)cin >> w[i];
for (int i = 1; i <= n - 1; i++)prew[i] = (prew[i - 1] + w[i]) % mod;
for (int i = 0; i <= n; i++)f[i] = 0;
res[0] = 0, res[1] = c[0];
solve(0, n - 1);
cout << res[n] << '\n';
}
}
2022广东省赛M (生成函数+分治NTT)
https://pintia.cn/problem-sets/1534086341544497152/problems/1534088931057451020
题意
思路
基本不等式得:
1
=
∑
i
=
1
k
x
i
2
b
i
2
>
=
k
x
1
2
x
2
2
.
.
.
x
k
2
b
1
2
b
2
2
.
.
.
b
k
2
k
1=\sum^k_{i=1}\frac{x_i^2}{b_i^2}>=k\sqrt[k]{\frac{x_1^2x_2^2...x_k^2}{b_1^2b_2^2...b_k^2}}
1=∑i=1kbi2xi2>=kkb12b22...bk2x12x22...xk2
化简得到:
∏
i
=
1
k
b
i
k
k
/
2
≥
∏
i
=
1
k
x
i
\frac{\prod_{i = 1}^{k}b_i}{k^{k/2}}\geq \prod_{i = 1}^kx_i
kk/2∏i=1kbi≥∏i=1kxi
即求m个数中选k个数,这k个数乘积的期望
求出和再除以C(m,k)就是期望
构造生成函数
F
(
x
)
=
∏
i
=
1
m
(
1
+
b
i
x
)
F(x)=\prod^{m}_{i=1}(1+b_ix)
F(x)=∏i=1m(1+bix),其
x
k
x^k
xk的系数即答案
这可以看作m个多项式相乘,第
i
i
i个多项式系数为{1,b[i]}
分治+NTT求解
复杂度
n
l
o
g
l
o
g
nloglog
nloglog
代码
#include<bits/stdc++.h>
#define show(x) cerr<<#x<<" : "<<x<<endl;
#define ll long long
#define int ll
using namespace std;
const int maxn = 1e5 + 5;
const int mod = 998244353;
int qpow(int a, int b) {
int res = 1;
while (b) {
if (b & 1)res = res * a % mod;
b >>= 1;
a = a * a % mod;
}
return res;
}
int fac[maxn];
int inv[maxn];
void init() {
fac[0] = 1;
for (int i = 1; i < maxn; i++)
fac[i] = (fac[i - 1] * i) % mod;
inv[maxn - 1] = qpow(fac[maxn - 1], mod - 2);
for (int i = maxn - 2; i >= 0; i--)
inv[i] = (inv[i + 1] * (i + 1)) % mod;
}
int C(int m, int n)
{
if (m > n)return -1;
return fac[n] * inv[m] % mod * inv[n - m] % mod;
}
int r[3 * maxn];
void ntt(vector<int>& a, int lim, int b = 0)
{
int i, j, k, l, * p, * q, u, v, w;
for (i = 0; i < lim; ++i)
if (i < r[i])
swap(a[i], a[r[i]]);
for (i = 1; i < lim; i = l)
for (j = 0, l = i * 2, u = qpow(3, (mod - 1) / l); j < lim; j += l)
for (k = 0, v = 1, p = &a[j], q = &p[i]; k < i; ++k, v = v * 1ll * u % mod)
{
w = v * 1ll * q[k] % mod;
q[k] = p[k] < w ? p[k] - w + mod : p[k] - w;
p[k] = (p[k] += w) < mod ? p[k] : p[k] - mod;
}
if (b)
for (i = 0, u = qpow(lim, mod - 2), reverse(&a[1], &a[lim]); i < lim; ++i)
a[i] = a[i] * 1ll * u % mod;
}
struct poly
{
int l;
vector<int> a;
poly() { l = 0; }
//poly(int x){for(l=x;~x;--x)a.push_back(1);}
//bool operator<(poly x)const{return l>x.l;}
void mul(poly x) {
int i, lim;
for (lim = 1, l += x.l; lim <= l; lim *= 2);
for (i = 1; i < lim; ++i)r[i] = (r[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0);
a.resize(lim), x.a.resize(lim), ntt(a, lim), ntt(x.a, lim);
for (i = 0; i < lim; ++i)a[i] = a[i] * 1ll * x.a[i] % mod;
ntt(a, lim, 1);//ntt(x.a,lim,1);
}
};
int a[maxn];
poly binary(int l, int r) {
if (l == r) {
poly p;
p.l = 1;
p.a = { 1,a[l] };
return p;
}
int mid = (l + r) >> 1;
poly p1 = binary(l, mid), p2 = binary(mid + 1, r);
p1.mul(p2);
return p1;
}
void solve() {
init();
int n, k; cin >> n >> k;
for (int i = 1; i <= n; i++)cin >> a[i];
poly res = binary(1, n);
cout << res.a[k] * qpow(C(k, n), mod - 2) % mod * qpow(qpow(k, k / 2),mod-2) % mod;
}
signed main() {
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
//int t; cin >> t; while (t--)
solve();
}