//P3803
第一道靠自己实力AC的省选题 无数次漏掉位运算,无数次忘记重载运算符之后终于从别人的博客里学会了FFT
然而我发现我可以HACK自己的第一个代码[捂脸]
终于有一天我醒悟进化了!!!
上自己的FFT模板
#include<bits/stdc++.h>
using namespace std;
template<class T>inline void read(T &x)
{
int n=0;
bool sym=0;
char c=getchar();
while(c<48||c>57){sym|=(c==45);c=getchar();}
while(c>47&&c<58){n=(n<<1)+(n<<3)+(c^48);c=getchar();}
x=sym?~n+1:n;
}
struct cpx
{
double r,i;//复数
cpx(double r,double i)//构造函数
{
this->r=r;
this->i=i;
}
cpx(){}
cpx operator +(const cpx& b)const//复-复重载加号
{
return cpx(this->r+b.r,this->i+b.i);
}
cpx operator -(const cpx& b)const//复-复重载减号
{
return cpx(this->r-b.r,this->i-b.i);
}
cpx operator *(const cpx& b)const//复-复重载乘号
{
return cpx(this->r*b.r-this->i*b.i,this->i*b.r+this->r*b.i);
}
cpx operator *(const double& b)const//复-实重载乘号
{
return cpx(this->r*b,this->i*b);
}
inline cpx operator /(const double& b)const//复-实重载除号
{
return cpx(this->r/b,this->i/b);
}
inline cpx operator *=(const cpx& b)//重载复-复乘法赋值
{
return *this=*this*b;
}
inline cpx operator /=(const double& b)//重载复-实除法赋值
{
return *this=*this/b;
}
inline cpx operator =(const double& b)//重载复-实赋值
{
return *this=cpx(b,0);
}
inline cpx operator /=(const cpx& b)//重载复-复乘法赋值
{
return *this=*this/b;
}
};
const int maxl=1<<22;
const double PI=3.1415926535897932384626433832795028841971693993751;
int n,m;
cpx a[maxl],b[maxl];
int rev[maxl];
void getrev(int bit)
{
for(int i=0;i<(1<<bit);++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
}
void fft(cpx *a,int n,int dft)
{
for(int i=0;i<n;++i)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int step=1;step<n;step<<=1)
{
cpx wn(cos(PI/step),sin(dft*PI/step));
for(int j=0;j<n;j+=step<<1)
{
cpx wnk(1,0);
for(int k=j;k<j+step;++k)
{
cpx x=a[k];
cpx y=a[k+step]*wnk;
a[k]=x+y;
a[k+step]=x-y;
wnk*=wn;
}
}
}
if(dft==-1)for(int i=0;i<n;++i)a[i]/=n;
}
int main()
{
read(n),read(m);
for(int i=0;i<=n;++i)read(a[i].r);
for(int i=0;i<=m;++i)read(b[i].r);
int bit=1,s=2;
for(;(1<<bit)<n+m+1;++bit)s<<=1;
getrev(bit);
fft(a,s,1),fft(b,s,1);
for(int i=0;i<s;++i)a[i]*=b[i];
fft(a,s,-1);
for(int i=0;i<=n+m;++i)printf("%d ",(int)(a[i].r+0.5));
}
丑陋的复数模板
然而,FFT涉及整数与实数、实数与复数的运算
所以难免有精度问题其实是我嫌处理输入输出太麻烦
于是我再次进化了!!!
不用写复数,不用处理输入输出的NTT。嗯,真好吃~~比糕点还好吃(才怪)
#include<cstdio>
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1<<21, stdin), p1 == p2) ? EOF : *p1++)
inline void swap(long long& a,long long& b)
{
a^=b;
b^=a;
a^=b;
}
char buf[1<<21],*p1=buf,*p2=buf;//io_optimize
template<class T>inline void read(T &x)
{
long long n=0;
bool sym=0;
char c=getchar();
while(c<48||c>57){sym|=(c==45);c=getchar();}
while(c>47&&c<58){n=(n<<1)+(n<<3)+(c^48);c=getchar();}
x=sym?~n+1:n;
}
template<class T>inline void write(T x)
{
if(x>=10)write(x/10);
putchar((x%10)^48);
}
const int maxl=4004000;
int n,m;
const long long mod=998244353,g=3,gi=332748118;
long long a[maxl],b[maxl];
int rev[maxl];
long long power(long long a,long long n)
{
long long res=1;
for(;n;n>>=1)
{
(n&1)&&(res=res*a%mod);
a=a*a%mod;
}
return res;
}
void getrev(int bit)
{
for(int i=0;i<(1<<bit);++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
}
void ntt(long long *a,int n,int dft)
{
for(int i=0;i<n;++i)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int step=1;step<n;step<<=1)
{
long long wn=power(dft==1?g:gi,(mod-1)/(step<<1));
for(int j=0;j<n;j+=step<<1)
{
long long wnk=1;
for(int k=j;k<j+step;++k)
{
long long x=a[k];
long long y=a[k+step]*wnk%mod;
a[k]=(x+y)%mod;
a[k+step]=(x-y+mod)%mod;
wnk=wnk*wn%mod;
}
}
}
if(dft==-1)
{
long long ni=power(n,mod-2);
for(int i=0;i<n;++i)a[i]=a[i]*ni%mod;
}
}
int main()
{
read(n),read(m);
for(int i=0;i<=n;++i)read(a[i]);
for(int i=0;i<=m;++i)read(b[i]);
int bit=1,s=2;
for(;s<n+m+1;++bit)s<<=1;
getrev(bit);
ntt(a,s,1),ntt(b,s,1);
for(int i=0;i<s;++i)a[i]=a[i]*b[i]%mod;
ntt(a,s,-1);
for(int i=0;i<=n+m;++i)
{
write(a[i]);
putchar(' ');
}
}
蒟蒻实力和时间有限,写不了证明和推导过程,,,但看着我这骚气清新的码风,这个拿去当板子还是很不错哒!
另附Pascal版NTT题解(会TLE)
作为一个P党转C++的蒟蒻,打了一下午却发现,,,这题用Pascal是一定会T的
不得不说FPC编译出来的代码效率远低于GCC和G++
const p:int64=998244353;
g:int64=3;
gi:int64=332748118;
var n,m,i,bit,s:longint;
a,b:array[0..(1 shl 21)] of int64;
function pow(a,b:int64):int64;
begin
pow:=1;
while b>0 do
begin
if b and 1=1 then pow:=pow*a mod p;
a:=a*a mod p;
b:=b shr 1;
end;
end;
var rev:array[0..(1 shl 21)]of longint;
procedure getrev(bit:longint);
var i:longint;
begin
for i:=0 to (1 shl bit)-1 do rev[i]:=(rev[i shr 1] shr 1)or((i and 1)shl(bit-1));
end;
procedure ntt(var a:array of int64;n:longint;dft:integer);
var i,step,j,k:longint;
wn,wnk,gg,x,y,inv:int64;
begin
if dft=1 then gg:=g
else gg:=gi;
for i:=0 to n-1 do
if i<rev[i] then
begin
a[i]:=a[i] xor a[rev[i]];
a[rev[i]]:=a[rev[i]] xor a[i];
a[i]:=a[i] xor a[rev[i]];
end;
step:=1;
while step<n do
begin
wn:=pow(gg,(P-1) div (step shl 1));
j:=0;
while j<n do
begin
wnk:=1;
for k:=j to j+step-1 do
begin
x:=a[k];
y:=a[k+step]*wnk mod p;
a[k]:=(x+y) mod p;
a[k+step]:=(x-y+p) mod p;
wnk:=wnk*wn mod p;
end;
j:=j+(step shl 1);
end;
step:=step shl 1;
end;
if dft=-1 then
begin
inv:=pow(n,p-2);
for i:=0 to n-1 do a[i]:=a[i]*inv mod p;
end;
end;
begin
read(n,m);
for i:=0 to n do read(a[i]);
for i:=0 to m do read(b[i]);
s:=1;
bit:=0;
while s<=n+m do
begin
s:=s shl 1;
inc(bit);
end;
getrev(bit);
ntt(a,s,1);
ntt(b,s,1);
for i:=0 to s-1 do a[i]:=a[i]*b[i] mod p;
ntt(a,s,-1);
for i:=0 to n+m do write(a[i],' ');
end.
我已经尽力在用位运算了。。。
//P4238
懒得写,直接上代码
#include<cstdio>
//using namespace std;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
inline void swap(long long& a,long long& b)
{
a^=b;
b^=a;
a^=b;
}
char buf[1<<21],*p1=buf,*p2=buf;//io_optimize
template<class T>inline void read(T &x)
{
long long n=0;
bool sym=0;
char c=getchar();
while(c<48||c>57){sym|=(c==45);c=getchar();}
while(c>47&&c<58){n=(n<<1)+(n<<3)+(c^48);c=getchar();}
x=sym?~n+1:n;
}
template<class T>inline void write(T x)
{
if(x>=10)write(x/10);
putchar(x%10+48);
}
const int maxl=1<<21;
int n;
const long long P=998244353,g=3,gi=332748118;
long long a[maxl],b[maxl],c[maxl];
int rev[maxl];
inline long long power(long long a,long long n)
{
long long res=1;
for(;n;n>>=1)
{
(n&1)&&(res=res*a%P);
a=a*a%P;
}
return res;
}
void getrev(int bit)
{
for(int i=0;i<(1<<bit);++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
}
void ntt(long long *a,int n,int dft=1)//NTT
{
for(int i=0;i<n;++i)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int step=1;step<n;step<<=1)
{
long long wn=power(dft==1?g:gi,(P-1)/(step<<1));
for(int j=0;j<n;j+=step<<1)
{
long long wnk=1;
for(int k=j;k<j+step;++k)
{
long long x=a[k];
long long y=a[k+step]*wnk%P;
a[k]=(x+y)%P;
a[k+step]=(x-y+P)%P;
wnk=wnk*wn%P;
}
}
}
if(dft==-1)//INTT
{
long long inv=power(n,P-2);
for(int i=0;i<n;++i)a[i]=a[i]*inv%P;
}
}
void work(int deg,long long *a,long long *b)
{
if(deg==1){b[0]=power(a[0],P-2);return;}
work((deg+1)>>1,a,b);
int bit=0,len=1;
for(;len<(deg<<1);++bit)len<<=1;
getrev(bit);
for(int i=0;i<deg;++i)c[i]=a[i];
for(int i=deg;i<len;++i)c[i]=0;
ntt(b,len);
ntt(c,len);
for(int i=0;i<len;++i)b[i]=(2-b[i]*c[i]%P+P)%P*b[i]%P;
ntt(b,len,-1);
for(int i=deg;i<len;++i)b[i]=0;
}
int main()
{
read(n);
for(int i=0;i<n;++i)read(a[i]);
work(n,a,b);
for(int i=0;i<n;++i)printf("%lld ",b[i]);
}
//P4717
也直接上代码,不解释
#include<cstdio>
#include<cstring>
using namespace std;
const int P=998244353;
inline void cpy(int*a,int*b,int n)
{
for(int i=0;i<n;++i)a[i]=b[i];
}
void FWT_and(int* a,int n)
{
for(int step=1;step<n;step<<=1)
{
for(int j=0;j<n;j+=step<<1)
{
for(int k=j;k<j+step;++k)
{
int x=a[k];
int y=a[k+step];
a[k]=(x+y)%P;
}
}
}
}
void FWT_or(int* a,int n)
{
for(int step=1;step<n;step<<=1)
{
for(int j=0;j<n;j+=step<<1)
{
for(int k=j;k<j+step;++k)
{
int x=a[k];
int y=a[k+step];
a[k+step]=(x+y)%P;
}
}
}
}
void FWT_xor(int* a,int n)
{
for(int step=1;step<n;step<<=1)
{
for(int j=0;j<n;j+=step<<1)
{
for(int k=j;k<j+step;++k)
{
int x=a[k];
int y=a[k+step];
a[k]=(x+y)%P;
a[k+step]=(x-y+P)%P;
}
}
}
}
void IFWT_and(int* a,int n)
{
for(int step=1;step<n;step<<=1)
{
for(int j=0;j<n;j+=step<<1)
{
for(int k=j;k<j+step;++k)
{
int x=a[k];
int y=a[k+step];
a[k]=(x-y+P)%P;
}
}
}
}
void IFWT_or(int* a,int n)
{
for(int step=1;step<n;step<<=1)
{
for(int j=0;j<n;j+=step<<1)
{
for(int k=j;k<j+step;++k)
{
int x=a[k];
int y=a[k+step];
a[k+step]=(y-x+P)%P;
}
}
}
}
void IFWT_xor(int* a,int n)
{
int inv=499122177;
for(int step=1;step<n;step<<=1)
{
for(int j=0;j<n;j+=step<<1)
{
for(int k=j;k<j+step;++k)
{
int x=a[k];
int y=a[k+step];
a[k]=(long long)(x+y)*inv%P;
a[k+step]=(long long)(x-y+P)*inv%P;
}
}
}
}
int n,a[1<<17],b[1<<17],_a[1<<17],_b[1<<17],c[1<<17];
int main()
{
scanf("%d",&n);
n=1<<n;
for(int i=0;i<n;++i)scanf("%d",a+i);
for(int i=0;i<n;++i)scanf("%d",b+i);
cpy(_a,a,n);
cpy(_b,b,n);
FWT_or(_a,n);
FWT_or(_b,n);
for(int i=0;i<n;++i)c[i]=(long long)_a[i]*_b[i]%P;
IFWT_or(c,n);
for(int i=0;i<n;++i)printf("%d ",c[i]);
putchar('\n');
cpy(_a,a,n);
cpy(_b,b,n);
FWT_and(_a,n);
FWT_and(_b,n);
for(int i=0;i<n;++i)c[i]=(long long)_a[i]*_b[i]%P;
IFWT_and(c,n);
for(int i=0;i<n;++i)printf("%d ",c[i]);
putchar('\n');
cpy(_a,a,n);
cpy(_b,b,n);
FWT_xor(_a,n);
FWT_xor(_b,n);
for(int i=0;i<n;++i)c[i]=(long long)_a[i]*_b[i]%P;
IFWT_xor(c,n);
for(int i=0;i<n;++i)printf("%d ",c[i]);
putchar('\n');
}