之前我在洛谷做 P1919(快速高精度乘法),发现我的代码每个点要几百毫秒,而大佬们每个点几十毫秒就能轻松AC。
于是,我决定仔细研究一下 FFT 的常数优化
1. dif 和 dit 混合使用
FFT 有两种方法,第一种是时间抽取(DIT),需要在计算前进行下标二进制逆序交换,第二种是频率抽取(DIF),需要在计算后进行下标二进制逆序交换。于是我们先做 DIF 的正变换,再做 DIT 的逆变换,就省略了二进制逆序交换的步骤,如图。
其他部分与普通FFT差不多,只是混用了两种FFT函数。
void fft_dif(complex<double> y[], int len) {
for (int h=len; h>=2; h>>=1) {
complex<double> wn(cos(PI_2/h), sin(PI_2/h));
for (int j=0; j<len; j+=h) {
complex<double> w(1,0);
for (int k=j; k<j+(h>>1); ++k) {
complex<double> u=y[k], t=y[k+(h>>1)];
y[k]=u+t, y[k+(h>>1)]=w*(u-t), w*=wn;
}
}
}
}
void fft_dit(complex<double> y[], int len) {
for (int h=2; h<=len; h<<=1) {
complex<double> wn(cos(PI_2/h), sin(PI_2/h));
for (int j=0; j<len; j+=h) {
complex<double> w(1,0);
for (int k=j; k<j+(h>>1); ++k) {
complex<double> u=y[k], t=w*y[k+(h>>1)];
y[k]=u+t, y[k+(h>>1)]=u-t, w*=wn;
}
}
}
}
2. 三次变两次
上面需要两次正变换和一次逆变换。太慢了,我们可以将 a 放在实部,b 放在虚部,一起进行一次正变换。接着将点值平方,跑一遍逆变换,答案等于虚部的一半。
原理如下: ( a + b i ) 2 = a 2 − b 2 + 2 a b i (a+bi)^2 = a^2-b^2+2abi (a+bi)2=a2−b2+2abi,而我们要的是 a b ab ab,正好是虚部的一半。
3. 三次变 1.5 次
此方法和 (2) 只能选一个。
将多项式的系数,交替放在实部和虚部,进行相乘。这样 n 就变小了一倍,相当于三次变 1.5 次。实现很方便,将 double*
强制转成 complex<double>*
就可以了。
FFT中,实部和虚部的处理是相互独立的。这样操作,相当于对 a 0 , a 2 , a 4 . . . a_0,a_2,a_4 ... a0,a2,a4... 进行了一次变换,对 a 1 , a 3 , a 5 . . . a_1,a_3,a_5 ... a1,a3,a5... 进行了一次变换,最终简单合并一下就行了。
这是我从一位大佬那里看到的代码:
#include <bits/stdc++.h>
using f64 = double;
using cpx = std::complex<f64>;
using u64 = unsigned long long;
constexpr int N = 1 << 20;
constexpr int bceil(int x){
return (x == 1 || x == 0) ? x : (1 << (std::__lg(x - 1) + 1));
}
namespace f_f_t
{
const f64 Pi_2 = acos(-1.0) / 2;
cpx w[N >> 3];
int init_l;
void init(int l)
{
if (l <= init_l){
return;
}
int t = std::__lg(l - 1);
l = 1 << t, *w = cpx(1.0, 0.0), init_l = l << 1;
for (int i = 1; i < l; i <<= 1){
w[i] = std::polar(1.0, Pi_2 / i);
}
for (int i = 1; i < l; ++i){
w[i] = w[i & (i - 1)] * w[i & -i];
}
}
void dif(cpx *f, int L)
{
for (int l = L >> 1, r = L; l; l >>= 1, r >>= 1){
for (cpx *j = f, *o = w; j != f + L; j += r, ++o){
for (cpx *k = j; k != j + l; ++k){
cpx x = *k, y = k[l] * *o;
*k = x + y, k[l] = x - y;
}
}
}
}
void dit(cpx *f, int L)
{
for (int l = 1, r = 2; l < L; l <<= 1, r <<= 1){
for (cpx *j = f, *o = w; j != f + L; j += r, ++o){
for (cpx *k = j; k != j + l; ++k){
cpx x = *k, y = k[l];
*k = x + y, k[l] = (x - y) * std::conj(*o);
}
}
}
}
void Conv(f64 *f, int lim, f64 *g){
cpx *F = (cpx*)f, *G = (cpx*)g;
int l = lim >> 1;
init(l), dif(F, l), dif(G, l);
f64 fx = 2.0 / lim, fx2 = 0.5 / lim;
F[0] = (F[0] * G[0] + 2 * F[0].imag() * G[0].imag()) * fx, F[1] = F[1] * G[1] * fx;
for (int k = 2, m = 3; k < l; k <<= 1, m <<= 1){
for (int i = k, j = i + k - 1; i < m; ++i, --j){
cpx oi = (F[i] + std::conj(F[j])), hi = (F[i] - std::conj(F[j]));
cpx Oi = (G[i] + std::conj(G[j])), Hi = (G[i] - std::conj(G[j]));
cpx r0 = oi * Oi - hi * Hi * ((i & 1) ? -w[i >> 1] : w[i >> 1]), r1 = Oi * hi + oi * Hi;
F[i] = (r0 + r1) * fx2, F[j] = std::conj(r0 - r1) * fx2;
}
}
dit(F, l);
}
}
f64 F[N >> 1], G[N >> 1];
char buf[N << 1];
void solve(){
int tot = fread(buf, 1, sizeof(buf), stdin);
char *bga = buf, *eda = buf, *bgb, *edb = buf + tot;
while(!isdigit(*(edb - 1))){--edb;}
for(; (((*((u64*)eda)) + 0x5f5f5f5f5f5f5f5f) & 0x8080808080808080) == 0x8080808080808080; eda+=8){}
for(; *eda > 32; ){++eda;}
for(bgb = eda; !isdigit(*bgb); ++bgb){}
auto radX = [](char* bg, char* ed, f64* out){
char *edr = bg + 4, *pos = ed;
for(; pos > edr; pos -= 4){
*out++ = *(pos - 4) * 1000 + *(pos - 3) * 100 + *(pos - 2) * 10 + *(pos - 1) - 53328;
}
return *out++ = std::stod(std::string(bg, pos)), out;
};
int n = radX(bga, eda, F) - F, m = radX(bgb, edb, G) - G, lim = std::max(8, bceil(n + m - 1));
std::fill(F + n, F + lim, 0.0), std::fill(G + m, G + lim, 0.0), f_f_t::Conv(F, lim, G);
{
struct ict{
int num[10000];//小端
ict(){
int j = 0;
for(int e0 = (48 << 0); e0 < (58 << 0); e0 += (1 << 0)){
for(int e1 = (48 << 8); e1 < (58 << 8); e1 += (1 << 8)){
for(int e2 = (48 << 16); e2 < (58 << 16); e2 += (1 << 16)){
for(int e3 = (48 << 24); e3 < (58 << 24); e3 += (1 << 24)){
num[j] = e0 ^ e1 ^ e2 ^ e3, ++j;
}
}
}
}
}
}ot;
int o = (n + m - 2), *ed = (int*)buf + o, *c = ed;
u64 u = 0;
for(int p = 0; p < o; ++p){
u += u64(F[p] + 0.5), *--c = ot.num[u % 10000u], u /= 10000u;
}
fprintf(stdout, "%llu", u + u64(F[o] + 0.5)), fwrite(c, sizeof(int), ed - c, stdout);
}
}
int main(){
std::cin.tie(nullptr) -> sync_with_stdio(false);
solve();
return 0;
}