题目大意
给定数组{An}\{A_n\}{An}和{Bn}\{B_n\}{Bn},求∑1≤i<j≤nmin(Ai⊕Aj,Bi⊕Bj)\sum_{1\le i<j\le n}\min(A_i\oplus A_j,B_i\oplus B_j)∑1≤i<j≤nmin(Ai⊕Aj,Bi⊕Bj),n≤250000n\le 250000n≤250000
题解
这题还是很妙的,要求的东西看上去没什么关联,所以想办法找点关联,发现Ai⊕Aj⊕Bi⊕BjA_i\oplus A_j\oplus B_i\oplus B_jAi⊕Aj⊕Bi⊕Bj的最高位111的位置是Ai⊕AjA_i\oplus A_jAi⊕Aj和Bi⊕BjB_i\oplus B_jBi⊕Bj最高位不同的位置
设Ci=Ai⊕BiC_i=A_i\oplus B_iCi=Ai⊕Bi,考虑根据上面那条性质分治,每次把在当前位数depdepdep为000的CiC_iCi放到一个集合中,为111的CiC_iCi放到另一个集合中,那么这两个集合之间Ai⊕AjA_i\oplus A_jAi⊕Aj和Bi⊕BjB_i\oplus B_jBi⊕Bj不同的最高位就是depdepdep,判断它们的大小关系只需要判断它们在depdepdep位上的大小关系,讨论一下算出贡献,再分治下去,分治的每一层算贡献的复杂度为O(nlogn)O(n\log n)O(nlogn),分治深度为O(logn)O(\log n)O(logn),所以总时间复杂度为O(nlog2n)O(n\log^2n)O(nlog2n)
code
#include<cstdio>
#include<algorithm>
#define ll long long
using namespace std;
void read(int &res)
{
res=0;char ch=getchar();
while(ch<'0'||ch>'9') ch=getchar();
while('0'<=ch&&ch<='9') res=(res<<1)+(res<<3)+(ch^48),ch=getchar();
}
const int N=3e5+100,Bl=19;
int n,a[N+10],b[N+10],p[N+10],p0[N+10],p1[N+10],c[2][Bl+10],A[2][2][Bl+10],B[2][2][Bl+10];
ll solve(int l,int r,int dep=Bl)
{
if(l>r) return 0;
if(dep<0)
{
ll res=0;
for(int i=0;i<=Bl;i++) c[0][i]=c[1][i]=0;
for(int i=l;i<=r;i++)
{
for(int j=0,k;j<=Bl;j++) k=((a[p[i]]&(1<<j))!=0),res+=1ll*c[k^1][j]*(1<<j);
for(int j=0,k;j<=Bl;j++) k=((a[p[i]]&(1<<j))!=0),c[k][j]++;
}
return res;
}
ll res=0;p0[0]=p1[0]=0;
for(int i=l;i<=r;i++)
{
if((a[p[i]]^b[p[i]])&(1<<dep)) p0[++p0[0]]=p[i];
else p1[++p1[0]]=p[i];
}
for(int j=0;j<=Bl;j++) for(int i=0;i<=1;i++) for(int k=0;k<=1;k++) A[i][k][j]=B[i][k][j]=0;
for(int i=1,k1;i<=p0[0];i++)
{
k1=((a[p0[i]]&(1<<dep))!=0);
for(int j=0,k;j<=Bl;j++)
{
k=((a[p0[i]]&(1<<j))!=0),A[k1][k][j]++,
k=((b[p0[i]]&(1<<j))!=0),B[k1][k][j]++;
}
}
for(int i=1,k1;i<=p1[0];i++)
{
k1=((a[p1[i]]&(1<<dep))!=0);
for(int j=0,k;j<=Bl;j++)
{
k=((a[p1[i]]&(1<<j))!=0);
res+=1ll*A[k1][k^1][j]*(1<<j);
k=((b[p1[i]]&(1<<j))!=0);
res+=1ll*B[k1^1][k^1][j]*(1<<j);
}
}
for(int i=1;i<=p0[0];i++) p[l+i-1]=p0[i];
for(int i=1;i<=p1[0];i++) p[l+p0[0]-1+i]=p1[i];
int mid=l+p0[0]-1;
res+=solve(l,mid,dep-1)+solve(mid+1,r,dep-1);
return res;
}
int main()
{
read(n);
for(int i=1;i<=n;i++) p[i]=i;
for(int i=1;i<=n;i++) read(a[i]);
for(int i=1;i<=n;i++) read(b[i]);
printf("%lld",solve(1,n,Bl));
return 0;
}