题目
思路来源
dls
题解
先离散化,把a和b搞进一个序列,然后排序去重,把a和b赋为rank值
然后遍历这个不重复的序列,两个大于1W就建2e5数列用NTT搞,否则直接暴力
既然要在ai和bj值相同,下标i+j==k时求i*j的值,
那么就把位置i和j分别放入值的vector里,对相同值的vector进行NTT
注意多项式乘法一定会遍历所有的情形,所以就令f(x)的第i位为i,g(x)的第j位为j
是一种很自然的做法,注意多项式乘法的定义
这个1W卡的真是刚刚好啊……昨天T了若干发
注意 NTT里面取模不要判if 直接模
上面的是有if的 下面的是没if的
NTT和FFT类似,处理整数取模的情况,
可以把大素数搞成它的原根,998244353的原根是3
FFT卡精度过不了
代码
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <queue>
#include <map>
#define g 3
using namespace std;
typedef long long ll;
const double PI = acos(-1.0);
const int mod=998244353;
const int maxn=1e5+10;
const int G=3;
ll inv,N;
ll n,m,tmp1[maxn*3],tmp2[maxn*3],rev[maxn*3],back[maxn*3];
vector<ll>q[maxn*2],p[maxn*2];
vector<ll>num;
ll a[maxn],b[maxn],c[maxn*2];
ll power(ll x,ll y)
{
ll res=1ll;
for(;y;y>>=1)
{
if(y&1)res=res*x%mod;
x=x*x%mod;
}
return res;
}
void init()
{
int len=0;
while((n+m+2)>=(1<<len))len++;
N=(1<<len);
inv=power(N,mod-2);
for(int i=0;i<N;i++)
{
ll pos=0;
ll temp=i;
for(int j=1;j<=len;j++)
{
pos<<=1;pos |= temp&1;temp>>=1;
}
back[i]=rev[i]=pos;
}
}
void init2()
{
for(int i=0;i<N;++i)
rev[i]=back[i];
}
void ntt(ll *a,ll n,ll re)
{
for(int i=0;i<n;i++)
{
if(rev[i]>i)
{
swap(a[i],a[rev[i]]);
}
}
for(int i=2;i<=n;i<<=1)
{
ll mid=i>>1;
ll wn=power(G,(mod-1)/i);
if(re) wn=power(wn,(mod-2));
for(int j=0;j<n;j+=i)
{
ll w=1;
for(int k=0;k<mid;k++)
{
ll temp1=a[j+k];
ll temp2=a[j+k+mid]*w%mod;
a[j+k]=(temp1+temp2)%mod;
a[j+k+mid]=(temp1-temp2);
a[j+k+mid]=(a[j+k+mid]%mod+mod)%mod;
w=w*wn%mod;
}
}
}
if(re)
{
for(int i=0;i<n;i++)
{
a[i]=(a[i]*inv)%mod;
}
}
}
inline void solve1(ll v,ll len1,ll len2)
{
for(ll i=0;i<len1;++i)
{
ll pos1=p[v][i];
for(ll j=0;j<len2;++j)
{
ll pos2=q[v][j];
c[pos1+pos2]=(c[pos1+pos2]+pos1*pos2)%mod;
}
}
}
void solve2(ll v,ll len1,ll len2)
{
init2();
for(int i=0;i<N;++i)
tmp1[i]=tmp2[i]=0;
for(int i=0;i<len1;++i)tmp1[p[v][i]]=p[v][i];
for(int i=0;i<len2;++i)tmp2[q[v][i]]=q[v][i];
ntt(tmp1,N,0);
ntt(tmp2,N,0);
for(int i=0;i<N;++i)
{
tmp1[i]=tmp1[i]*tmp2[i]%mod;
}
ntt(tmp1,N,1);
for(int i=0;i<=p[v][len1-1]+q[v][len2-1];i++)
{
c[i]=(c[i]+tmp1[i])%mod;
}
}
int main()
{
scanf("%lld%lld",&n,&m);
init();
for(ll i=0;i<=n;++i)scanf("%lld",&a[i]),num.push_back(a[i]);
for(ll i=0;i<=m;++i)scanf("%lld",&b[i]),num.push_back(b[i]);
sort(num.begin(),num.end());
num.resize(unique(num.begin(),num.end())-num.begin());
for(ll i=0;i<=n;++i)a[i]=lower_bound(num.begin(),num.end(),a[i])-num.begin(),p[a[i]].push_back(i);
for(ll i=0;i<=m;++i)b[i]=lower_bound(num.begin(),num.end(),b[i])-num.begin(),q[b[i]].push_back(i);
for(ll i=0;i<num.size();++i)//保证无重复,rank==i
{
ll len1=p[i].size(),len2=q[i].size();
if(!len1||!len2)continue;
//printf("%lld:\n",i);
//for(int j=0;j<len1;++j)
//printf("%lld ",p[i][j]);puts("");
///for(int j=0;j<len2;++j)
///printf("%lld ",q[i][j]);puts("");
if(len1+len2<=10000)solve1(i,len1,len2);
else solve2(i,len1,len2);
}
for(int i=0;i<=n+m;++i)
printf("%lld%c",c[i],i==n+m?'\n':' ');
return 0;
}