这道题是叉姐讲义的第一道题,题意要求对于所有S,三个编号互不相同的数之和等于S的方案数。
如果是没有互不相同这个要求,那么题目就非常简单,直接用FFT求出x的s次方的系数即可。
但是如果要求互不相同,那么就考虑需要用到容斥定理。
具体式子推导不太好写,讲义里写的很清楚明白~
因此我们需要事先统计出两个数和三个数相同能取的所有方案数,显然和为s + s的方案数等于大小为s的数的个数,三个数同理。
那么接下来就可以通过FFT来求出答案的个数。
下面是代码
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long LL;
const int N = 320005;
const double pi = acos(-1.0);
struct Complex{
double r,i;
Complex(){}
Complex(double r,double i):r(r),i(i){}
Complex operator + (const Complex &a)const{
return Complex(r + a.r,i + a.i);
}
Complex operator - (const Complex &a)const{
return Complex(r - a.r,i - a.i);
}
Complex operator * (const Complex &a)const{
return Complex(r * a.r - i * a.i,r * a.i + i * a.r);
}
Complex operator * (const double &a)const{
return Complex(r * a,i * a);
}
Complex operator / (const double &n)const{
return Complex(r / n,i / n);
}
void init(){r = i = 0;}
};
int cnt[N],tn,val[N];
Complex x1[N],x2[N],x3[N],tmp[N];
Complex r1[N],r2[N];
int rev(int x){
int res = 0;
for(int i = 0 ; i < tn ; i ++){
if(x & 1) res += 1 << tn - i - 1;
x >>= 1;
}
return res;
}
void fft(Complex A[],int n,int op){
for(int i = 0 ; i < n ; i ++) tmp[ rev(i) ] = A[i];
for(int i = 0 ; i < n ; i ++) A[i] = tmp[i];
for(int i = 1 ; (1 << i) <= n ; i ++){
int m = 1 << i;
for(int k = 0 ; k < n ; k += m){
Complex wn(cos(op * 2 * pi / m),sin(op * 2 * pi / m));
Complex w(1,0),u,t;
for(int j = 0 ; j < m / 2 ; j ++){
u = A[k + j];
t = w * A[k + j + m / 2];
A[k + j] = u + t;
A[k + j + m / 2] = u - t;
w = w * wn;
}
}
}
if(op == -1) for(int i = 0 ; i < n ; i ++) A[i] = A[i] / n;
}
void solve(int n){
int Min,Max,len;
for(int i = 1 ; i <= n ; i ++) scanf("%d",&val[i]);
sort(val + 1,val + n + 1);
Min = val[1];
Max = val[n] * 3;
memset(cnt,0,sizeof(cnt));
for(int i = 1 ; i <= n ; i ++) val[i] -= Min;
for(int i = 1 ; i <= n ; i ++) cnt[ val[i] ] ++;
tn = ceil(log(Max + 0.0) / log(2.0)) + 1;
len = 1 << tn;
for(int i = 0 ; i < len ; i ++){
x1[i] = Complex(cnt[i],0);
if(i % 2 == 0) x2[i] = Complex(cnt[i / 2],0);
else x2[i].init();
if(i % 3 == 0) x3[i] = Complex(cnt[i / 3],0);
else x3[i].init();
}
fft(x1,len,1);
fft(x2,len,1);
for(int i = 0 ; i < len ; i ++){
r1[i] = x1[i] * x1[i] * x1[i] - x1[i] * x2[i] * 3.0;
}
fft(r1,len,-1);
for(int i = 0 ; i < len ; i ++){
LL res = (LL(r1[i].r + 0.5) + 2 * LL(x3[i].r)) / 6;
if(res > 0)
printf("%d : %lld\n",i + Min * 3,res);
}
return;
}
int main()
{
int n;
while(scanf("%d",&n) != EOF) solve(n);
return 0;
}