前言
多项式求逆与求幂在生成函数中有着广泛的应用可以用来解决OGF和EGF的计数问题
一、求逆例题 p4283
二、思路及代码
1.思路
很单纯的一道模板题,直接上代码:
2.代码
代码如下:
#include <iostream>
using namespace std;
#define int long long
const int maxn = 1e7 + 7;
const int mod = 998244353;
const int g = 3; // 原根g
int n;
int N, len;
int rev[maxn];
int a[maxn], b[maxn], _c[maxn]; // 全局变量_c 用于求逆
int _A[maxn], _B[maxn], _C[maxn];
// 全局变量_A, _B 用于polyln // 全局变量_C 用于polyexp
int quickpow(int a, int n) {
int ans = 1;
while (n) {
if (n & 1) ans = ans * a % mod;
n >>= 1;
a = a * a % mod;
}
return ans;
}
int getinv(int a) { return quickpow(a, mod - 2); }
void NTT(int a[], int deg, int inv) {
N = 1, len = 0;
while (N < deg) N <<= 1, len++;
for (int i = 0; i < N; i++) // 每次均更新
rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (len - 1)));
for (int i = 0; i < N; i++) // 0 1 2 3 4 5 6 7 -> 0 4 2 6 1 5 3 7
if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int k = 1; k < N; k <<= 1) {
int wn = quickpow(g, (mod - 1) / (2 * k));
if (inv == -1) wn = getinv(wn);
for (int i = 0; i < N; i += 2 * k) {
int w = 1, x, y; // butterfly
for (int j = 0; j < k; j++) {
x = a[i + j], y = w * a[i + j + k] % mod;
a[i + j] = (x + y + mod) % mod, a[i + j + k] = (x - y + mod) % mod;
w = w * wn % mod;
}
}
}
if (inv == -1) {
int val = getinv(N);
for (int i = 0; i < N; i++) a[i] = a[i] * val % mod;
}
}
void polyinv(int a[], int b[], int deg) {
if (deg == 1) {
b[0] = getinv(a[0]);
return;
}
polyinv(a, b, (deg + 1) >> 1);
N = 1, len = 0;
while (N <= (deg << 1)) N <<= 1, len++;
for (int i = 0; i < N; i++) // 每次均更新
rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (len - 1)));
for (int i = 0; i < N; i++) _c[i] = i < deg ? a[i] : 0;
NTT(b, deg << 1, 1);
NTT(_c, deg << 1, 1);
for (int i = 0; i < N; i++)
b[i] = (2ll - _c[i] * b[i] % mod + mod) * b[i] % mod;
NTT(b, deg << 1, -1);
for (int i = deg; i < N; i++) b[i] = 0; // 递归计算,注意归0
for (int i = 0; i < N; i++) _c[i] = 0; // 递归计算,注意归0
}
void polydif(int a[], int b[], int deg) {
for (int i = 1; i < deg; i++) b[i - 1] = a[i] * i % mod;
b[deg - 1] = 0; // 微分
}
void polyint(int a[], int b[], int deg) {
for (int i = 1; i < deg; i++) b[i] = a[i - 1] * getinv(i) % mod;
b[0] = 0; // 积分, !important 注意与polyinv的区别
}
void polymul(int a[], int b[], int deg) {
N = 1, len = 0;
while (N <= deg) N <<= 1, len++;
for (int i = 0; i < N; i++) // 每次均更新
rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (len - 1)));
NTT(a, deg << 1, 1);
NTT(b, deg << 1, 1);
for (int i = 0; i < deg << 1; i++) a[i] = a[i] * b[i] % mod;
NTT(a, deg << 1, -1);
}
void polyln(int a[], int b[], int deg) {
polydif(a, _A, deg);
polyinv(a, _B, deg);
polymul(_A, _B, deg);
polyint(_A, b, deg << 1);
for (int i = 0; i < deg << 1; i++) _A[i] = _B[i] = 0;
}
void polyexp(int a[], int b[], int deg) {
if (deg == 1) return (void)(b[0] = 1);
polyexp(a, b, (deg + 1) >> 1);
polyln(b, _C, deg);
_C[0] = (a[0] + 1 - _C[0] + mod) % mod;
for (int i = 1; i < deg; i++) _C[i] = (a[i] - _C[i] + mod) % mod;
polymul(b, _C, deg);
for (int i = deg; i < (deg << 1); i++) b[i] = _C[i] = 0;
}
signed main() {
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
scanf("%lld", &n);
for (int i = 0; i < n; i++) scanf("%lld", &a[i]);
polyinv(a, b, n);
for (int i = 0; i < n; i++) printf("%lld ", b[i]);
printf("\n");
return 0;
}
三、求幂例题 p4726
四、思路及代码
1.思路
模板题,套模板
2.代码
代码如下:
#include <iostream>
using namespace std;
#define int long long
const int maxn = 1e7 + 7;
const int mod = 998244353;
const int g = 3; // 原根g
int n;
int N, len;
int rev[maxn];
int a[maxn], b[maxn], _c[maxn]; // 全局变量_c 用于求逆
int _A[maxn], _B[maxn], _C[maxn];
// 全局变量_A, _B 用于polyln // 全局变量_C 用于polyexp
int quickpow(int a, int n) {
int ans = 1;
while (n) {
if (n & 1) ans = ans * a % mod;
n >>= 1;
a = a * a % mod;
}
return ans;
}
int getinv(int a) { return quickpow(a, mod - 2); }
void NTT(int a[], int deg, int inv) {
N = 1, len = 0;
while (N < deg) N <<= 1, len++;
for (int i = 0; i < N; i++) // 每次均更新
rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (len - 1)));
for (int i = 0; i < N; i++) // 0 1 2 3 4 5 6 7 -> 0 4 2 6 1 5 3 7
if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int k = 1; k < N; k <<= 1) {
int wn = quickpow(g, (mod - 1) / (2 * k));
if (inv == -1) wn = getinv(wn);
for (int i = 0; i < N; i += 2 * k) {
int w = 1, x, y; // butterfly
for (int j = 0; j < k; j++) {
x = a[i + j], y = w * a[i + j + k] % mod;
a[i + j] = (x + y + mod) % mod, a[i + j + k] = (x - y + mod) % mod;
w = w * wn % mod;
}
}
}
if (inv == -1) {
int val = getinv(N);
for (int i = 0; i < N; i++) a[i] = a[i] * val % mod;
}
}
void polyinv(int a[], int b[], int deg) {
if (deg == 1) {
b[0] = getinv(a[0]);
return;
}
polyinv(a, b, (deg + 1) >> 1);
N = 1, len = 0;
while (N <= (deg << 1)) N <<= 1, len++;
for (int i = 0; i < N; i++) // 每次均更新
rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (len - 1)));
for (int i = 0; i < N; i++) _c[i] = i < deg ? a[i] : 0;
NTT(b, deg << 1, 1);
NTT(_c, deg << 1, 1);
for (int i = 0; i < N; i++)
b[i] = (2ll - _c[i] * b[i] % mod + mod) * b[i] % mod;
NTT(b, deg << 1, -1);
for (int i = deg; i < N; i++) b[i] = 0; // 递归计算,注意归0
for (int i = 0; i < N; i++) _c[i] = 0; // 递归计算,注意归0
}
void polydif(int a[], int b[], int deg) {
for (int i = 1; i < deg; i++) b[i - 1] = a[i] * i % mod;
b[deg - 1] = 0; // 微分
}
void polyint(int a[], int b[], int deg) {
for (int i = 1; i < deg; i++) b[i] = a[i - 1] * getinv(i) % mod;
b[0] = 0; // 积分, !important 注意与polyinv的区别
}
void polymul(int a[], int b[], int deg) {
N = 1, len = 0;
while (N <= deg) N <<= 1, len++;
for (int i = 0; i < N; i++) // 每次均更新
rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (len - 1)));
NTT(a, deg << 1, 1);
NTT(b, deg << 1, 1);
for (int i = 0; i < deg << 1; i++) a[i] = a[i] * b[i] % mod;
NTT(a, deg << 1, -1);
}
void polyln(int a[], int b[], int deg) {
polydif(a, _A, deg);
polyinv(a, _B, deg);
polymul(_A, _B, deg);
polyint(_A, b, deg << 1);
for (int i = 0; i < deg << 1; i++) _A[i] = _B[i] = 0;
}
void polyexp(int a[], int b[], int deg) {
if (deg == 1) return (void)(b[0] = 1);
polyexp(a, b, (deg + 1) >> 1);
polyln(b, _C, deg);
_C[0] = (a[0] + 1 - _C[0] + mod) % mod;
for (int i = 1; i < deg; i++) _C[i] = (a[i] - _C[i] + mod) % mod;
polymul(b, _C, deg);
for (int i = deg; i < (deg << 1); i++) b[i] = _C[i] = 0;
}
signed main() {
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
scanf("%lld", &n);
for (int i = 0; i < n; i++) scanf("%lld", &a[i]);
N = 1, len = 0;
while (N <= n) N <<= 1, len++;
polyexp(a, b, N);
for (int i = 0; i < n; i++) printf("%lld ", b[i]);
printf("\n");
return 0;
}