合影
【题目名称】有 n 个人,第 i 个人的外貌可以用权值 ai 表示。两个人可以照一张合影,i 与 j 合影的美观度是 ai^aj。同一对人只能照一次,要照 m 张,求最大美观度和。
【输入格式】
第一行两个数 n、m。接下来一行 n 个数,第 i 个数表示 ai。
【输出格式】
一行一个数,和的最大值。
【输入样例】
3 2
6 7 8
【输出样例】
29
【数据范围】
30%的数据,1≤n、m≤1000
60%的数据,1≤n、m≤50000
100%的数据,1≤n≤50000,1≤ai≤10^9
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cstdio>
#include<cmath>
#define N 800003
#define LL long long
using namespace std;
int n,m,a[N],ch[N][3],size[N],sum[N][41],tot;
int num[50003][100],num1[100],isend[N],maxn;
int deep[N],ki[N],q[N];
LL ok,mark,base[100];
bool pd(int x)
{
int t=0; int xx=x;
memset(num1,0,sizeof(num1));
while (xx)
{
t++;
if (xx&1) num1[t]=1;
else num1[t]=0;
xx>>=1;
}
int ans[100];
memset(ans,0,sizeof(ans));
int tot=0;
for (int i=1;i<=n;i++)
{
int now=0;
for (int j=32;j>=1;j--)
{
int xx;
if (num1[j]) xx=num[i][j]^1;
else
{
xx=num[i][j];
if (ch[now][xx^1])
tot+=size[ch[now][xx^1]];
}
if (!ch[now][xx]) break;
now=ch[now][xx];
}
if (isend[now]) tot+=size[now];
}
if ((tot/2)>=m) {
mark=max(mark,(LL)x);
return true;
}
else return false;
}
LL calc(int x)
{
int t=0; int xx=x;
memset(num1,0,sizeof(num1));
while (xx)
{
t++;
if (xx&1) num1[t]=1;
else num1[t]=0;
xx>>=1;
}
int ans[100];
memset(ans,0,sizeof(ans));
int tot=0;
for (int i=1;i<=n;i++)
{
int now=0;
for (int j=32;j>=1;j--)
{
int xx;
if (num1[j]) xx=num[i][j]^1;
else
{
xx=num[i][j];
if (ch[now][xx^1])
{
int t=size[ch[now][xx^1]];
tot+=t;
for (int k=1;k<=32;k++)
if (num[i][k]^1)
ans[k]+=sum[ch[now][xx^1]][k];
else
ans[k]+=(t-sum[ch[now][xx^1]][k]);
}
}
if (!ch[now][xx]) break;
now=ch[now][xx];
}
if (isend[now])
{
for (int k=1;k<=32;k++)
if (num[i][k]^1)
ans[k]+=sum[now][k];
else
ans[k]+=(size[now]-sum[now][k]);
tot+=size[now];
}
}
LL s=0;
for (int i=1;i<=32;i++)
{
LL xx=(LL)ans[i]*(LL)base[i-1];
s+=(LL)xx;
}
s-=(LL)(tot-m*2)*(LL)x;
return s/2;
}
int main()
{
freopen("photo.in","r",stdin);
freopen("photo.out","w",stdout);
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++)
scanf("%d",&a[i]);
base[0]=1;
for (int i=1;i<=32;i++) base[i]=base[i-1]*2;
for (int i=1;i<=n;i++)
{
int x=a[i]; int t=0;
while (x)
{
t++; maxn=max(maxn,t);
if (x&1) num[i][t]=1;
else num[i][t]=0;
x>>=1;
}
int now=0; deep[now]=0; int cnt=0;
for (int j=32;j>=1;j--)
{
int x=num[i][j];
if (!ch[now][x]) ch[now][x]=++tot;
size[now]++;
q[++cnt]=now;
now=ch[now][x];
}
q[++cnt]=now; size[now]++;
for (int j=1;j<=cnt;j++)
{
for (int k=1;k<=32;k++)
if (num[i][k])
sum[q[j]][k]++;
}
isend[now]=1;
}
int l=0,r=base[maxn];
while (l<=r)
{
int mid=(l+r)/2;
if (pd(mid)) l=mid+1;
else r=mid-1;
}
//cout<<mark<<endl;
printf("%I64d\n",calc(mark));
}