文章目录
拉格朗日插值法
1.算法分析
给定 n n n个点 P ( x i , y i ) P(x_i, y_i) P(xi,yi), 将过这 n n n个点的最多 n − 1 n - 1 n−1次多项式记为 f ( X ) f(X) f(X),求 f ( k ) f(k) f(k), 答案模上 998244353 998244353 998244353返回。
拉格朗日插值的公式 f ( x ) f(x) f(x)为:
$ f(x)=\sum_{i=1}^{n} y_i\prod_{j\neq i }\frac{x-x_j}{x_i-x_j} $
对于某个点 x = k x = k x=k的 f ( k ) f(k) f(k)为:
$ f(k)=\sum_{i=1}^{n} y_i\prod_{j\neq i }\frac{k-x_j}{x_i-x_j} $
比如:现在已经有三个点 ( 2 , 5 ) , ( 3 , 7 ) , ( 9 , 11 ) (2, 5), (3, 7), (9, 11) (2,5),(3,7),(9,11),那么利用公式可以得到:
f ( k ) = 5 ∗ ( k − 3 ) ∗ ( k − 9 ) ( 2 − 3 ) ∗ ( 2 − 9 ) + 7 ∗ ( k − 2 ) ∗ ( k − 9 ) ( 3 − 2 ) ∗ ( 3 − 9 ) + 11 ∗ ( k − 2 ) ∗ ( k − 3 ) ( 9 − 2 ) ∗ ( 9 − 3 ) f(k) = 5 * \frac{(k-3)*(k-9)}{(2-3)*(2-9)} + 7*\frac{(k-2)*(k-9)}{(3-2)*(3-9)}+11*\frac{(k-2)*(k-3)}{(9-2)*(9-3)} f(k)=5∗(2−3)∗(2−9)(k−3)∗(k−9)+7∗(3−2)∗(3−9)(k−2)∗(k−9)+11∗(9−2)∗(9−3)(k−2)∗(k−3)
如果给定的点为连续的点,那么可以使用前缀后缀优化。
例如给定 ( 0 , f 0 ) , ( 1 , f 1 ) , ( 2 , f 2 ) , . . . , ( n , f n ) (0, f_0), (1, f_1), (2, f_2), ..., (n, f_n) (0,f0),(1,f1),(2,f2),...,(n,fn)
则记: p r e [ i ] = ∏ j = 0 i ( k − j ) , s u f f [ i ] = ∏ j = i n ( k − j ) , f a c t [ i ] = i ! pre[i] = \prod_{j=0}^{i}(k - j), suff[i] = \prod_{j=i}^{n}(k-j), fact[i] = i! pre[i]=∏j=0i(k−j),suff[i]=∏j=in(k−j),fact[i]=i!
则$ f(k)=\sum_{i=0}^{n} y_i\frac{pre[i-1]*suff[i+1]}{fact[i]*fact[n-i]} , 但 是 分 母 , 但是分母 ,但是分母fact[n-i]$可能出现负数,需要去判断,当 n − i n-i n−i 为奇数时,分母应该取负号。这样就可以预处理得到 f a c t fact fact数组,然后根据输入的k来 O ( n ) O(n) O(n)计算得到 p r e 、 s u f f pre、suff pre、suff数组,从而 O ( n ) O(n) O(n)完成 f ( k ) f(k) f(k)计算。
2.模板
2.1 随机给定n个点插值: O ( n 2 l o g n ) O(n^2log_n) O(n2logn)
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int MAXN = 2010;
int mod = 998244353;
int n, k, x[MAXN], y[MAXN], res, s1, s2;
int qmi(int a, int k, int p) {
int res = 1 % p; // res记录答案, 模上p是为了防止k为0,p为1的特殊情况
while(k) { // 只要还有剩下位数
if (k & 1) res = res * a % p; // 判断最后一位是否为1,如果为1就乘上a,模上p, 乘法时可能爆int,所以变成long long
k >>= 1; // 右移一位
a = a * a % p; // 当前a等于上一次的a平方,取模,平方时可能爆int,所以变成long long
}
return res;
}
int get_inv(int a, int p) {
return a % p == 0? -1: qmi(a, p - 2, p);
}
// n个点(x[i], y[i]),求f(k)
int Lagrange(int x[], int y[], int n, int k) {
int res = 0;
for (int i = 1; i <= n; i++) {
s1 = y[i] % mod; // s1计算分子
s2 = 1; // s2计算分母
for (int j = 1; j <= n; j++)
if (i != j)
s1 = s1 * (k - x[j]) % mod, s2 = s2 * (x[i] - x[j]) % mod; // 分别计算分子和分母
res += s1 * get_inv(s2, mod) % mod;
}
return (res % mod + mod) % mod;
}
signed main() {
scanf("%lld%lld", &n, &k);
for (int i = 1; i <= n; i++) scanf("%lld%lld", x + i, y + i); // 读入n个点
cout << Lagrange(x, y, n, k);
return 0;
}
2.2 给定连续n个点插值: O ( n ) O(n) O(n)
void init(int n, int mod) { // 预处理逆元
LL res = 1;
for (int i = 1; i <= n + 2; ++i) res = res * i % mod; // 计算阶乘
fact[n + 2] = qmi(res, mod - 2, mod);
for (int i = n + 1; i >= 0; --i) fact[i] = (i + 1) * fact[i + 1] % mod; // 计算逆元
}
//给定最高次为n的多项式的n+1个点分别为:(0,f0),(1,f1),(2,f2)...(n,fn),求f(k)的值
LL Lagrange(LL f[], int n, int k) {
if (k <= n) return f[k];
pre[0] = suff[n + 1] = 1;
for (int i = 0; i <= n; ++i) pre[i + 1] = pre[i] * (k - i) % mod;
for (int i = n; i >= 0; --i) suff[i] = suff[i + 1] * (k - i) % mod;
LL fk = 0;
for (int i = 0; i <= n; ++i) {
LL tmp = f[i] * pre[i] % mod * suff[i + 1] % mod * fact[i] % mod * fact[n - i] % mod; // 计算 (pre[i-1]*suff[i+1]) / (fact[i]*fact[n-i])
if ((n - i) & 1) // 当 n-i 为奇数时,分母应该取负号。
fk = (fk - tmp + mod) % mod;
else
fk = (fk + tmp) % mod;
}
return fk;
}
int main() {
...
init(MAXN - 10, mod); // 预处理逆元
for (int i = 0; i <= n; ++i) cin >> f[i]; // 读入n+1个点
f[n + 1] = Lagrange(f, n, n + 1); // 得到f[n+1]
...
}
3.典型例题
P4781 【模板】拉格朗日插值
题意: 给定 n n n个点 P ( x i , y i ) P(x_i, y_i) P(xi,yi), 将过这 n n n个点的最多 n − 1 n - 1 n−1次多项式记为 f ( X ) f(X) f(X),求 f ( k ) f(k) f(k), 答案模上 998244353 998244353 998244353返回。 1 < = n < = 2 ∗ 1 0 3 , x i 1 <= n <= 2*10^3, x_i 1<=n<=2∗103,xi两两不同。
题解: 按照拉格朗日插值法公式得到:$ f(x)=\sum_{i=1}^{n} y_i\prod_{j\neq i }\frac{x-x_j}{x_i-x_j} , 那 么 带 入 x = k , 得 到 , 那么带入x = k,得到 ,那么带入x=k,得到 f(k)=\sum_{i=1}^{n} y_i\prod_{j\neq i }\frac{k-x_j}{x_i-x_j} , 那 么 就 可 以 ,那么就可以 ,那么就可以O(n^2log_n)$计算式子,其中log为求逆元的复杂度。
代码:
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int MAXN = 2010;
int mod = 998244353;
int n, k, x[MAXN], y[MAXN], res, s1, s2;
int qmi(int a, int k, int p) {
int res = 1 % p; // res记录答案, 模上p是为了防止k为0,p为1的特殊情况
while(k) { // 只要还有剩下位数
if (k & 1) res = res * a % p; // 判断最后一位是否为1,如果为1就乘上a,模上p, 乘法时可能爆int,所以变成long long
k >>= 1; // 右移一位
a = a * a % p; // 当前a等于上一次的a平方,取模,平方时可能爆int,所以变成long long
}
return res;
}
int get_inv(int a, int p) {
return a % p == 0? -1: qmi(a, p - 2, p);
}
int Lagrange(int x[], int y[], int n, int k) {
int res = 0;
for (int i = 1; i <= n; i++) {
s1 = y[i] % mod; // s1计算分子
s2 = 1; // s2计算分母
for (int j = 1; j <= n; j++)
if (i != j)
s1 = s1 * (k - x[j]) % mod, s2 = s2 * (x[i] - x[j]) % mod;
res += s1 * get_inv(s2, mod) % mod;
}
return (res % mod + mod) % mod;
}
signed main() {
scanf("%lld%lld", &n, &k);
for (int i = 1; i <= n; i++) scanf("%lld%lld", x + i, y + i); // 读入n个点
cout << Lagrange(x, y, n, k);
return 0;
}
2019 南昌邀请赛 B. Polynomial
题意: 给出一个 n n n次多项式的前 n + 1 n+1 n+1项的值: f 0 , f 1 , f 2 , . . . , f n f_0, f_1, f_2, ..., f_n f0,f1,f2,...,fn,求 ∑ i = L R f ( i ) m o d 9999991 \sum_{i=L}^{R}f(i)\ mod\ 9999991 ∑i=LRf(i) mod 9999991
题解: ∑ i = L R f ( i ) \sum_{i=L}^{R}f(i) ∑i=LRf(i) 转化为: S R − S L − 1 , S i S_R - S_{L-1}, S_i SR−SL−1,Si为前i项的前缀和。那么求出每次询问求出 S R S_R SR 和 S L − 1 S_{L-1} SL−1即可。而 S i S_i Si是 n + 1 n+1 n+1次多项式,那么我们需要 n + 2 n+2 n+2个点。因此,可以使用拉格朗日插值法先求出 f n + 1 f_{n+1} fn+1,然后就能得到 n + 2 n+2 n+2个前缀点( f i f_i fi累加起来即可)。利用这些点再次做插值得到 S R S_R SR 和 S L − 1 S_{L-1} SL−1相减。
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int mod = 9999991;
const int MAXN = 1e3 + 23;
LL pre[MAXN], suff[MAXN], fact[MAXN], f[MAXN], sum[MAXN];
LL qmi(LL a, LL k, LL p) {
LL res = 1 % p; // res记录答案, 模上p是为了防止k为0,p为1的特殊情况
while (k) { // 只要还有剩下位数
if (k & 1)
res = res * a % p; // 判断最后一位是否为1,如果为1就乘上a,模上p,
// 乘法时可能爆int,所以变成long long
k >>= 1; // 右移一位
a = a * a % p; // 当前a等于上一次的a平方,取模,平方时可能爆int,所以变成long
// long
}
return res;
}
void init(int n, int mod) {
LL res = 1;
for (int i = 1; i <= n + 2; ++i) res = res * i % mod; // 计算阶乘
fact[n + 2] = qmi(res, mod - 2, mod);
for (int i = n + 1; i >= 0; --i) fact[i] = (i + 1) * fact[i + 1] % mod; // 计算逆元
}
//给定最高次为n的多项式的n+1个点分别为:(0,f0),(1,f1),(2,f2)...(n,fn),求f(k)的值
LL Lagrange(LL f[], int n, int k) {
if (k <= n) return f[k];
pre[0] = suff[n + 1] = 1;
for (int i = 0; i <= n; ++i) pre[i + 1] = pre[i] * (k - i) % mod;
for (int i = n; i >= 0; --i) suff[i] = suff[i + 1] * (k - i) % mod;
LL fk = 0;
for (int i = 0; i <= n; ++i) {
LL tmp = f[i] * pre[i] % mod * suff[i + 1] % mod * fact[i] % mod * fact[n - i] % mod; // 计算 (pre[i-1]*suff[i+1]) / (fact[i]*fact[n-i])
if ((n - i) & 1) // 当 n-i 为奇数时,分母应该取负号。
fk = (fk - tmp + mod) % mod;
else
fk = (fk + tmp) % mod;
}
return fk;
}
int main() {
init(MAXN - 10, mod);
int T, L, R, n, m;
cin >> T;
while (T--) {
cin >> n >> m;
for (int i = 0; i <= n; ++i) cin >> f[i];
f[n + 1] = Lagrange(f, n, n + 1); // 得到f[n+1]
sum[0] = f[0] % mod;
for (int i = 1; i <= n + 1; ++i) sum[i] = (sum[i - 1] + f[i]) % mod; // 得到 n + 2个前缀和点
while (m--) {
cin >> L >> R;
cout << (Lagrange(sum, n + 1, R) - Lagrange(sum, n + 1, L - 1) + mod) % mod << '\n'; // 两个前缀和相减记为: L ~ R 这段
}
}
return 0;
}