题意:给定a,b两个数组,求最优的k使得sigma((a[i]-b[(i+k)%n])^2){0<=i<n}最小。
分析:我们将表达式拆开会发现我们只要求-2* a[i]*b[(i+k)%n]最小。这是一个循环的乘积和,我们将所有的情况写出来会是一个n*n的矩阵,其中第k行是a[0]*b[k-1]+a[1]*b[k]...a[n-k]*b[0]+..+a[n-1]*b[k-2]。如果我们将b数组逆序一下就会发现这一行是(a[0]*b[n-k]+a[1]*b[n-k-1]+a[2]*b[n-k-2]+..+a[n-k]*b[0])+(a[n-k+1]*b[n-1]+a[n-k+2]*b[n-2]+..+a[n-1]*[n-k+1])其实这就是FFT答案中的C[n-k]+C[2*n-k],那么其实答案就是max(C[i]+C[i+n])啦。但是这题裸fft会有精度问题因为最大的项可以达到10^16。然后我就去淘了下鸟神的ntt板子,然后换了下费马素数和原根就好了,ntt的素数可以点这里。
补充:过题的方式多种多样,多积累。本来直接用fft是会有精度误差,但是如果我们不直接用fft确定答案而是先用有误差的fft求出最优答案的那个k在哪,那么接下来我们O(n)再求一次精确解就行啦。要构出卡那点误差的数据也很难。详见代码1。
代码0:
#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<bitset>
#include<math.h>
#include<vector>
#include<string>
#include<stdio.h>
#include<cstring>
#include<iostream>
#include<algorithm>
#pragma comment(linker, "/STACK:102400000,102400000")
using namespace std;
typedef long long LL ;
#define clr(a,x) memset ( a , x , sizeof a )
const int MAXN = 300010 ;
const int MAXM = 2000005 ;
const LL mod = ( 1LL << 55 ) * 5 + 1 ;
const LL g = 6 ;
int nn;
LL x1[MAXN] , x2[MAXN] ;
LL mul ( LL x , LL y ) {
return ( x * y - ( long long ) ( x / ( long double ) mod * y + 1e-3 ) * mod + mod ) % mod ;
}
LL power ( LL a , LL b ) {
LL res = 1 , tmp = a ;
while ( b ) {
if ( b & 1 ) res = mul ( res , tmp ) ;
tmp = mul ( tmp , tmp ) ;
b >>= 1 ;
}
return res ;
}
void DFT ( LL y[] , int n , bool rev ) {
for ( int i = 1 , j , t , k ; i < n ; ++ i ) {
for ( k = n >> 1 , t = i , j = 0 ; k ; k >>= 1 , t >>= 1 ) {
j = j << 1 | t & 1 ;
}
if ( i < j ) swap ( y[i] , y[j] ) ;
}
for ( int s = 2 , ds = 1 ; s <= n ; ds = s , s <<= 1 ) {
LL wn = power ( g , ( mod - 1 ) / s ) ;
if ( !rev ) wn = power ( wn , mod - 2 ) ;
for ( int k = 0 ; k < n ; k += s ) {
LL w = 1 , t ;
for ( int i = k ; i < k + ds ; ++ i , w = mul ( w , wn ) ) {
y[i + ds] = ( y[i] - ( t = mul ( y[i + ds] , w ) ) + mod ) % mod ;
y[i] = ( y[i] + t ) % mod ;
}
}
}
}
LL FFT ( LL x1[] , LL x2[] , int n ) {
DFT ( x1 , n , 1 ) ;
DFT ( x2 , n , 1 ) ;
for ( int i = 0 ; i < n ; ++ i ) x1[i] = mul ( x1[i] , x2[i] ) ;
DFT ( x1 , n , 0 ) ;
LL vn = power ( n , mod - 2 ) ;
for ( int i = 0 ; i < n ; ++ i ) x1[i] = mul ( x1[i] , vn ) ;
LL ret=0;x1[2*nn-1]=0ll;
for ( int i = 0; i < nn; i++ ) ret=max( ret , x1[i] + x1[i+nn] );
return ret;
}
void solve () {
int i,n,len=1;LL ans=0;
scanf("%d", &n);nn=n;
for (i=n-1;i>=0;i--) scanf("%lld", &x1[i]),ans+=x1[i]*x1[i];
for (i=0;i<n;i++) scanf("%lld", &x2[i]),ans+=x2[i]*x2[i];
while (len<2*n) len<<=1;
for (i=n;i<len;i++) x1[i]=x2[i]=0ll;
printf("%lld\n", ans-2*FFT(x1,x2,len));
}
int main()
{
int T;
scanf("%d", &T);
while (T--) solve();
return 0;
}
代码1:
#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<bitset>
#include<math.h>
#include<vector>
#include<string>
#include<stdio.h>
#include<cstring>
#include<iostream>
#include<algorithm>
#pragma comment(linker, "/STACK:102400000,102400000")
using namespace std;
const int N=60010;
const int M=50010;
const int mod=1000000007;
const int MOD1=1000000007;
const int MOD2=1000000009;
const double EPS=0.00000001;
typedef long long ll;
const ll MOD=1000000007;
const int INF=1000000010;
const ll MAX=1ll<<55;
const double eps=1e-5;
const double inf=~0u>>1;
const double pi=acos(-1.0);
typedef double db;
typedef unsigned int uint;
typedef unsigned long long ull;
ll x[N],y[N];
struct Complex{
db r,i;
Complex() {}
Complex(db r,db i):r(r),i(i) {}
Complex operator + (const Complex &t) const {
return Complex(r+t.r,i+t.i);
}
Complex operator - (const Complex &t) const {
return Complex(r-t.r,i-t.i);
}
Complex operator * (const Complex &t) const {
return Complex(r*t.r-i*t.i,r*t.i+i*t.r);
}
}a[3*N],b[3*N];
void FFT(Complex x[],int n,int rev) {
int i,j,k,t,ds;
Complex w,u,wn;
for (i=1;i<n;i++) {
for (j=0,t=i,k=n>>1;k;k>>=1,t>>=1) j=j<<1|t&1;
if (i<j) swap(x[i],x[j]);
}
for (i=2,ds=1;i<=n;ds=i,i<<=1) {
w=Complex(1,0);wn=Complex(cos(rev*2*pi/i),sin(rev*2*pi/i));
for (j=0;j<ds;j++,w=w*wn)
for (k=j;k<n;k+=i) {
u=w*x[k+ds];x[k+ds]=x[k]-u;x[k]=x[k]+u;
}
}
if (rev==-1) for (i=0;i<n;i++) x[i].r/=n;
}
int get_k(int n) {
int i,k=0,len=1;ll mx=0;
while (len<2*n) len<<=1;
for (i=0;i<n;i++) a[i]=Complex(x[i],0),b[i]=Complex(y[i],0);
for (i=n;i<len;i++) a[i]=Complex(0,0),b[i]=Complex(0,0);
FFT(a,len,1);FFT(b,len,1);
for (i=0;i<len;i++) a[i]=a[i]*b[i];
FFT(a,len,-1);
for (i=0;i<n;i++)
if (a[i].r+a[i+n].r>mx) {
mx=a[i].r+a[i+n].r;k=i;
}
return n-k-1;
}
void solve() {
int i,k,n,len=1;
ll ans=0;
scanf("%d", &n);
for (i=0;i<n;i++) scanf("%lld", &x[i]),ans+=x[i]*x[i];
for (i=n-1;i>=0;i--) scanf("%lld", &y[i]),ans+=y[i]*y[i];
k=get_k(n);
for (i=0;i<n;i++) ans-=2*x[i]*y[n-((i+k)%n)-1];
printf("%lld\n", ans);
}
int main()
{
int T;
scanf("%d", &T);
while (T--) solve();
return 0;
}