多项式FFT、NTT


【模板】多项式乘法(FFT)

https://www.luogu.com.cn/problem/P3803

//luogu P3808 输入两个多项式系数(低到高)
#include<bits/stdc++.h>
const int N = 4e6 + 10;
const double pi = acos(-1.0);
using namespace std;
typedef long long ll;
struct Complex {
    double x, y;
    Complex(double xx = 0, double yy = 0) { x = xx, y = yy; }
    Complex operator +(const Complex& a) { return Complex(x + a.x, y + a.y); }
    Complex operator -(const Complex& a) { return Complex(x - a.x, y - a.y); }
    Complex operator *(const Complex& a) { return Complex(x * a.x - y * a.y, x * a.y + y * a.x); }
};
int r[N];
Complex a[N], b[N];
void fft(Complex* a, int n, int type) {
    for (int i = 0; i < n; i++) {
        if (i < r[i])
            swap(a[i], a[r[i]]);
    }
    for (int mid = 1; mid < n; mid <<= 1) {
        Complex Wn = Complex(cos(pi / mid), type * sin(pi / mid));
        for (int R = mid << 1, j = 0; j < n; j += R) {
            Complex w(1, 0);
            for (int k = 0; k < mid; k++, w = w * Wn) {
                Complex x = a[j + k], y = w * a[j + mid + k];
                a[j + k] = x + y;
                a[j + mid + k] = x - y;
            }
        }
    }
}
int main() {
    int n, m, x;
    cin >> n >> m;

    for (int i = 0; i <= n; i++) cin >> x, a[i].x = x, a[i].y = 0;
    for (int i = 0; i <= m; i++) cin >> x, b[i].x = x, b[i].y = 0;
    int len = 1, l = 0;
    while (len <= (n + m)) len *= 2, l++;
    for (int i = 0; i < len; i++)
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    fft(a, len, 1);
    fft(b, len, 1);
    for (int i = 0; i <= len; i++) {
        a[i] = a[i] * b[i];
    }
    fft(a, len, -1);
    for (int i = 0; i <= (n + m); i++) printf("%d ", int(a[i].x / (len)+0.5));
    return 0;
}

【模板】A*B Problem 升级版(FFT)

https://www.luogu.com.cn/problem/P1919
给你两个正整数 a , b a,b a,b,求 a × b a \times b a×b
1 ≤ a , b ≤ 1 0 1000000 1≤a,b≤10^{1000000} 1a,b101000000

#include<bits/stdc++.h>
const int N = 4e6 + 10;
const double pi = acos(-1.0);
using namespace std;
typedef long long ll;
struct Complex {
    double x, y;
    Complex(double xx = 0, double yy = 0) { x = xx, y = yy; }
    Complex operator +(const Complex& a) { return Complex(x + a.x, y + a.y); }
    Complex operator -(const Complex& a) { return Complex(x - a.x, y - a.y); }
    Complex operator *(const Complex& a) { return Complex(x * a.x - y * a.y, x * a.y + y * a.x); }
};
int r[N];
Complex a[N], b[N];
void fft(Complex* a, int n, int type) {
    for (int i = 0; i < n; i++) {
        if (i < r[i])
            swap(a[i], a[r[i]]);
    }
    for (int mid = 1; mid < n; mid <<= 1) {
        Complex Wn = Complex(cos(pi / mid), type * sin(pi / mid));
        for (int R = mid << 1, j = 0; j < n; j += R) {
            Complex w(1, 0);
            for (int k = 0; k < mid; k++, w = w * Wn) {
                Complex x = a[j + k], y = w * a[j + mid + k];
                a[j + k] = x + y;
                a[j + mid + k] = x - y;
            }
        }
    }
}
int ans[2000005];
int main() {
    string s, t;
    cin >> s >> t;
    int n = s.size() - 1, m = t.size() - 1;
    a[0].x = 0, b[0].x = 0, a[0].y = 0, b[0].y = 0;
    for (int i = 0; i <= n; i++) a[i].x = s[n - i] - '0', a[i].y = 0;
    for (int i = 0; i <= m; i++) b[i].x = t[m - i] - '0', b[i].y = 0;

    int len = 1, l = 0;
    while (len <= (n + m)) len *= 2, l++;
    for (int i = 0; i < len; i++)
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    fft(a, len, 1);
    fft(b, len, 1);
    for (int i = 0; i <= len; i++) {
        a[i] = a[i] * b[i];
    }
    fft(a, len, -1);

    int carry = 0;

    for (int i = 0; i <= (n + m); i++) {
        carry /= 10;
        carry += int(a[i].x / len + 0.5);
        ans[i] = carry % 10;
    }
    int id = n + m;
    while (carry) {
        carry /= 10;
        ans[++id] = carry;
    }
    while (ans[id] == 0)id--;
    for (int i = id; i >= 0; i--)cout << ans[i];
    return 0;
}

3-idiots(FFT)

题意
给定n个数字,求任取其中3个数,能组成三角形的概率
3 ≤ n ≤ 1 e 5 3\le n\le1e5 3n1e5

思路
处理出 s u m [ i ] sum[i] sum[i]:两个数和为 i i i的方案数
d [ i ] d[i] d[i]为值为 i i i的数字的个数,则 s u m [ k ] = ∑ i = 1 k − 1 d [ i ] ⋅ d [ k − i ] sum[k]=\sum_{i=1}^{k-1}d[i]\sdot d[k-i] sum[k]=i=1k1d[i]d[ki],d的卷积就是sum
用FFT可求得 s u m [ i ] sum[i] sum[i]
重复部分:
x自己和自己(2*x=x+x),一对数被算2次(x+y=y+x)
因此sum[x[i]*2]--,sum[i]/=2

代码

#include<bits/stdc++.h>
#define show(x) cerr<<#x<<" : "<<x<<endl;
const int N = 4e5 + 10;
const double pi = acos(-1.0);
using namespace std;
typedef long long ll;
struct Complex {
    double x, y;
    Complex(double xx = 0, double yy = 0) { x = xx, y = yy; }
    Complex operator +(const Complex& a) { return Complex(x + a.x, y + a.y); }
    Complex operator -(const Complex& a) { return Complex(x - a.x, y - a.y); }
    Complex operator *(const Complex& a) { return Complex(x * a.x - y * a.y, x * a.y + y * a.x); }
};
int r[N];
Complex a[N], b[N];
void fft(Complex* a, int n, int type) {
    for (int i = 0; i < n; i++) {
        if (i < r[i])
            swap(a[i], a[r[i]]);
    }
    for (int mid = 1; mid < n; mid <<= 1) {
        Complex Wn = Complex(cos(pi / mid), type * sin(pi / mid));
        for (int R = mid << 1, j = 0; j < n; j += R) {
            Complex w(1, 0);
            for (int k = 0; k < mid; k++, w = w * Wn) {
                Complex x = a[j + k], y = w * a[j + mid + k];
                a[j + k] = x + y;
                a[j + mid + k] = x - y;
            }
        }
    }
}
ll x[N];
ll sum[N];
ll cnt[N];
void solve() {
    int tot; cin >> tot;
    ll n = 0;

    for (int i = 1; i <= tot; i++) cin >> x[i], a[x[i]].x++, a[x[i]].y = 0, n = max(n, x[i]);
    for (int i = n; i >= 0; i--) cnt[i] = cnt[i + 1] + a[i].x;
    int len = 1, l = 0;
    while (len <= (n + n)) len *= 2, l++;
    for (int i = 0; i < len; i++)
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    fft(a, len, 1);
    //fft(b, len, 1);
    for (int i = 0; i <= len; i++) {
        a[i] = a[i] * a[i];
    }
    fft(a, len, -1);
    for (int i = 0; i <= (n + n); i++) sum[i] = int(a[i].x / (len)+0.5);
    for (int i = 1; i <= tot; i++) sum[x[i] * 2]--;
    for (int i = 0; i <= 2 * n; i++) sum[i] /= 2;
    ll ans = 1ll * tot * (tot - 1) * (tot - 2) / 6;
    ll res = ans;

    for (int i = 1; i <= n; i++) {
        res -= sum[i] * cnt[i];
    }
    printf("%.7lf\n", 1.0 * res / ans);

    for (int i = 0; i <= len; i++)a[i].x = 0, a[i].y = 0, cnt[i] = 0, sum[i] = 0;

    return;
}
signed main()
{
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
    int t; cin >> t; while (t--)
        solve();
}

【模板】NTT

https://www.luogu.com.cn/problem/P3803

代码

// NTT模板 模数只能是998244353
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 4000005;
const int G = 3, mod = 998244353;
int n, m, L, limit = 1, R[maxn];
int A[maxn], B[maxn];
int qpow(int a, int b) {
	int ans = 1;
	for (; b; b >>= 1, a = 1ll * a * a % mod)
		if (b & 1) ans = 1ll * ans * a % mod;
	return ans;
}
void NTT(int* a, int f) {
	for (int i = 0; i < limit; i++) if (i < R[i]) swap(a[i], a[R[i]]);
	for (int i = 1; i < limit; i <<= 1) {
		int gn = qpow(G, (mod - 1) / (i << 1));
		for (int j = 0; j < limit; j += (i << 1)) {
			int g = 1;
			for (int k = 0; k < i; k++, g = 1ll * g * gn % mod) {
				int x = a[j + k], y = 1ll * g * a[j + k + i] % mod;
				a[j + k] = (x + y) % mod; a[j + k + i] = (x - y + mod) % mod;
			}
		}
	} 
	if (f == 1) return;
	int inv = qpow(limit, mod - 2); reverse(a + 1, a + limit);
	for (int i = 0; i <= n + m; i++) a[i] = 1ll * a[i] * inv % mod;
}
void poly_mul(int* a, int* b, int deg) {
	L = 0, limit = 1;
	while (limit <= deg) {
		limit <<= 1;
		L++;
	}
	for (int i = 0; i < limit; i++) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (L - 1));
	NTT(A, 1); NTT(B, 1);
	for (int i = 0; i < limit; i++) A[i] = 1ll * A[i] * B[i] % mod;
	NTT(A, -1);
}
int main() {
	cin >> n >> m;
	for (int i = 0; i <= n; i++) cin >> A[i];
	for (int i = 0; i <= m; i++) cin >> B[i];
	poly_mul(A, B, n + m);
	for (int i = 0; i <= n + m; i++) cout << A[i] << ' ';
	cout << endl;
	return 0;
}
//模板2
#include<bits/stdc++.h>
#define show(x) cerr<<#x<<" : "<<x<<endl;
#define ll long long
#define int ll
using namespace std;
const int maxn = 3e6 + 5;
const int mod = 998244353;
int qpow(int a, int b) {
    int res = 1;
    while (b) {
        if (b & 1)res = res * a % mod;
        b >>= 1;
        a = a * a % mod;
    }
    return res;
}
int r[maxn];
void ntt(vector<int>& a, int lim, int b = 0)
{
	int i, j, k, l, * p, * q, u, v, w;
	for (i = 0; i < lim; ++i)
		if (i < r[i])
			swap(a[i], a[r[i]]);
	for (i = 1; i < lim; i = l)
		for (j = 0, l = i * 2, u = qpow(3, (mod - 1) / l); j < lim; j += l)
			for (k = 0, v = 1, p = &a[j], q = &p[i]; k < i; ++k, v = v * 1ll * u % mod)
			{
				w = v * 1ll * q[k] % mod;
				q[k] = p[k] < w ? p[k] - w + mod : p[k] - w;
				p[k] = (p[k] += w) < mod ? p[k] : p[k] - mod;
			}
	if (b)
		for (i = 0, u = qpow(lim, mod - 2), reverse(&a[1], &a[lim]); i < lim; ++i)
			a[i] = a[i] * 1ll * u % mod;
}
struct poly
{
	int l;
	vector<int> a;
	poly() { l = 0; }
	//poly(int x){for(l=x;~x;--x)a.push_back(1);}
	//bool operator<(poly x)const{return l>x.l;}
	void mul(poly x) {
		int i, lim;
		for (lim = 1, l += x.l; lim <= l; lim *= 2);
		for (i = 1; i < lim; ++i)r[i] = (r[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0);
		a.resize(lim), x.a.resize(lim), ntt(a, lim), ntt(x.a, lim);
		for (i = 0; i < lim; ++i)a[i] = a[i] * 1ll * x.a[i] % mod;
		ntt(a, lim, 1);//ntt(x.a,lim,1);
	}
};

void solve() {
	int n, m; cin >> n >> m;
	poly p1, p2;
	p1.l = n, p2.l = m;
	int x;
	for (int i = 0; i <= n; i++) {
		cin >> x; p1.a.push_back(x);
	}
	for (int i = 0; i <= m; i++) {
		cin >> x; p2.a.push_back(x);
	}
	p1.mul(p2);
	for (int i = 0; i <= n + m; i++)cout << p1.a[i] << ' ';
	
}
signed main() {
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);

    //int t; cin >> t; while (t--)
    solve();
}
/*

*/

【模板】分治FFT(NTT)

https://www.luogu.com.cn/problem/P4721

代码

//给定序列g[1..n-1],求序列f[0...n-1]
//f[i]=\sum^i_{j=1}f[i-j]g[j]  边界f[0]=1
//答案对998244353取模
#include<bits/stdc++.h>
#define ll long long
#define show(x) cerr<<#x<<" : "<<x<<endl;
#define int ll
using namespace std;
const int maxn = 4e6 + 5;
const int G = 3, mod = 998244353;
int n, L, limit = 1, R[maxn];
int A[maxn], B[maxn], f[maxn], g[maxn];
int qpow(int a, int b) {
	int res = 1;
	while (b) {
		if (b & 1)res = res * a % mod;
		b >>= 1;
		a = a * a % mod;
	}
	return res;
}

void NTT(int* a, int f) {
	for (int i = 0; i < limit; i++) if (i < R[i]) swap(a[i], a[R[i]]);
	for (int i = 1; i < limit; i <<= 1) {
		int gn = qpow(G, (mod - 1) / (i << 1));
		for (int j = 0; j < limit; j += (i << 1)) {
			int g = 1;
			for (int k = 0; k < i; k++, g = 1ll * g * gn % mod) {
				int x = a[j + k], y = 1ll * g * a[j + k + i] % mod;
				a[j + k] = (x + y) % mod; a[j + k + i] = (x - y + mod) % mod;
			}
		}
	} 
	if (f == 1) return;
	int inv = qpow(limit, mod - 2); reverse(a + 1, a + limit);
	for (int i = 0; i < limit; i++) a[i] = 1ll * a[i] * inv % mod;
}
void poly_mul(int* a, int* b, int deg) {
	L = 0, limit = 1;
	while (limit <= deg) {
		limit <<= 1;
		L++;
	}
	for (int i = 0; i < limit; i++) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (L - 1));
	NTT(A, 1); NTT(B, 1);
	for (int i = 0; i < limit; i++) A[i] = 1ll * A[i] * B[i] % mod;
	NTT(A, -1);
}
void solve(int l,int r) {  //分治
	if (l == r) {
		//递推式子在这写
		if (l) {

		}
		return;
	}
	int mid = (l + r) >> 1;
	solve(l, mid);
	L = 0, limit = 1;
	while (limit <= (mid - l + r - l))L++, limit <<= 1;
	for (int i = 0; i < limit; i++)A[i] = B[i] = 0;
	for (int i = l; i <= mid; i++)A[i - l] = f[i];
	for (int i = 1; i <= r - l; i++)B[i] = g[i];
	poly_mul(A, B, mid - l + r - l);
	for (int i = mid + 1; i <= r; i++)f[i] = (f[i] + A[i - l]) % mod;
	solve(mid + 1, r);
}
signed main() {
	ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
	cin >> n;
	f[0] = 1;
	for (int i = 1; i < n; i++)cin >> g[i];
	solve(0, n - 1);
	for (int i = 0; i < n; i++)cout << f[i] << ' ';
}

2022牛客多校 Falfa with Substring (容斥+NTT)

https://ac.nowcoder.com/acm/contest/33187/E

题意
F n , k F_{n,k} Fn,k:长度为n的串中"bit"出现次数为k的串的数量
计算 F n , 0 , F n , 1 , . . . , F n , n F_{n,0} ,F_{n,1},...,F_{n,n} Fn,0,Fn,1,...,Fn,n,答案对998244353取模
1 ≤ n ≤ 1 0 6 1\le n\le 10^6 1n106

思路
在这里插入图片描述

代码

#include<bits/stdc++.h>
#define ll long long
#define show(x) cerr<<#x<<" : "<<x<<endl;
#define int ll
using namespace std;
const int maxn = 4e6 + 5;
const int G = 3, mod = 998244353;
int n, L, limit = 1, R[maxn];
int A[maxn], B[maxn], F[maxn], fac[maxn], inv[maxn];
int qpow(int a, int b) {
	int res = 1;
	while (b) {
		if (b & 1)res = res * a % mod;
		b >>= 1;
		a = a * a % mod;
	}
	return res;
}

void NTT(int* a, int f) {
	for (int i = 0; i < limit; i++) if (i < R[i]) swap(a[i], a[R[i]]);
	for (int i = 1; i < limit; i <<= 1) {
		int gn = qpow(G, (mod - 1) / (i << 1));
		for (int j = 0; j < limit; j += (i << 1)) {
			int g = 1;
			for (int k = 0; k < i; k++, g = 1ll * g * gn % mod) {
				int x = a[j + k], y = 1ll * g * a[j + k + i] % mod;
				a[j + k] = (x + y) % mod; a[j + k + i] = (x - y + mod) % mod;
			}
		}
	} 
	if (f == 1) return;
	int inv = qpow(limit, mod - 2); reverse(a + 1, a + limit);
	for (int i = 0; i <= n + n; i++) a[i] = 1ll * a[i] * inv % mod;
}
void poly_mul(int* a, int* b, int deg) {
	L = 0, limit = 1;
	while (limit <= deg) {
		limit <<= 1;
		L++;
	}
	for (int i = 0; i < limit; i++) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (L - 1));
	NTT(A, 1); NTT(B, 1);
	for (int i = 0; i < limit; i++) A[i] = 1ll * A[i] * B[i] % mod;
	NTT(A, -1);
}
int C(int m, int n)
{
	if (m > n)return 0;
	return fac[n] * inv[m] % mod * inv[n - m] % mod;
}
signed main() {
	ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
	cin >> n;
	fac[0] = 1;
	for (int i = 1; i <= n; i++)fac[i] = (fac[i - 1] * i) % mod;
	for (int i = 0; i <= n; i++)inv[i] = qpow(fac[i], mod - 2);
	for (int i = 0; 3 * i <= n; i++)A[i] = fac[i] * C(i, n - 2 * i) % mod * qpow(26, n - 3 * i) % mod;
	for (int i = 0; i <= n; i++)B[i] = ((((n - i) % 2 == 0) ? 1 : -1) * inv[n - i] % mod + mod) % mod;
	poly_mul(A, B, n + n);
	for (int i = 0; i <= n; i++) cout << A[n + i] * inv[i] % mod << ' ';
	cout << endl;
	return 0;
}

2022杭电多校3 Equipment Upgrade(期望dp+分治NTT)

https://acm.hdu.edu.cn/showproblem.php?pid=7162
题意
要将武器从0级升到n级
当前等级为 i ( 0 ≤ i < n ) i(0\le i<n) i(0i<n),升级花费 c i c_i ci,升到 i + 1 i+1 i+1级概率为 p i p_i pi,降到 i − j i-j ij级概率为 ( 1 − p i ) w j ∑ k = 1 i w k (1-p_i)\frac{w_j}{\sum_{k=1}^{i}w_k} (1pi)k=1iwkwj
给定c、p、w数组
求升级到n的期望花费

思路
d p [ i ] dp[i] dp[i]为从0级升到 i i i级的期望花费, p r e w [ i ] prew[i] prew[i] w w w的前缀和
d p [ i + 1 ] = d p [ i ] + c [ i ] + ∑ j = 1 i ( 1 − p [ i ] ) ∑ j = 1 i ( d p [ i + 1 ] − d p [ i − j ] ) w [ j ] ∑ k = 1 i w k dp[i+1]=dp[i]+c[i]+\sum_{j=1}^i(1-p[i])\frac{\sum^i_{j=1}(dp[i+1]-dp[i-j])w[j]}{\sum^i_{k=1}w_k} dp[i+1]=dp[i]+c[i]+j=1i(1p[i])k=1iwkj=1i(dp[i+1]dp[ij])w[j]
化简得到:
d p [ i + 1 ] = d p [ i ] + c [ i ] − ( 1 − p [ i ] ) p r e w [ i ] ∑ j = 1 i d p [ i − j ] w [ j ] dp[i+1]=dp[i]+c[i]-\frac{(1-p[i])}{prew[i]}\sum_{j=1}^idp[i-j]w[j] dp[i+1]=dp[i]+c[i]prew[i](1p[i])j=1idp[ij]w[j]
其中 d p [ i − j ] w [ j ] dp[i-j]w[j] dp[ij]w[j]是卷积形式
分治NTT加速计算即可

代码

#include<bits/stdc++.h>
#define ll long long
#define show(x) cerr<<#x<<" : "<<x<<endl;
#define int ll
using namespace std;
const int maxn = 4e6 + 5;
const int G = 3, mod = 998244353;
int n, L, limit = 1, R[maxn];
int A[maxn], B[maxn], f[maxn];
int w[maxn], prew[maxn], p[maxn], c[maxn], res[maxn];
int qpow(int a, int b) {
    int res = 1;
    while (b) {
        if (b & 1)res = res * a % mod;
        b >>= 1;
        a = a * a % mod;
    }
    return res;
}
int inv(int x) {
    return qpow(x, mod - 2);
}
void NTT(int* a, int f) {
    for (int i = 0; i < limit; i++) if (i < R[i]) swap(a[i], a[R[i]]);
    for (int i = 1; i < limit; i <<= 1) {
        int gn = qpow(G, (mod - 1) / (i << 1));
        for (int j = 0; j < limit; j += (i << 1)) {
            int g = 1;
            for (int k = 0; k < i; k++, g = 1ll * g * gn % mod) {
                int x = a[j + k], y = 1ll * g * a[j + k + i] % mod;
                a[j + k] = (x + y) % mod; a[j + k + i] = (x - y + mod) % mod;
            }
        }
    }
    if (f == 1) return;
    int inv = qpow(limit, mod - 2); reverse(a + 1, a + limit);
    for (int i = 0; i < limit; i++) a[i] = 1ll * a[i] * inv % mod;
}
void poly_mul(int* a, int* b, int deg) {
    L = 0, limit = 1;
    while (limit <= deg) {
        limit <<= 1;
        L++;
    }
    for (int i = 0; i < limit; i++) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (L - 1));
    NTT(A, 1); NTT(B, 1);
    for (int i = 0; i < limit; i++) A[i] = 1ll * A[i] * B[i] % mod;
    NTT(A, -1);
}
void solve(int l, int r) {  //分治
    if (l == r) {
        //递推式子在这写
        if (l) {
            res[l + 1] = ((res[l] + c[l] - ((1 - p[l]) % mod + mod) * inv(prew[l]) % mod * f[l] % mod) % mod + mod) % mod * inv(p[l]) % mod;
        }
        return;
    }
    int mid = (l + r) >> 1;
    solve(l, mid);
    L = 0, limit = 1;
    while (limit <= (mid - l + r - l))L++, limit <<= 1;
    for (int i = 0; i < limit; i++)A[i] = B[i] = 0;
    for (int i = l; i <= mid; i++)A[i - l] = res[i];
    for (int i = 1; i <= r - l; i++)B[i] = w[i];
    poly_mul(A, B, mid - l + r - l);
    for (int i = mid + 1; i <= r; i++)f[i] = (f[i] + A[i - l]) % mod;
    solve(mid + 1, r);
}
signed main() {
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
    int t; cin >> t; while (t--) {
        cin >> n;
        for (int i = 1; i <= n; i++)cin >> p[i - 1] >> c[i - 1], p[i - 1] = p[i - 1] * inv(100) % mod;
        for (int i = 1; i <= n - 1; i++)cin >> w[i];
        for (int i = 1; i <= n - 1; i++)prew[i] = (prew[i - 1] + w[i]) % mod;
        for (int i = 0; i <= n; i++)f[i] = 0;
        res[0] = 0, res[1] = c[0];
        solve(0, n - 1);

        cout << res[n] << '\n';
    }

}

2022广东省赛M (生成函数+分治NTT)

https://pintia.cn/problem-sets/1534086341544497152/problems/1534088931057451020
题意
在这里插入图片描述

思路
基本不等式得:
1 = ∑ i = 1 k x i 2 b i 2 > = k x 1 2 x 2 2 . . . x k 2 b 1 2 b 2 2 . . . b k 2 k 1=\sum^k_{i=1}\frac{x_i^2}{b_i^2}>=k\sqrt[k]{\frac{x_1^2x_2^2...x_k^2}{b_1^2b_2^2...b_k^2}} 1=i=1kbi2xi2>=kkb12b22...bk2x12x22...xk2
化简得到:
∏ i = 1 k b i k k / 2 ≥ ∏ i = 1 k x i \frac{\prod_{i = 1}^{k}b_i}{k^{k/2}}\geq \prod_{i = 1}^kx_i kk/2i=1kbii=1kxi
即求m个数中选k个数,这k个数乘积的期望
求出和再除以C(m,k)就是期望
构造生成函数 F ( x ) = ∏ i = 1 m ( 1 + b i x ) F(x)=\prod^{m}_{i=1}(1+b_ix) F(x)=i=1m(1+bix),其 x k x^k xk的系数即答案
这可以看作m个多项式相乘,第 i i i个多项式系数为{1,b[i]}
分治+NTT求解
复杂度 n l o g l o g nloglog nloglog
代码

#include<bits/stdc++.h>
#define show(x) cerr<<#x<<" : "<<x<<endl;
#define ll long long
#define int ll
using namespace std;
const int maxn = 1e5 + 5;
const int mod = 998244353;
int qpow(int a, int b) {
    int res = 1;
    while (b) {
        if (b & 1)res = res * a % mod;
        b >>= 1;
        a = a * a % mod;
    }
    return res;
}
int fac[maxn];
int inv[maxn];
void init() {
	fac[0] = 1;
	for (int i = 1; i < maxn; i++)
		fac[i] = (fac[i - 1] * i) % mod;
	inv[maxn - 1] = qpow(fac[maxn - 1], mod - 2);
	for (int i = maxn - 2; i >= 0; i--)
		inv[i] = (inv[i + 1] * (i + 1)) % mod;
}
int C(int m, int n)
{
	if (m > n)return -1;
	return fac[n] * inv[m] % mod * inv[n - m] % mod;
}
int r[3 * maxn];
void ntt(vector<int>& a, int lim, int b = 0)
{
	int i, j, k, l, * p, * q, u, v, w;
	for (i = 0; i < lim; ++i)
		if (i < r[i])
			swap(a[i], a[r[i]]);
	for (i = 1; i < lim; i = l)
		for (j = 0, l = i * 2, u = qpow(3, (mod - 1) / l); j < lim; j += l)
			for (k = 0, v = 1, p = &a[j], q = &p[i]; k < i; ++k, v = v * 1ll * u % mod)
			{
				w = v * 1ll * q[k] % mod;
				q[k] = p[k] < w ? p[k] - w + mod : p[k] - w;
				p[k] = (p[k] += w) < mod ? p[k] : p[k] - mod;
			}
	if (b)
		for (i = 0, u = qpow(lim, mod - 2), reverse(&a[1], &a[lim]); i < lim; ++i)
			a[i] = a[i] * 1ll * u % mod;
}
struct poly
{
	int l;
	vector<int> a;
	poly() { l = 0; }
	//poly(int x){for(l=x;~x;--x)a.push_back(1);}
	//bool operator<(poly x)const{return l>x.l;}
	void mul(poly x) {
		int i, lim;
		for (lim = 1, l += x.l; lim <= l; lim *= 2);
		for (i = 1; i < lim; ++i)r[i] = (r[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0);
		a.resize(lim), x.a.resize(lim), ntt(a, lim), ntt(x.a, lim);
		for (i = 0; i < lim; ++i)a[i] = a[i] * 1ll * x.a[i] % mod;
		ntt(a, lim, 1);//ntt(x.a,lim,1);
	}
};
int a[maxn];
poly binary(int l, int r) {
	if (l == r) {
		poly p;
		p.l = 1;
		p.a = { 1,a[l] };
		return p;
	}
	int mid = (l + r) >> 1;
	poly p1 = binary(l, mid), p2 = binary(mid + 1, r);
	p1.mul(p2);
	return p1;
}

void solve() {
	init();
	int n, k; cin >> n >> k;
	for (int i = 1; i <= n; i++)cin >> a[i];
	poly res = binary(1, n);
	cout << res.a[k] * qpow(C(k, n), mod - 2) % mod * qpow(qpow(k, k / 2),mod-2) % mod;
}
signed main() {
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);

    //int t; cin >> t; while (t--)
    solve();
}

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

General.song

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值