#34. 多项式乘法

UOJ的FFT板子

我分别写了fft 与 ntt...

ntt与fft的差别仅在于复数运算这一块...

FFT代码如下:

#include<bits/stdc++.h>
#define PI acos(-1)
#define rep(i,x,y) for(register int i = x; i <= y ; ++ i)
#define repd(i,x,y) for(register int i = x; i >= y ; -- i)
using namespace std;
typedef long long ll;
template<typename T>inline void read(T&x)
{
	x = 0; char c;int sign = 1;
	do { c = getchar(); if(c == '-') sign = -1; }while(!isdigit(c));
	do { x = x * 10 + c - '0'; c = getchar(); }while(isdigit(c));
	x *= sign;
}

const int N = 4e5 + 500;

struct cpx
{
	double x,y;
	
	cpx(){}
	cpx(double a,double b) { x = a,y = b; }
	
	inline cpx operator * (cpx b) { return cpx(x*b.x - y * b.y,b.x*y + b.y * x); }
	inline cpx operator *= (cpx b) { *this = *this * b; }
	inline cpx operator + (cpx b) { return cpx(x + b.x,y + b.y); }
	inline cpx operator - (cpx b) { return cpx(x - b.x,y - b.y); }
}A[N],B[N];

int m,M,n,L,R[N];

inline void fft(cpx*a,int f)
{
	rep(i,0,n - 1) if(i < R[i]) swap(a[i],a[R[i]]);
	for(register int i = 1 ; i < n ; i <<= 1)
	{
		cpx nw = cpx(cos(PI/i),f*sin(PI/i));
		for(register int j = 0 ; j < n ; j += (i << 1))
		{
			cpx w = cpx(1,0);
			for(register int k = 0; k < i ; ++k, w *= nw)
			{
				cpx x = a[j + k],y = w * a[i + j + k];
				a[j + k] = x + y; a[i + j + k] = x - y;	
			}
		}
	}
	
	if(f == -1) rep(i,0,n - 1) a[i].x /= n;
}

int main()
{ 
	read(n); read(m);
	rep(i,0,n) scanf("%lf",&A[i].x);
	rep(i,0,m) scanf("%lf",&B[i].x);
	
	m = n + m;
	for(n = 1;  n <= m; n <<= 1,++L);
	rep(i,0,n - 1) R[i] = (R[i>>1]>>1)|((i&1)<<(L-1));
	
	fft(A,1); fft(B,1);
	rep(i,0,n - 1) A[i] *= B[i];
	fft(A,-1);
	rep(i,0,m) printf("%d ",(int)(A[i].x + 0.5));
	
	return 0;
}

NTT代码如下:

#include<bits/stdc++.h>
#define rep(i,x,y) for(register int i = x;i <= y; ++ i)
#define repd(I,x,y) for(register int i = x;i >= y; -- i)
using namespace std;
typedef long long ll;
template<typename T>inline void read(T&x)
{
	x = 0;char c;int sign = 1;
	do { c = getchar(); if(c == '-') sign = -1; }while(!isdigit(c));
	do { x = x * 10 + c - '0'; c = getchar(); }while(isdigit(c));
	x *= sign;
}

const int g = 3,mod = 998244353,N = 4e5+50;

int n,m,L,R[N];
ll a[N],b[N];

inline int quick_pow(int x,int y)
{
	int ans = 1;
	while(y)
	{
		if(y&1) ans = 1ll * ans * x % mod;
		x = 1ll * x * x % mod;
		y >>= 1;
	}
	return ans;
}

inline void ntt(ll*a,int f)
{
	rep(i,0,n-1) if(i < R[i]) swap(a[i],a[R[i]]);
	for(register int i = 1 ; i < n ; i <<= 1)
	{
		ll wn = quick_pow(g,(mod - 1)/(i << 1));
		if(!(~f)) wn = quick_pow(wn,mod - 2);
		for(register int j = 0; j < n ; j += (i << 1))
		{
			ll w = 1;
			for(register int k = 0; k < i;++k,w = w*wn%mod)
			{
				ll x = a[j + k],y = w * a[i + j + k] % mod;
				a[j + k] = ((x + y) % mod + mod) %mod;
				a[i + j + k] = ((x - y) % mod + mod) %mod;
			}
		}
	}
	if(!(~f))
	{
		int inv = quick_pow(n, mod - 2);
		rep(i,0,n - 1) a[i] = 1ll * a[i] * inv % mod;
	}
}

int main()
{
	read(n); read(m);
	rep(i,0,n) read(a[i]);
	rep(i,0,m) read(b[i]);
	m = n + m;
	for(n = 1;n <= m; n <<= 1,L++);
	rep(i,0,n - 1) R[i] = (R[i >> 1] >> 1) | ((i & 1)<<(L - 1)); 
	ntt(a,1); ntt(b,1);
	rep(i,0,n - 1) a[i] = 1ll * a[i] * b[i] %mod;
	ntt(a,-1);
	rep(i,0,m) printf("%lld ",a[i]);
	
	return 0;
}

分治fft

#include<bits/stdc++.h>
#define rep(i,x,y) for(register int i = x;i <= y; ++ i)
#define repd(i,x,y) for(register int i = x;i >= y; -- i)
using namespace std;
typedef long long ll;
template<typename T>inline void read(T&x)
{
	char c;int sign = 1;x = 0;
	do { c = getchar(); if(c == '-') sign = -1; }while(!isdigit(c));
	do { x = x * 10 + c - '0'; c = getchar(); }while(isdigit(c));
	x *= sign;
}

const int G = 3,mod = 998244353;
inline int ksm(int x,int y)
{
	int ans = 1;
	while(y)
	{
		if(y&1) ans = 1ll * ans * x % mod;
		x = 1ll * x * x % mod; y >>= 1;
	}
	return ans;
}

const int N = 4e5 + 50;
int R[N],g[N],f[N],a[N],b[N];

inline void ntt(int*a,int f,int n)
{
	rep(i,0,n - 1) if(R[i] < i) swap(a[i],a[R[i]]);
	for(register int i = 1;i < n; i <<= 1)
	{
		int wn = ksm(G,(mod-1)/(i<<1));
		if(!(~f)) wn = ksm(wn,mod-2);
		for(register int j = 0;j < n;j += (i << 1))
		{
			int w = 1;
			for(register int k = 0;k < i; ++ k,w = 1ll * w * wn % mod)
			{
				int x = a[j + k],y = 1ll * w * a[i + j + k] % mod;
				a[j + k] = x + y; a[i + j + k] = x - y;
				if(a[j + k] >= mod) a[j + k] -= mod;
				if(a[i + j + k] < 0) a[i + j + k] += mod;
			}
		}
	}
	if(!(~f))
	{
		int inv = ksm(n,mod - 2);
		rep(i,0,n-1) a[i] = 1ll * a[i] * inv % mod;
	}
}

inline void solve(int l,int r)
{
	if(l == r) return ;
	int mid = l + r >> 1;
	solve(l,mid);
	int m = mid - l + 2 + r - l,len,L = 0;
	for(len = 1;len <= m;len <<= 1) ++ L;
	rep(i,0,len-1) a[i] = b[i] = 0;
	rep(i,0,mid - l) a[i] = f[i + l];
	rep(i,0,r - l) b[r - l - i] = g[r - i - l];
	rep(i,0,len - 1) R[i] = (R[i>>1]>>1)|((i&1)<<(L-1));
	ntt(a,1,len); ntt(b,1,len);
	rep(i,0,len - 1) a[i] = 1ll * a[i] * b[i] % mod;
	ntt(a,-1,len); rep(i,mid+1,r) f[i] = (f[i] + a[i-l])%mod;
	solve(mid+1,r);
}

int main()
{
	int n;
	read(n);
	rep(i,1,n-1) read(g[i]);
	f[0] = 1; solve(0,n-1);
	rep(i,0,n-1) printf("%d ",f[i]);
	return 0;
}

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值