FFT 的几个常数优化

之前我在洛谷做 P1919(快速高精度乘法),发现我的代码每个点要几百毫秒,而大佬们每个点几十毫秒就能轻松AC。

于是,我决定仔细研究一下 FFT 的常数优化

1. dif 和 dit 混合使用

FFT 有两种方法,第一种是时间抽取(DIT),需要在计算前进行下标二进制逆序交换,第二种是频率抽取(DIF),需要在计算后进行下标二进制逆序交换。于是我们先做 DIF 的正变换,再做 DIT 的逆变换,就省略了二进制逆序交换的步骤,如图。
dif 和 dit 混合使用

其他部分与普通FFT差不多,只是混用了两种FFT函数。

void fft_dif(complex<double> y[], int len) {
	for (int h=len; h>=2; h>>=1) {
		complex<double> wn(cos(PI_2/h), sin(PI_2/h));
		for (int j=0; j<len; j+=h) {
			complex<double> w(1,0);
			for (int k=j; k<j+(h>>1); ++k) {
				complex<double> u=y[k], t=y[k+(h>>1)];
				y[k]=u+t, y[k+(h>>1)]=w*(u-t), w*=wn;
			}
		}
	}
}
void fft_dit(complex<double> y[], int len) {
	for (int h=2; h<=len; h<<=1) {
		complex<double> wn(cos(PI_2/h), sin(PI_2/h));
		for (int j=0; j<len; j+=h) {
			complex<double> w(1,0);
			for (int k=j; k<j+(h>>1); ++k) {
				complex<double> u=y[k], t=w*y[k+(h>>1)];
				y[k]=u+t, y[k+(h>>1)]=u-t, w*=wn;
			}
		}
	}
}

2. 三次变两次

上面需要两次正变换和一次逆变换。太慢了,我们可以将 a 放在实部,b 放在虚部,一起进行一次正变换。接着将点值平方,跑一遍逆变换,答案等于虚部的一半。

原理如下: ( a + b i ) 2 = a 2 − b 2 + 2 a b i (a+bi)^2 = a^2-b^2+2abi (a+bi)2=a2b2+2abi,而我们要的是 a b ab ab,正好是虚部的一半。

3. 三次变 1.5 次

此方法和 (2) 只能选一个。

将多项式的系数,交替放在实部和虚部,进行相乘。这样 n 就变小了一倍,相当于三次变 1.5 次。实现很方便,将 double* 强制转成 complex<double>* 就可以了。

FFT中,实部和虚部的处理是相互独立的。这样操作,相当于对 a 0 , a 2 , a 4 . . . a_0,a_2,a_4 ... a0,a2,a4... 进行了一次变换,对 a 1 , a 3 , a 5 . . . a_1,a_3,a_5 ... a1,a3,a5... 进行了一次变换,最终简单合并一下就行了。

这是我从一位大佬那里看到的代码:

#include <bits/stdc++.h>

using f64 = double;
using cpx = std::complex<f64>;
using u64 = unsigned long long;

constexpr int N = 1 << 20;
constexpr int bceil(int x){
    return (x == 1 || x == 0) ? x : (1 << (std::__lg(x - 1) + 1)); 
}

namespace f_f_t
{
    const f64 Pi_2 = acos(-1.0) / 2;
    cpx w[N >> 3];
    int init_l;
    void init(int l)
    {
        if (l <= init_l){
            return;
        }
        int t = std::__lg(l - 1);
        l = 1 << t, *w = cpx(1.0, 0.0), init_l = l << 1;
        for (int i = 1; i < l; i <<= 1){
            w[i] = std::polar(1.0, Pi_2 / i);
        }
        for (int i = 1; i < l; ++i){
            w[i] = w[i & (i - 1)] * w[i & -i];
        }
    }
    void dif(cpx *f, int L)
    {
        for (int l = L >> 1, r = L; l; l >>= 1, r >>= 1){
            for (cpx *j = f, *o = w; j != f + L; j += r, ++o){
                for (cpx *k = j; k != j + l; ++k){
                    cpx x = *k, y = k[l] * *o;
                    *k = x + y, k[l] = x - y;
                }
            }
        }
    }
    void dit(cpx *f, int L)
    {
        for (int l = 1, r = 2; l < L; l <<= 1, r <<= 1){
            for (cpx *j = f, *o = w; j != f + L; j += r, ++o){
                for (cpx *k = j; k != j + l; ++k){
                    cpx x = *k, y = k[l];
                    *k = x + y, k[l] = (x - y) * std::conj(*o);
                }
            }
        }
    }
    void Conv(f64 *f, int lim, f64 *g){
        cpx *F = (cpx*)f, *G = (cpx*)g;
        int l = lim >> 1;
        init(l), dif(F, l), dif(G, l);
        f64 fx = 2.0 / lim, fx2 = 0.5 / lim;
        F[0] = (F[0] * G[0] + 2 * F[0].imag() * G[0].imag()) * fx, F[1] = F[1] * G[1] * fx;
        for (int k = 2, m = 3; k < l; k <<= 1, m <<= 1){
            for (int i = k, j = i + k - 1; i < m; ++i, --j){
                cpx oi = (F[i] + std::conj(F[j])), hi = (F[i] - std::conj(F[j]));
                cpx Oi = (G[i] + std::conj(G[j])), Hi = (G[i] - std::conj(G[j]));
                cpx r0 = oi * Oi - hi * Hi * ((i & 1) ? -w[i >> 1] : w[i >> 1]), r1 = Oi * hi + oi * Hi;
                F[i] = (r0 + r1) * fx2, F[j] = std::conj(r0 - r1) * fx2;
            }
        }
        dit(F, l);
    }
}

f64 F[N >> 1], G[N >> 1];

char buf[N << 1];

void solve(){
    int tot = fread(buf, 1, sizeof(buf), stdin);
    char *bga = buf, *eda = buf, *bgb, *edb = buf + tot;
    while(!isdigit(*(edb - 1))){--edb;}
    for(; (((*((u64*)eda)) + 0x5f5f5f5f5f5f5f5f) & 0x8080808080808080) == 0x8080808080808080; eda+=8){}
    for(; *eda > 32; ){++eda;}
    for(bgb = eda; !isdigit(*bgb); ++bgb){}
    auto radX = [](char* bg, char* ed, f64* out){
        char *edr = bg + 4, *pos = ed;
        for(; pos > edr; pos -= 4){
            *out++ = *(pos - 4) * 1000 + *(pos - 3) * 100 + *(pos - 2) * 10 + *(pos - 1) - 53328;
        }
        return *out++ = std::stod(std::string(bg, pos)), out;
    };
    int n = radX(bga, eda, F) - F, m = radX(bgb, edb, G) - G, lim = std::max(8, bceil(n + m - 1));
    std::fill(F + n, F + lim, 0.0), std::fill(G + m, G + lim, 0.0), f_f_t::Conv(F, lim, G);
    {
        struct ict{
            int num[10000];//小端
            ict(){
                int j = 0;
                for(int e0 = (48 << 0); e0 < (58 << 0); e0 += (1 << 0)){
                    for(int e1 = (48 << 8); e1 < (58 << 8); e1 += (1 << 8)){
                        for(int e2 = (48 << 16); e2 < (58 << 16); e2 += (1 << 16)){
                            for(int e3 = (48 << 24); e3 < (58 << 24); e3 += (1 << 24)){
                                num[j] = e0 ^ e1 ^ e2 ^ e3, ++j;
                            }
                        }
                    }
                }
            }
        }ot;
        int o = (n + m - 2), *ed = (int*)buf + o, *c = ed;
        u64 u = 0;
        for(int p = 0; p < o; ++p){
            u += u64(F[p] + 0.5), *--c = ot.num[u % 10000u], u /= 10000u;
        }
        fprintf(stdout, "%llu", u + u64(F[o] + 0.5)), fwrite(c, sizeof(int), ed - c, stdout);
    }
}

int main(){
    std::cin.tie(nullptr) -> sync_with_stdio(false);
    solve();
    return 0;
}
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值