题目大意
求有多少对i<j<k
满足a[j]-a[i]=a[k]-a[j]
分块FFT
口胡里写过了。
枚举j,我们可以得到两边的生成函数。
只要卷积起来看2*a[j]项的系数就可以统计了。
这样做显然不行。
考虑分块。
对于i或k在块内的情况,用枚举来暴力统计。
而对于i与k均不在块内的情况,用FFT。
#include<cstdio>
#include<algorithm>
#include<cmath>
#define fo(i,a,b) for(i=a;i<=b;i++)
#define fd(i,a,b) for(i=a;i>=b;i--)
using namespace std;
typedef long long ll;
typedef double db;
const int maxn=120000+10,B=2000;
const db pi=acos(-1);
struct node{
db x,y;
friend node operator +(node a,node b){
node c;
c.x=a.x+b.x;c.y=a.y+b.y;
return c;
}
friend node operator -(node a,node b){
node c;
c.x=a.x-b.x;c.y=a.y-b.y;
return c;
}
friend node operator *(node a,node b){
node c;
c.x=a.x*b.x-a.y*b.y;c.y=a.x*b.y+a.y*b.x;
return c;
}
};
node tt[maxn],d[maxn],e[maxn],a[maxn],b[maxn],c[maxn],w[maxn];
int v[maxn],rev[maxn];
int i,j,k,l,r,t,n,m,wdc,mx,tot,len;
ll ans;
db ce;
int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9'){
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
void prepare(){
len=1;
while (len<mx*2) len*=2;
ce=log(len)/log(2);
fo(i,0,len-1){
int p=0;
for (int j=0,tp=i;j<ce;j++,tp/=2) p=(p<<1)+(tp%2);
rev[i]=p;
}
w[0].x=1;w[0].y=0;
w[1].x=cos(2*pi/len);w[1].y=sin(2*pi/len);
fo(i,2,len) w[i]=w[i-1]*w[1];
}
void DFT(node *a,int sig){
int i;
fo(i,0,len-1) tt[rev[i]]=a[i];
for (int m=2;m<=len;m*=2){
int half=m/2,bei=len/m;
fo(i,0,half-1){
node wi=sig>0?w[i*bei]:w[len-i*bei];
for (int j=i;j<len;j+=m){
node u=tt[j],v=tt[j+half]*wi;
tt[j]=u+v;
tt[j+half]=u-v;
}
}
}
if (sig==-1)
fo(i,0,len-1) tt[i].x/=len;
fo(i,0,len-1) a[i]=tt[i];
}
void FFT(node *a,node *b){
int i;
fo(i,0,len-1) c[i]=a[i],e[i]=b[i];
DFT(c,1);DFT(e,1);
fo(i,0,len-1) c[i]=c[i]*e[i];
DFT(c,-1);
}
int main(){
freopen("data.in","r",stdin);
n=read();
fo(i,1,n) v[i]=read(),b[v[i]].x++,mx=max(mx,v[i]);
prepare();
fo(wdc,1,(n-1)/B+1){
l=(wdc-1)*B+1;r=min(wdc*B,n);
fo(i,l,r) b[v[i]].x--;
FFT(a,b);
fo(i,l,r)
ans+=(ll)round(c[2*v[i]].x);
fo(i,l,r) a[v[i]].x++;
}
fo(i,1,n) a[v[i]].x=b[v[i]].x=0;
fo(i,1,n) b[v[i]].x++;
fo(wdc,1,(n-1)/B+1){
l=(wdc-1)*B+1;r=min(wdc*B,n);
fo(i,l,r) b[v[i]].x--;
fo(j,l,r)
fo(i,l,j-1)
if (2*v[j]-v[i]>=0) ans+=(ll)round(b[2*v[j]-v[i]].x);
fo(j,l,r){
if (j>l) a[v[j-1]].x++;
fo(k,j+1,r)
if (2*v[j]-v[k]>=0) ans+=(ll)round(a[2*v[j]-v[k]].x);
}
a[v[r]].x++;
}
printf("%lld\n",ans);
}


被折叠的 条评论
为什么被折叠?



