题意:给定三个长度 n 的数组 a , b , c ,求有多少个(i,j,k)满足 ai , bi , ci 三个数构成不严格三角形(可以是直线)
分析:FFT裸题(没学过的同志走这边:多项式与快速傅立叶变换_ZLH_HHHH的博客-优快云博客,我也才稍微整明白一些,这一篇蛮适合入门的)
正难则反,考虑先枚举出 A + B >= C 的数量再去掉不合法的,则不合法的情况有三种
①:A > C , B <= C;
②:A <= C , B > C;
③:A > C , B > C;
还有就是 2 个相同的 和 3 个相同的去重
A,B,C 的值域皆为 [1,m=100000] ,则枚举 A +B 的情况显然可以用 FFT 加速,复杂度为 O(m*logm) ,但考虑题目有 T <=100 组数据,FFT 复杂度拉满不太理想;不过题目提醒 n > 1000 的数据最多 20 组 ,显然,n > 1000 的时候用 FFT ,n <= 1000 的情况下则可以 O(n^2) 枚举其中两组求出第三个数组中合法的数字数量;
代码:(copy了一遍官方题解,蒟蒻一枚)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const double PI=acos(-1.0);
struct Complex
{
double r,i;
Complex(double _r=0,double _i=0)
{
r=_r,i=_i;
}
Complex operator + (const Complex &b)
{
return Complex(r+b.r,i+b.i);
}
Complex operator - (const Complex &b)
{
return Complex(r-b.r,i-b.i);
}
Complex operator * (const Complex &b)
{
return Complex(r*b.r-i*b.i,r*b.i+i*b.r);
}
};
const int N = 1e5+10;
struct node
{
int v,id;
}e[N*3];
bool cmp(node a,node b){return a.v<b.v;}
Complex x1[N*4];
Complex x2[N*4];
int a[3][N];
ll num[N*4];
ll sum[N*4];
void change(Complex y[],int len)
{
int i,j,k;
for(i=1,j=len/2;i<len-1;i++)
{
if(i<j) swap(y[i],y[j]);
k=len/2;
while(j>=k)
{
j-=k;
k/=2;
}
if(j<k) j+=k;
}
}
void fft(Complex y[],int len,int on)
{
change(y,len);
for(int h=2;h<=len;h<<=1)
{
Complex wn(cos(-on*2*PI/h),sin(-on*2*PI/h));
for(int j=0;j<len;j+=h)
{
Complex w(1,0);
for(int k=j;k< j+h/2;k++)
{
Complex u=y[k];
Complex t=w*y[k+h/2];
y[k]=u+t;
y[k+h/2]=u-t;
w=w*wn;
}
}
}
if(on==-1)
for(int i=0;i<len;i++)
y[i].r/=len;
}
ll pre_fft(int id,int n)
{
int id1=(id+1)%3;
int id2=(id+2)%3;
int len1=0;
for(int i=0;i<n;i++)
{
num[a[id1][i]]++;
len1=max(len1,a[id1][i]+1);
len1=max(len1,a[id2][i]+1);
}
int len=1;
while(len<len1*2) len<<=1;
for(int i=0;i<len1;i++) x1[i]=Complex(num[i],0);
for(int i=len1;i<len;i++) x1[i]=Complex(0,0);
fft(x1,len,1);
for(int i=0;i<n;i++) num[a[id1][i]]--;
for(int i=0;i<n;i++) num[a[id2][i]]++;
for(int i=0;i<len1;i++) x2[i]=Complex(num[i],0);
for(int i=len1;i<len;i++) x2[i]=Complex(0,0);
fft(x2,len,1);
for(int i=0;i<len;i++) x1[i]=x1[i]*x2[i];
fft(x1,len,-1);
for(int i=0;i<len;i++) num[i]=(ll)(x1[i].r+0.5);
sum[0]=0;
for(int i=1;i<=2*len1;i++) sum[i]=sum[i-1]+num[i];
for(int i=0;i<len;i++) num[i]=0;
ll ans=0;
for(int i=0;i<n;i++)
{
ans+=sum[2*len1]-sum[a[id][i]-1];
}
return ans;
}
ll solve2(int n)
{
ll ans=0;
ans+=pre_fft(0,n);
ans+=pre_fft(1,n);
ans+=pre_fft(2,n);
for(int i=0;i<n;i++)
{
e[i].v=a[0][i];
e[i].id=0;
}
for(int i=n;i<2*n;i++)
{
e[i].v=a[1][i-n];
e[i].id=1;
}
for(int i=2*n;i<3*n;i++)
{
e[i].v=a[2][i-2*n];
e[i].id=2;
}
sort(e,e+3*n,cmp);
int cnt[3]={0};
for(int i=0;i<3*n;i++)
{
int id=e[i].id;
int id1=(id+1)%3;
int id2=(id+2)%3;
ans-=(ll)(n-cnt[id1])*cnt[id2];
ans-=(ll)cnt[id1]*(n-cnt[id2]);
ans-=(ll)(n-cnt[id1])*(n-cnt[id2]);
cnt[id]++;
}
return ans;
}
ll solve1(int n) //n<2000
{
sort(a[0],a[0]+n);
sort(a[1],a[1]+n);
sort(a[2],a[2]+n);
for(int i=0;i<n;i++)
{
e[i].v=a[0][i];
e[i].id=0;
}
for(int i=n;i<2*n;i++)
{
e[i].v=a[1][i-n];
e[i].id=1;
}
for(int i=2*n;i<3*n;i++)
{
e[i].v=a[2][i-2*n];
e[i].id=2;
}
sort(e,e+3*n,cmp);
int cnt[3]={0};
ll ans=0;
for(int i=0;i<3*n;i++)
{
/*
把所有数从小到大排完序后,假设当前数为 C 为长边,则短边 A,B 只会在前面
,选择从小到大枚举 A , 则合法 B 的数量会相应变多
*/
int id=e[i].id;
int id1=(id+1)%3;
int id2=(id+2)%3;
int tmp=cnt[id2];
for(int j=0;j<cnt[id1];j++)
{
while(tmp>0&&a[id1][j]+a[id2][tmp-1]>=e[i].v) tmp--;
ans+=cnt[id2]-tmp;
}
cnt[id]++;
}
return ans;
}
int main()
{
int T;
scanf("%d",&T);
int n;
for(int cas=1;cas<=T;cas++)
{
scanf("%d",&n);
for(int i=0;i<3;i++)
for(int j=0;j<n;j++)
scanf("%d",&a[i][j]);
ll ans;
if(n<=1000) ans=solve1(n);
else ans=solve2(n);
printf("Case #%d: %lld\n",cas,ans);
}
return 0;
}