解释一下官方题解:
因为y中为0的位置,x中若为1,肯定a,b该位置上都为1,若为0, a和b该位置都为0,即y中为0的位置,确定x后,a和b的该位置上的值都是确定的, 而y中为1的位置,x中肯定为1,a,b有 a1b0,a0b1两种选择。于是最后有
2bit(y)
种选择方法。
重写式子的前三步很容易理解
最后一步可以证明
若x xor y =k 则 x and y =y <==> bit[x]-bit[y]=bit[k]
从左向右 显然
从右向左 k中为0的位置,x和y一定相同。 而k中为1的位置,x和y的位置肯定不同,这些位置有bit[k]个, 只有当这些位置 x均为1,y均为0时 可以使条件成立 ,此时x and y = y。
如下代码实际执行了一个按位中1个数分类,然后根据位数之差FWT的过程
因为FWT可以线性相加,c[i]数组其实同时执行了b和a位数差为i的所有的FWT,它们的和累加在了一起。
注释里是错误写法,它没有把之前对应数组的所有FWT后的值累加到一个数组中
#include <bits/stdc++.h>
using namespace std;
const int MAXN=1<<20;
const int MOD=998244353;
const int inv2=(MOD+1)>>1;
void fwt(int a[],int len,int mode)
{
if(mode)
for(int d=1;d<len;d<<=1)
for(int m=d<<1,i=0;i<len;i+=m)
for(int j=0;j<d;j++)
{
int x=a[i+j],y=a[i+j+d];
a[i+j]=(x+y)%MOD,a[i+j+d]=(x-y)%MOD;
}
else
for(int d=1;d<len;d<<=1)
for(int m=d<<1,i=0;i<len;i+=m)
for(int j=0;j<d;j++)
{
int x=a[i+j],y=a[i+j+d];
a[i+j]=1ll*(x+y)*inv2%MOD,a[i+j+d]=1ll*(x-y)*inv2%MOD;
}
}
int m,len,bit[MAXN];
void init()
{
len=1<<m;
for(int i=0;i<len;i++)
bit[i]=bit[i>>1]+(i&1);
}
int a[22][MAXN],b[22][MAXN],c[22][MAXN];
int main()
{
scanf("%d",&m);
init();
int ta,tb;
for(int i=0;i<len;i++)
{
scanf("%d",&ta);
a[bit[i]][i]=1ll*ta*(1<<bit[i])%MOD;
}
for(int i=0;i<len;i++)
{
scanf("%d",&ta);
b[bit[i]][i]=ta;
}
for(int i=0;i<=m;i++)
{
fwt(a[i],len,1);
fwt(b[i],len,1);
}
for(int i=0;i<=m;i++)
for(int j=i;j<=m;j++)
for(int k=0;k<len;k++)
c[j-i][k]=(c[j-i][k]+1ll*b[j][k]*a[i][k]%MOD)%MOD;
// for(int k=0;k<len;k++)
// for(int i=bit[k];i<=m;i++)
// {
// c[bit[k]][k]=c[bit[k]][k]+1ll*b[i][k]*a[i-bit[k]][k]%MOD;
// }
for(int i=0;i<=m;i++)
{
fwt(c[i],len,0);
for(int j=0;j<len;j++)
if(c[i][j]<0)
c[i][j]+=MOD;
}
long long ans=0,tmp=1;
for(int i=0;i<len;i++)
{
ans+=c[bit[i]][i]*tmp%MOD;
tmp=tmp*1526%MOD;
}
ans%=MOD;
printf("%lld\n",ans);
return 0;
}