所谓常系数齐次线性递推,就是系数为常数的齐次线性递推。
(逃)
前言
sto Asta orz!
又是一个名字高大上,实则小清新的算法!
解析
考虑一个 k 次的线性递推:
a
n
=
∑
i
=
1
k
f
i
a
n
−
i
a_n=\sum_{i=1}^kf_ia_{n-i}
an=i=1∑kfian−i
不断的把高次的项按照定义式拆成低次项,最终必然可以写成
∑
i
=
0
k
−
1
p
i
a
i
\sum_{i=0}^{k-1}p_ia_i
∑i=0k−1piai 的形式。
以最常见的斐波拉契为例:
f
5
=
f
4
+
f
3
=
2
f
3
+
f
2
=
3
f
2
+
2
f
1
=
5
f
1
+
3
f
0
f_5=f_4+f_3=2f_3+f_2=3f_2+2f_1=5f_1+3f_0
f5=f4+f3=2f3+f2=3f2+2f1=5f1+3f0。
然后带入
a
0...
k
−
1
a_{0...k-1}
a0...k−1 的值计算即可。
注意到,我们分解高次项的过程,其实在本质上也就等价于令
x
n
x^n
xn 不断向多项式
x
k
−
f
1
x
k
−
1
−
f
2
x
k
−
2
−
.
.
.
−
f
k
x
0
x^k-f_1x^{k-1}-f_2x^{k-2}-...-f_kx^0
xk−f1xk−1−f2xk−2−...−fkx0 取模的过程。
既然如此,我们就可以直接使用类似于快速幂的方法,不断把
x
x
x 平方并向上述的多项式取模即可。
暴力取模
O
(
k
2
log
n
)
O(k^2\log n)
O(k2logn),使用多项式科技可以做到
O
(
k
log
k
log
n
)
O(k\log k\log n)
O(klogklogn)。
代码
(过度封装非常严重,常数极大。)
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
#define debug(...) fprintf(stderr,__VA_ARGS__)
inline ll read() {
ll x(0),f(1);char c=getchar();
while(!isdigit(c)) {if(c=='-')f=-1;c=getchar();}
while(isdigit(c)) {x=(x<<1)+(x<<3)+c-'0';c=getchar();}
return x*f;
}
const int N=2e5+100;
const int mod=998244353;
int n,m,k;
inline ll ksm(ll x,ll k){
ll res=1;
while(k){
if(k&1) res=res*x%mod;
x=x*x%mod;k>>=1;
}
return res;
}
int niv2=ksm(2,mod-2);
int r[N];
void init(int n,int &lim){
lim=1;int L=0;
while(lim<n) lim<<=1,L++;
for(int i=1;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(L-1));
}
void NTT(ll *x,int lim,int op){
for(int i=0;i<lim;i++) if(i<r[i]) swap(x[i],x[r[i]]);
for(int l=1;l<lim;l<<=1){
ll w=ksm(3,(mod-1)/(l<<1));if(op==-1) w=ksm(w,mod-2);
for(int st=0;st<lim;st+=(l<<1)){
for(ll i=0,now=1;i<l;i++,now=now*w%mod){
ll u=x[st+i],v=now*x[st+i+l]%mod;
x[st+i]=u+v>=mod?u+v-mod:u+v;
x[st+i+l]=u-v<0?u-v+mod:u-v;
}
}
}
if(op==-1){
ll ni=ksm(lim,mod-2);
for(int i=0;i<lim;i++) x[i]=x[i]*ni%mod;
}
return;
}
void copy(ll *a,ll *b,int n,int lim){
assert(n<=lim);
memcpy(a,b,sizeof(ll)*n);
fill(a+n,a+lim,0);return;
}
void mul(ll *a,ll *b,ll *c,int n,int m){
static ll u[N],v[N];
static int lim;
init(n+m-1,lim);
copy(u,a,n,lim);
copy(v,b,m,lim);
NTT(u,lim,1);NTT(v,lim,1);
for(int i=0;i<lim;i++) c[i]=u[i]*v[i]%mod;
NTT(c,lim,-1);
//for(int i=0;i<n+m-1;i++) printf("%lld ",c[i]);putchar('\n');
//putchar('\n');
return;
}
void inv(ll *h,ll *f,int n){
static ll t1[N],t2[N];
static int lim;
if(n==1){
f[0]=ksm(h[0],mod-2);return;
}
inv(h,f,(n+1)>>1);
init(n<<1,lim);
fill(f+((n+1)>>1),f+lim,0);
copy(t1,f,n,lim);copy(t2,h,n,lim);
NTT(t1,lim,1);NTT(t2,lim,1);
for(int i=0;i<lim;i++) t1[i]=(2*t1[i]-t1[i]*t1[i]%mod*t2[i]%mod+mod)%mod;
NTT(t1,lim,-1);
memcpy(f,t1,sizeof(ll)*n);
return;
}
void chu(ll *f,ll *g,ll *q,ll *r,int n,int m){
static ll F[N],G[N],ff[N],gg[N],Q[N],tmp[N];
static int lim;
--n;--m;
init(n+n,lim);
copy(F,f,n+1,lim);
copy(ff,f,n+1,lim);
copy(G,g,m+1,lim);
copy(gg,g,m+1,lim);
reverse(F,F+1+n);
reverse(G,G+1+m);
inv(G,tmp,n-m+1);mul(tmp,F,Q,n-m+1,n-m+1);
reverse(Q,Q+n-m+1);
//fill(Q+n-m+1,Q+n+1,0);
for(int i=n-m+1;i<=n;i++) Q[i]=0;
copy(q,Q,n-m+1,lim);
mul(Q,gg,tmp,n+1,n+1);
for(int i=0;i<n;i++) r[i]=(ff[i]+mod-tmp[i])%mod;
return;
}
ll F[N],G[N],Q[N],R[N];
void mul_mod(ll *x,ll *y,ll *Mod,int n){
static ll t1[N],t2[N],t[N];
static int lim;
init(n+n,lim);
copy(t1,x,n,lim);
copy(t2,y,n,lim);
mul(t1,t2,t,n,n);
chu(t,Mod,t1,t,2*n-1,n+1);
memcpy(x,t,sizeof(t));
}
ll LinearRecurrence(ll *a,ll *ff,int k,int n){
static ll res[N],tmp[N],f[N];
memset(tmp,0,sizeof(ll)*(k*2+5));
memset(res,0,sizeof(ll)*(k*2+5));
for(int i=0;i<k;i++) f[i]=(mod-ff[k-i])%mod;
f[k]=1;
tmp[1]=1;res[0]=1;
while(n){
if(n&1){
mul_mod(res,tmp,f,k);
}
mul_mod(tmp,tmp,f,k);
n>>=1;
}
ll ans(0);
for(int i=0;i<k;i++) ans=(ans+a[i]*res[i])%mod;
return ans;
}
ll f[N],a[N];
signed main() {
#ifndef ONLINE_JUDGE
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
#endif
n=read();k=read();
for(int i=1;i<=k;i++) f[i]=(read()%mod+mod)%mod;
for(int i=0;i<k;i++) a[i]=(read()%mod+mod)%mod;
printf("%lld\n",LinearRecurrence(a,f,k,n));
return 0;
}
/*
3 1
2 3 3 1
1 1
*/