引入
先看一道例题:
给定
n,k
n
,
k
,求:
对大质数取模。 n≤109 n ≤ 10 9 。
Part 1: k≤2 k ≤ 2 。
Part 2: k≤300 k ≤ 300 。
Part 3: k≤2000 k ≤ 2000 。
Part 4: k≤105 k ≤ 10 5 。
先看每一部分的解题方法。
Part 1
根据高中数学介绍的公式, O(1) O ( 1 ) 算出。
Part 2
我们已经不能手算了。由
k≤2
k
≤
2
的情况,猜想答案可能是一个
k+1
k
+
1
次多项式。于是可以高斯消元消出每一项系数,
O(k3)
O
(
k
3
)
得到答案。
具体证明可以参见差分与有限微积分——阮行止的博客。
Part 3
接下来这一部分,高斯消元已经无能为力了。接下来就要开始进入正题——拉格朗日插值法了。
什么?你已经推出来了?
设
f(k)=∑ni=1ik
f
(
k
)
=
∑
i
=
1
n
i
k
,我们有:
将所有式子求和,得:
然后就可以 O(n2) O ( n 2 ) 算了。
Part 4
现在,上面那种奇♂怪的算法也不能做了,下面介绍拉格朗日插值法。
算法原理
所谓插值,就是将一些点值
(xi,yi)
(
x
i
,
y
i
)
代入反解还原出多项式的过程,上文介绍的高斯消元解出多项式的系数,就是插值的过程。
拉格朗日插值法是一种高效的插值多项式的算法,以以法国数学家约瑟夫·拉格朗日命名。
先回到我们高斯消元解多项式的过程,我们对于一个
n
n
次多项式(加上常数项共 项),我们需要
n+1
n
+
1
个点值来列方程,由线性代数的知识可得,一个含有
n+1
n
+
1
个未知数的线性方程组有唯一解的必要条件是方程的个数
≤n+1
≤
n
+
1
,所以我们需要
n+1
n
+
1
个点值才能解出唯一对应的多项式。
在拉格朗日插值法中,我们也需要
n+1
n
+
1
个点值。设点值为
(x0,y0),(x2,y2),…,(xnyn)
(
x
0
,
y
0
)
,
(
x
2
,
y
2
)
,
…
,
(
x
n
y
n
)
拉格朗日插值法的原理是构造一个拉格朗日基本多项式
lj(x)
l
j
(
x
)
,满足
lj(xj)=1
l
j
(
x
j
)
=
1
,
∀i=0…n,i≠j,lj(xi)=0
∀
i
=
0
…
n
,
i
≠
j
,
l
j
(
x
i
)
=
0
。
那么所得到的拉格朗日插值多项式为:
发现这个多项式对于所有的点值均成立。那么我们想办法构造出满足条件的 lj(x) l j ( x ) 。这里用到了一个十分暴力的方法:
十分巧妙。
这样朴素计算显然是 O(n2) O ( n 2 ) 的。 但是我们可以取原函数在 0,1,…,n 0 , 1 , … , n 处的取值,这样就把 lj(x) l j ( x ) 的分数线下化成阶乘的形式,分数线上化成下降阶乘幂的形式。于是就可以在 O(nlog2n) O ( n log 2 n ) ( log2n log 2 n 为快速幂)的复杂度来求解这个问题了。
例题
「BZOJ3453」「Tyvj1858」 XLkxc
题意:
给定
k,a,n,d,p
k
,
a
,
n
,
d
,
p
f(i)=1k+2k+3k+......+ik
f
(
i
)
=
1
k
+
2
k
+
3
k
+
.
.
.
.
.
.
+
i
k
g(x)=f(1)+f(2)+f(3)+....+f(x)
g
(
x
)
=
f
(
1
)
+
f
(
2
)
+
f
(
3
)
+
.
.
.
.
+
f
(
x
)
求
(g(a)+g(a+d)+g(a+2d)+......+g(a+nd)) modp
(
g
(
a
)
+
g
(
a
+
d
)
+
g
(
a
+
2
d
)
+
.
.
.
.
.
.
+
g
(
a
+
n
d
)
)
mod
p
1≤k≤123
1
≤
k
≤
123
0≤a,n,d≤123456789
0
≤
a
,
n
,
d
≤
123456789
p=1234567891
p
=
1234567891
题解:
显然
f(i)
f
(
i
)
是一个
k+1
k
+
1
次多项式。
显然
g(i)
g
(
i
)
是
f(i)
f
(
i
)
的前缀和, 是一个
k+2
k
+
2
次多项式。
显然
ans
a
n
s
是
g(i)
g
(
i
)
的前缀和, 是一个
k+3
k
+
3
次多项式。
于是就可以大力插值,算出
g
g
,然后算出 。
由于中间结果可能超过 p,预处理阶乘搞可能绘算出
0
0
,于是我们可以暴力插值,同样能通过此题。
复杂度 。
My Code
/**************************************************************
Problem: 3453
User: infinityedge
Language: C++
Result: Accepted
Time:520 ms
Memory:1300 kb
****************************************************************/
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 1234567891ll;
ll qpow(ll a, ll b){
ll ret = 1;
for(; b; b >>= 1, a = a * a % mod){
if(b & 1) ret = ret * a % mod;
}
return ret;
}
ll dy[150];
ll fac[150], ifac[150];
void pre(ll k){
for(int i = 1; i <= k + 3; i ++){
dy[i] = dy[i - 1];
for(int j = 1; j <= i; j ++){
dy[i] = (dy[i] + qpow(j, k)) % mod;
}
// printf("%lld ", dy[i]);
}
// printf("\n");
fac[0] = 1; ifac[0] = 1;
for(int i = 1; i <= k + 4; i ++){
fac[i] = fac[i - 1] * i % mod;
ifac[i] = qpow(fac[i], mod - 2);
}
}
ll facx[150], ifacx[150];
ll calG(ll k, ll x){
if(x <= k + 3) return dy[x];
facx[0] = x % mod; ifacx[0] = qpow(x % mod, mod - 2);
for(int i = 1; i <= k + 3; i ++){
facx[i] = facx[i - 1] * (x % mod - i + mod) % mod;
ifacx[i] = qpow(facx[i], mod - 2);
}
ll ret = 0;
for(int i = 1; i <= k + 3; i ++){
ll tmp = 1;
for(int j = 1; j <= k + 3; j ++){
if(j == i) continue;
tmp = tmp * (x % mod - j + mod) % mod;
}
// tmp = tmp * facx[k + 3] * ifacx[i] % mod * facx[i - 1] % mod * ifacx[0] % mod;
tmp = tmp * qpow(fac[i - 1] * fac[k + 3 - i] % mod, mod - 2) % mod;
if((k + 3 - i) % 2 == 1) tmp = mod - tmp;
ret = (ret + tmp * dy[i]) % mod;
}
return ret;
}
//1 3 6 10 15 21
//1 4 10 20 35 56
//1 5 15 35 70
ll y[150];
ll solve(ll k, ll a, ll n, ll d){
for(int i = 0; i <= k + 3; i ++){
y[i] = (y[i - 1] + calG(k, a + d * i)) % mod;
// printf("%lld ", y[i]);
}
// printf("\n");
if(n <= k + 3) return y[n];
ll ret = 0;
facx[0] = n % mod; ifacx[0] = qpow(n, mod - 2);
for(int i = 1; i <= k + 3; i ++){
facx[i] = facx[i - 1] * (n % mod - i + mod) % mod;
ifacx[i] = qpow(facx[i], mod - 2);
// printf("%lld ", facx[i]);
}
// printf("\n");
for(int i = 0; i <= k + 3; i ++){
ll tmp = 1;
tmp = tmp * facx[k + 3] * ifacx[i] % mod;
if(i != 0) tmp = tmp * facx[i - 1] % mod;
if(i == 0) tmp = tmp * qpow(fac[k + 3] % mod, mod - 2) % mod;
else tmp = tmp * qpow(fac[i] * fac[k + 3 - i] % mod, mod - 2) % mod;
if((k + 3 - i) % 2 == 1) tmp = mod - tmp;
ret = (ret + tmp * y[i]) % mod;
//printf("%lld\n", y[i]);
}
return ret;
}
int main(){
int T;
scanf("%d", &T);
while(T--){
ll k, a, n, d;
scanf("%lld%lld%lld%lld", &k, &a, &n, &d);
pre(k);
printf("%lld\n", solve(k, a, n, d));
}
return 0;
}