链接:H Hash Function
题意:
给定一个长度为 n ( n ≤ 500000 n\leq 500000 n≤500000) 的数组 a,数组中元素互不相同 , 现要求出一个最小的 x , 使得所有的数模 x后仍互不相同。
思路:
- 根据同余定理,这个数 x 肯定不能是数组a中任意两个数的差值和他们的因子。所以我们要求出数组中所有数两两之间的差值,如果直接求肯定要o( n 2 n^2 n2),可以用 FFT 加速这个过程。
- FFT本身是求出两个多项式相乘后的系数,那么怎么用它来求数组a中两两的差有哪些呢。假设我们现在有两个多项式 s1 , s2。如果数组 a 中存在数字 5 和 3, 我们让多项式 s1中 x 5 x^5 x5的系数为1 ,s2中 x − 3 x^{-3} x−3的系数为1 ,两个多项式相乘后,我们会发现 x 2 x^2 x2的系数会为1,幂次 2 就是我们要求的差值。所以最后我们只需要知道哪些幂次出现过,就代表差值存在了。
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 4e6 + 5;
const int num = 500002;
const ll inf = (1ll << 60);
const ll mod = 998244353;
const double Pi = acos(-1.0);
int n, m, r[maxn];
int vis[maxn];
struct Complex{
double x, y;
Complex(double xx = 0, double yy = 0) : x(xx), y(yy) {}
Complex operator+(Complex b){
return Complex(x + b.x, y + b.y);
}
Complex operator-(Complex b){
return Complex(x - b.x, y - b.y);
}
Complex operator*(Complex b){
return Complex(x * b.x - y * b.y, x * b.y + y * b.x);
}
};
// 最高n,m次, 0次为常数
Complex a[maxn], b[maxn];
Complex c[maxn];
void fft(Complex *cm, int cnum, int tag){
for(int i = 0 ; i <= cnum - 1; i ++){
if (i < r[i]){
swap(cm[i], cm[r[i]]);
}
}
for (ll mid = 1; mid < cnum; mid <<= 1){
Complex wk = Complex(cos(2 * Pi / (2 * mid)), tag * sin(2 * Pi / (2 * mid)));
for (ll j = 0; j < cnum; j += 2 * mid) //枚举 cnum/2*mid个全长段
{
Complex w(1, 0);
for (ll k = 0; k < mid; k++) //每段里面进行fft,不是<=,因为只有一半,不能超出
{
Complex buf = w * cm[j + k + mid];
cm[j + k + mid] = cm[j + k] - buf; //在这一步cm[j + k] = cm[j + k] + buf上,否则cm[j+k]已被更改
cm[j + k] = cm[j + k] + buf;
w = w * wk;
}
}
}
}
int fft_init(int n){ //n是最高次幂
int maxx = 1, bits = 0;
while ((maxx) <= n) {
maxx <<= 1;
bits++;
}
for(int i = 0; i <= maxx - 1; i ++){
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (bits - 1)); //求二进制反转结果
}
return maxx;
}
int main(){
scanf("%d",&n);
for(int i = 1 , d; i <= n; i ++){
scanf("%d",&d);
a[num + d].x ++;
b[num - d].x ++;
}
int maxx = fft_init(4 * num);
fft(a, maxx, 1);
fft(b, maxx, 1);
for(int i = 0; i <= maxx; i ++){
c[i] = a[i] * b[i];
}
fft(c, maxx, -1);
for(int i = 1; i <= num; i ++){
vis[i] = (c[i + 2 * num].x / maxx + 0.5);
// printf ("%d %d\n",i , vis[i]);
}
for(int i = 1; i <= num; i ++){
if(vis[i]) continue;
for(int j = i; j <= num; j += i){
if(vis[j]){
vis[i] = 1;
}
}
}
for(int i = n; i <= num; i ++){
if(!vis[i]){
printf ("%d\n",i);
return 0;
}
}
return 0;
}