[bzoj4818]序列计数
倍增
- 代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
ll n;
int m,p;
ll mod=20170408;
ll f[40][100];
ll g[40][100];
const int N=2e7+1;
ll ff[100],gg[100],ss[100];
int pcnt,prime[4000010];
bool vis[N];
void init(){
for(int i=2;i<N;i++){
if(!vis[i])prime[++pcnt]=i;
for(int j=1;j<=pcnt&&1ll*i*prime[j]<N;j++){
vis[prime[j]*i]=true;
if(i%prime[j]==0)break;
}
}
}
int main()
{
init();
scanf("%lld%d%d",&n,&m,&p);
for(int i=1;i<=m;i++){
f[0][i%p]++;
g[0][i%p]++;
}
for(int i=1;i<=pcnt;i++)if(prime[i]<=m){
g[0][prime[i]%p]--;
}
for(int i=1;i<=34;i++){
for(int j=0;j<p;j++){
for(int k=0;k<p;k++){
int l=(j+k+p)%p;
f[i][l]+=f[i-1][j]*f[i-1][k];
g[i][l]+=g[i-1][j]*g[i-1][k];
f[i][l]%=mod,g[i][l]%=mod;
}
}
}
memset(ff,0,sizeof(ff));
memset(gg,0,sizeof(gg));
ff[0]=gg[0]=1;
for(int i=0;i<=31;i++)if((n>>i)&1){
for(int j=0;j<p;j++){
for(int k=0;k<p;k++){
int l=(j+k)%p;
ss[l]+=ff[j]*f[i][k];
ss[l]%=mod;
}
}
for(int j=0;j<p;j++){ff[j]=ss[j];ss[j]=0;}
for(int j=0;j<p;j++){
for(int k=0;k<p;k++){
int l=(j+k)%p;
ss[l]+=gg[j]*g[i][k];
ss[l]%=mod;
}
}
for(int j=0;j<p;j++){gg[j]=ss[j];ss[j]=0;}
}
printf("%lld\n",(ff[0]-gg[0]+mod)%mod);
}