UOJ的FFT板子
我分别写了fft 与 ntt...
ntt与fft的差别仅在于复数运算这一块...
FFT代码如下:
#include<bits/stdc++.h>
#define PI acos(-1)
#define rep(i,x,y) for(register int i = x; i <= y ; ++ i)
#define repd(i,x,y) for(register int i = x; i >= y ; -- i)
using namespace std;
typedef long long ll;
template<typename T>inline void read(T&x)
{
x = 0; char c;int sign = 1;
do { c = getchar(); if(c == '-') sign = -1; }while(!isdigit(c));
do { x = x * 10 + c - '0'; c = getchar(); }while(isdigit(c));
x *= sign;
}
const int N = 4e5 + 500;
struct cpx
{
double x,y;
cpx(){}
cpx(double a,double b) { x = a,y = b; }
inline cpx operator * (cpx b) { return cpx(x*b.x - y * b.y,b.x*y + b.y * x); }
inline cpx operator *= (cpx b) { *this = *this * b; }
inline cpx operator + (cpx b) { return cpx(x + b.x,y + b.y); }
inline cpx operator - (cpx b) { return cpx(x - b.x,y - b.y); }
}A[N],B[N];
int m,M,n,L,R[N];
inline void fft(cpx*a,int f)
{
rep(i,0,n - 1) if(i < R[i]) swap(a[i],a[R[i]]);
for(register int i = 1 ; i < n ; i <<= 1)
{
cpx nw = cpx(cos(PI/i),f*sin(PI/i));
for(register int j = 0 ; j < n ; j += (i << 1))
{
cpx w = cpx(1,0);
for(register int k = 0; k < i ; ++k, w *= nw)
{
cpx x = a[j + k],y = w * a[i + j + k];
a[j + k] = x + y; a[i + j + k] = x - y;
}
}
}
if(f == -1) rep(i,0,n - 1) a[i].x /= n;
}
int main()
{
read(n); read(m);
rep(i,0,n) scanf("%lf",&A[i].x);
rep(i,0,m) scanf("%lf",&B[i].x);
m = n + m;
for(n = 1; n <= m; n <<= 1,++L);
rep(i,0,n - 1) R[i] = (R[i>>1]>>1)|((i&1)<<(L-1));
fft(A,1); fft(B,1);
rep(i,0,n - 1) A[i] *= B[i];
fft(A,-1);
rep(i,0,m) printf("%d ",(int)(A[i].x + 0.5));
return 0;
}
NTT代码如下:
#include<bits/stdc++.h>
#define rep(i,x,y) for(register int i = x;i <= y; ++ i)
#define repd(I,x,y) for(register int i = x;i >= y; -- i)
using namespace std;
typedef long long ll;
template<typename T>inline void read(T&x)
{
x = 0;char c;int sign = 1;
do { c = getchar(); if(c == '-') sign = -1; }while(!isdigit(c));
do { x = x * 10 + c - '0'; c = getchar(); }while(isdigit(c));
x *= sign;
}
const int g = 3,mod = 998244353,N = 4e5+50;
int n,m,L,R[N];
ll a[N],b[N];
inline int quick_pow(int x,int y)
{
int ans = 1;
while(y)
{
if(y&1) ans = 1ll * ans * x % mod;
x = 1ll * x * x % mod;
y >>= 1;
}
return ans;
}
inline void ntt(ll*a,int f)
{
rep(i,0,n-1) if(i < R[i]) swap(a[i],a[R[i]]);
for(register int i = 1 ; i < n ; i <<= 1)
{
ll wn = quick_pow(g,(mod - 1)/(i << 1));
if(!(~f)) wn = quick_pow(wn,mod - 2);
for(register int j = 0; j < n ; j += (i << 1))
{
ll w = 1;
for(register int k = 0; k < i;++k,w = w*wn%mod)
{
ll x = a[j + k],y = w * a[i + j + k] % mod;
a[j + k] = ((x + y) % mod + mod) %mod;
a[i + j + k] = ((x - y) % mod + mod) %mod;
}
}
}
if(!(~f))
{
int inv = quick_pow(n, mod - 2);
rep(i,0,n - 1) a[i] = 1ll * a[i] * inv % mod;
}
}
int main()
{
read(n); read(m);
rep(i,0,n) read(a[i]);
rep(i,0,m) read(b[i]);
m = n + m;
for(n = 1;n <= m; n <<= 1,L++);
rep(i,0,n - 1) R[i] = (R[i >> 1] >> 1) | ((i & 1)<<(L - 1));
ntt(a,1); ntt(b,1);
rep(i,0,n - 1) a[i] = 1ll * a[i] * b[i] %mod;
ntt(a,-1);
rep(i,0,m) printf("%lld ",a[i]);
return 0;
}
分治fft
#include<bits/stdc++.h>
#define rep(i,x,y) for(register int i = x;i <= y; ++ i)
#define repd(i,x,y) for(register int i = x;i >= y; -- i)
using namespace std;
typedef long long ll;
template<typename T>inline void read(T&x)
{
char c;int sign = 1;x = 0;
do { c = getchar(); if(c == '-') sign = -1; }while(!isdigit(c));
do { x = x * 10 + c - '0'; c = getchar(); }while(isdigit(c));
x *= sign;
}
const int G = 3,mod = 998244353;
inline int ksm(int x,int y)
{
int ans = 1;
while(y)
{
if(y&1) ans = 1ll * ans * x % mod;
x = 1ll * x * x % mod; y >>= 1;
}
return ans;
}
const int N = 4e5 + 50;
int R[N],g[N],f[N],a[N],b[N];
inline void ntt(int*a,int f,int n)
{
rep(i,0,n - 1) if(R[i] < i) swap(a[i],a[R[i]]);
for(register int i = 1;i < n; i <<= 1)
{
int wn = ksm(G,(mod-1)/(i<<1));
if(!(~f)) wn = ksm(wn,mod-2);
for(register int j = 0;j < n;j += (i << 1))
{
int w = 1;
for(register int k = 0;k < i; ++ k,w = 1ll * w * wn % mod)
{
int x = a[j + k],y = 1ll * w * a[i + j + k] % mod;
a[j + k] = x + y; a[i + j + k] = x - y;
if(a[j + k] >= mod) a[j + k] -= mod;
if(a[i + j + k] < 0) a[i + j + k] += mod;
}
}
}
if(!(~f))
{
int inv = ksm(n,mod - 2);
rep(i,0,n-1) a[i] = 1ll * a[i] * inv % mod;
}
}
inline void solve(int l,int r)
{
if(l == r) return ;
int mid = l + r >> 1;
solve(l,mid);
int m = mid - l + 2 + r - l,len,L = 0;
for(len = 1;len <= m;len <<= 1) ++ L;
rep(i,0,len-1) a[i] = b[i] = 0;
rep(i,0,mid - l) a[i] = f[i + l];
rep(i,0,r - l) b[r - l - i] = g[r - i - l];
rep(i,0,len - 1) R[i] = (R[i>>1]>>1)|((i&1)<<(L-1));
ntt(a,1,len); ntt(b,1,len);
rep(i,0,len - 1) a[i] = 1ll * a[i] * b[i] % mod;
ntt(a,-1,len); rep(i,mid+1,r) f[i] = (f[i] + a[i-l])%mod;
solve(mid+1,r);
}
int main()
{
int n;
read(n);
rep(i,1,n-1) read(g[i]);
f[0] = 1; solve(0,n-1);
rep(i,0,n-1) printf("%d ",f[i]);
return 0;
}