#34. 多项式乘法
统计这是一道模板题。
给你两个多项式,请输出乘起来后的多项式。
输入格式
第一行两个整数 nn 和 mm,分别表示两个多项式的次数。
第二行 n+1n+1 个整数,表示第一个多项式的 00 到 nn 次项系数。
第三行 m+1m+1 个整数,表示第二个多项式的 00 到 mm 次项系数。
输出格式
一行 n+m+1n+m+1 个整数,表示乘起来后的多项式的 00 到 n+mn+m 次项系数。
样例一
input
1 2 1 2 1 2 1
output
1 4 5 2
模板题。。。
code:
#include<iostream>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<algorithm>
using namespace std;
typedef long long ll;
const int maxn = 1e6+10;
struct cp{
double r,i;
cp(double _r=0,double _i=0):r(_r),i(_i){}
cp operator + (cp x) { return cp(r+x.r,i+x.i);}
cp operator - (cp x) { return cp(r-x.r,i-x.i); }
cp operator * (cp x) { return cp(r*x.r-i*x.i,r*x.i+i*x.r);}
};
cp a[maxn],b[maxn],A[maxn],x,y,c[maxn];
char s1[maxn],s2[maxn];
int sum[maxn],a1[maxn],a2[maxn],dig[maxn];
int len1,len2,rev[maxn],N,L;
void FFT(cp a[],int flag){
//flag=1为 dft,flag=-1为idft
for(int i=0;i<N;i++) A[i]=a[rev[i]];
for(int i=0;i<N;i++) a[i]=A[i];
for(int i=2;i<=N;i<<=1){
//单位根
cp wn(cos(2*M_PI/i),flag*sin(2*M_PI/i));
for(int k=0;k<N;k+=i){
cp w(1,0);
for(int j=k;j<k+i/2;j++){
x=a[j];
y=a[j+i/2]*w;
a[j]=x+y;
a[j+i/2]=x-y;
w=w*wn;
}
}
}
if(flag==-1) for(int i=0;i<N;i++) a[i].r/=N;
}
int main(){
scanf("%d%d",&len1,&len2);
len1++; len2++;
for(N=1,L=0;N<max(len1,len2);N<<=1,L++); N<<=1;L++;
//rev为求二进制反转后顺序
for(int i=0;i<N;i++){
int len = 0;
for(int t=i;t;t>>=1) dig[len++]=t&1;
for(int j=0;j<L;j++) rev[i]=(rev[i]<<1)|dig[j];
}
for(int i=0;i<len1;i++) scanf("%d",&a1[i]);
for(int i=0;i<len2;i++) scanf("%d",&a2[i]);
for(int i=0;i<N;i++) a[i]=cp(a1[i]);
for(int i=0;i<N;i++) b[i]=cp(a2[i]);
FFT(a,1);
FFT(b,1);
for(int i=0;i<N;i++) c[i]=a[i]*b[i];
FFT(c,-1);
//四舍五入避免精度误差
for(int i=0;i<N;i++) sum[i]=c[i].r+0.5;
int l=len1+len2-2;
for(int i=0;i<=l;i++) {
if(i) printf(" ");
printf("%d",sum[i]);
}
putchar('\n');
return 0;
}
高精度乘法:
http://www.51nod.com/onlineJudge/questionCode.html#!problemId=1027
code:
#include<iostream>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<algorithm>
using namespace std;
typedef long long ll;
const int maxn = 1e6+10;
struct cp{
double r,i;
cp(double _r=0,double _i=0):r(_r),i(_i){}
cp operator + (cp x) { return cp(r+x.r,i+x.i);}
cp operator - (cp x) { return cp(r-x.r,i-x.i); }
cp operator * (cp x) { return cp(r*x.r-i*x.i,r*x.i+i*x.r);}
};
cp a[maxn],b[maxn],A[maxn],x,y,c[maxn];
char s1[maxn],s2[maxn];
int sum[maxn],a1[maxn],a2[maxn],dig[maxn];
int len1,len2,rev[maxn],N,L;
void FFT(cp a[],int flag){
//flag=1为 dft,flag=-1为idft
for(int i=0;i<N;i++) A[i]=a[rev[i]];
for(int i=0;i<N;i++) a[i]=A[i];
for(int i=2;i<=N;i<<=1){
//单位根
cp wn(cos(2*M_PI/i),flag*sin(2*M_PI/i));
for(int k=0;k<N;k+=i){
cp w(1,0);
for(int j=k;j<k+i/2;j++){
x=a[j];
y=a[j+i/2]*w;
a[j]=x+y;
a[j+i/2]=x-y;
w=w*wn;
}
}
}
if(flag==-1) for(int i=0;i<N;i++) a[i].r/=N;
}
int main(){
scanf("%s%s",s1,s2);
len1=strlen(s1);
len2=strlen(s2);
for(N=1,L=0;N<max(len1,len2);N<<=1,L++); N<<=1;L++;
//rev为求二进制反转后顺序
for(int i=0;i<N;i++){
int len = 0;
for(int t=i;t;t>>=1) dig[len++]=t&1;
for(int j=0;j<L;j++) rev[i]=(rev[i]<<1)|dig[j];
}
for(int i=0;i<len1;i++) a1[len1-i-1]=s1[i]-'0';
for(int i=0;i<len2;i++) a2[len2-i-1]=s2[i]-'0';
for(int i=0;i<N;i++) a[i]=cp(a1[i]);
for(int i=0;i<N;i++) b[i]=cp(a2[i]);
FFT(a,1);
FFT(b,1);
for(int i=0;i<N;i++) c[i]=a[i]*b[i];
FFT(c,-1);
//四舍五入避免精度误差
for(int i=0;i<N;i++) sum[i]=c[i].r+0.5;
for(int i=0;i<N;i++){
sum[i+1]+=sum[i]/10;
sum[i]%=10;
}
int l=len1+len2-1;
while(sum[l]==0&&l) l--;
for(int i=l;i>=0;i--) putchar(sum[i]+'0');
putchar('\n');
return 0;
}
update:
51nod 大数乘法
稍微封装一下:
FFT:
#define _USE_MATH_DEFINES
#include <iostream>
#include <cstring>
#include <string>
#include <algorithm>
#include <cstdio>
#include <cmath>
using namespace std;
const int maxn = 1e6 + 10;
const int mod = 998244353;
typedef long long ll;
struct cp{
double r,i;
cp(double _r=0,double _i=0):r(_r),i(_i){}
cp operator + (cp x) { return cp(r+x.r,i+x.i);}
cp operator - (cp x) { return cp(r-x.r,i-x.i); }
cp operator * (cp x) { return cp(r*x.r-i*x.i,r*x.i+i*x.r);}
};
struct FTT
{
int rev[maxn], dig[maxn];
int N, L;
void init_rev(int n)
{
for(N=1,L=0;N<n;N<<=1,L++); N<<=1;L++;
for (int i = 0; i < N; i++)
{
rev[i]=0;
int len = 0;
for (int t = i; t; t >>= 1)
dig[len++] = t & 1;
for (int j = 0; j < L; j++)
rev[i] = (rev[i] << 1) | dig[j];
}
}
void DFT(cp a[], int flag)
{
for (int i = 0; i < N; i++)
if (i<rev[i])
swap(a[i], a[rev[i]]);
for (int l = 2; l <= N; l <<= 1)
{
cp wn(cos(2*M_PI/l),flag*sin(2*M_PI/l));
for (int k = 0; k < N; k += l)
{
cp w(1,0);
cp x,y;
for (int j = k; j < k + l / 2; j++)
{
x = a[j];
y = a[j + l / 2] * w;
a[j] = x + y;
a[j + l / 2] = x-y;
w = w * wn;
}
}
}
if (flag == -1) for(int i=0;i<N;i++) a[i].r/=N;
}
void mul(cp a[],cp b[],cp c[]){
DFT(a,1); DFT(b,1);
for(int i=0;i<N;i++) c[i]=a[i]*b[i];
DFT(c,-1);
}
}fft;
cp a[maxn],b[maxn],c[maxn];
int a1[maxn],a2[maxn];
char s1[maxn],s2[maxn];
int sum[maxn];
int main(){
scanf("%s%s",s1,s2);
int n1=strlen(s1);
int n2=strlen(s2);
for(int i=0;i<n1;i++) a1[i] = s1[n1-i-1]-'0';
for(int i=0;i<n2;i++) a2[i] = s2[n2-i-1]-'0';
fft.init_rev(max(n1,n2));
for(int i=0;i<fft.N;i++) a[i]=cp(a1[i]);
for(int i=0;i<fft.N;i++) b[i]=cp(a2[i]);
fft.mul(a,b,c);
for(int i=0;i<fft.N;i++) sum[i]=c[i].r+0.5;
for(int i=0;i<fft.N;i++) {
sum[i+1]+=sum[i]/10;
sum[i]%=10;
}
int l=fft.N;
while(sum[l]==0&&l) l--;
for(int i=l;i>=0;i--) putchar(sum[i]+'0');
puts("");
return 0;
}
NTT:
#include <iostream>
#include <cstring>
#include <string>
#include <algorithm>
using namespace std;
const int maxn = 1e6 + 10;
const int mod = 998244353;
typedef long long ll;
ll qpow(ll a, ll b)
{
ll sum = 1;
while (b)
{
if (b & 1) sum = sum * a % mod;
b >>= 1;
a = a * a % mod;
}
return sum;
}
ll Inv(ll a, ll _mod)
{
return qpow(a, _mod - 2);
}
struct NTT
{
int rev[maxn], dig[maxn];
int N, L, P;
ll g;
void init_rev(int n)
{
//初始化原根
g = 3;
P = mod;
for(N=1,L=0;N<n;N<<=1,L++); N<<=1;L++;
for (int i = 0; i < N; i++)
{
rev[i]=0;
int len = 0;
for (int t = i; t; t >>= 1)
dig[len++] = t & 1;
for (int j = 0; j < L; j++)
rev[i] = (rev[i] << 1) | dig[j];
}
}
void DFT(ll a[], int flag)
{
for (int i = 0; i < N; i++)
if (i<rev[i])
swap(a[i], a[rev[i]]);
for (int l = 2; l <= N; l <<= 1)
{
ll wn;
if (flag == 1)
wn = qpow(g, (P - 1) / l);
else
wn = qpow(g, P - 1 - (P - 1) / l);
for (int k = 0; k < N; k += l)
{
ll w = 1;
ll x, y;
for (int j = k; j < k + l / 2; j++)
{
x = a[j];
y = a[j + l / 2] * w % P;
a[j] = (x + y) % P;
a[j + l / 2] = (x - y + P) % P;
w = w * wn % P;
}
}
}
if (flag == -1)
{
ll inv = Inv(N, P);
for (int i = 0; i < N; i++)
a[i] = a[i] * inv % P;
}
}
void mul(ll a[],ll b[],ll c[],int m){
init_rev(m);
DFT(a,1); DFT(b,1);
for(int i=0;i<N;i++) c[i]=a[i]*b[i];
DFT(c,-1);
}
}ntt;
char s1[maxn],s2[maxn];
ll a[maxn],b[maxn],c[maxn];
int main(){
// freopen("in.txt","r",stdin);
scanf("%s%s",s1,s2);
int n1=strlen(s1);
int n2=strlen(s2);
for(int i=0;i<n1;i++) a[i]=s1[n1-i-1]-'0';
for(int i=0;i<n2;i++) b[i]=s2[n2-i-1]-'0';
ntt.mul(a,b,c,max(n1,n2));
for(int i=0;i<ntt.N;i++) {
c[i+1]+=c[i]/10;
c[i]%=10;
}
int l=ntt.N;
while(c[l]==0&&l) l--;
for(int i=l;i>=0;i--) printf("%lld",c[i]);
puts("");
return 0;
}