数学板块学习之任意模数NTT

本文详细介绍了如何使用任意模数的NTT算法,通过三个特定模数进行快速傅里叶变换,并利用中国剩余定理进行结果合并,提供了一个完整的代码实现模板。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

任意模数NTT

对于模数不是NTT用的模数,可以用多个NTT模数做NTT然后用CRT(中国剩余定理合并)
一般我们用的三个NTT模数, 998244353 = 2 23 ∗ 119 + 1 , 1004535809 = 2 21 ∗ 479 + 1 , 469762049 = 2 26 ∗ 7 + 1 998244353=2^{23}*119+1,1004535809=2^{21}*479+1,469762049=2^{26}*7+1 998244353=223119+1,1004535809=221479+1,469762049=2267+1,这三个模数的原根都是3
但是三个模数的乘积过大,我们用CRT合并需要注意
合并过程:
我们得到的三个方程
x = a 1   ( m o d   m 1 ) x = a 2   ( m o d   m 2 ) x = a 3   ( m o d   m 3 ) \begin{aligned} x&=a_1\ (mod\ m_1)\\ x&=a_2\ (mod\ m_2)\\ x&=a_3\ (mod\ m_3) \end{aligned} xxx=a1 (mod m1)=a2 (mod m2)=a3 (mod m3)
先用CRT合并前两个为 x = A   ( m o d   M ) x=A\ (mod\ M) x=A (mod M)其中 M = m 1 m 2 M=m_1m_2 Mm1m2
设答案 a n s = k M + A ans=kM+A ans=kM+A
a n s ans ans满足第三个方程 a n s = k M + A = a 3   ( m o d   m 3 ) ans=kM+A=a_3\ (mod\ m_3) ans=kM+A=a3 (mod m3)
k = ( a 3 − A ) M − 1   ( m o d   m 3 ) k=(a3-A)M^{-1}\ (mod\ m_3) k=(a3A)M1 (mod m3)
已知 k k k代回原式即可求出 a n s ans ans
因为 k k k是在 ( m o d   m 3 ) (mod\ m_3) (mod m3)意义下求出的,我们只需要验证一下
假设 k = t m 3 + b k=tm3+b k=tm3+b
所以 a n s = ( t m 3 + b ) M + A = t m 3 M + b M + A ans=(tm_3+b)M+A=tm_3M+bM+A ans=(tm3+b)M+A=tm3M+bM+A,很明显是满足的

给个模板吧,模板来源,实在不想自己写了。。。

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <queue>
#include <cmath>
#include <string>
#include <cstring>
#include <map>
#include <set>
#include <math.h>
#include <unordered_map>
//#include <tr1/unordered_map>

using namespace std;
#define me(x,y) memset(x,y,sizeof x)
#define MIN(x,y) x < y ? x : y
#define MAX(x,y) x > y ? x : y

typedef long long ll;
typedef unsigned long long ull;

const int maxn = 300005;
const int INF = 0x3f3f3f3f;
const int MOD = 1e9+7;
const double eps = 1e-06;
const double PI = acos(-1);

int mod;

namespace Math {
    inline int pw(int base, int p, const int mod) {
        static int res;
        for (res = 1; p; p >>= 1, base = static_cast<long long> (base) * base % mod) if (p & 1) res = static_cast<long long> (res) * base % mod;
        return res;
    }
    inline int inv(int x, const int mod) { return pw(x, mod - 2, mod); }
}
using namespace Math;
const int mod1 = 998244353, mod2 = 1004535809, mod3 = 469762049, G = 3;
const long long mod_1_2 = static_cast<long long> (mod1) * mod2;
const int inv_1 = inv(mod1, mod2), inv_2 = inv(mod_1_2 % mod3, mod3);
struct Int {
    int A, B, C;
    explicit inline Int() { }
    explicit inline Int(int __num) : A(__num), B(__num), C(__num) { }
    explicit inline Int(int __A, int __B, int __C) : A(__A), B(__B), C(__C) { }
    static inline Int reduce(const Int &x) {
        return Int(x.A + (x.A >> 31 & mod1), x.B + (x.B >> 31 & mod2), x.C + (x.C >> 31 & mod3));
    }
    inline friend Int operator + (const Int &lhs, const Int &rhs) {
        return reduce(Int(lhs.A + rhs.A - mod1, lhs.B + rhs.B - mod2, lhs.C + rhs.C - mod3));
    }
    inline friend Int operator - (const Int &lhs, const Int &rhs) {
        return reduce(Int(lhs.A - rhs.A, lhs.B - rhs.B, lhs.C - rhs.C));
    }
    inline friend Int operator * (const Int &lhs, const Int &rhs) {
        return Int(static_cast<long long> (lhs.A) * rhs.A % mod1, static_cast<long long> (lhs.B) * rhs.B % mod2, static_cast<long long> (lhs.C) * rhs.C % mod3);
    }
    inline int get() {
        long long x = static_cast<long long> (B - A + mod2) % mod2 * inv_1 % mod2 * mod1 + A;
        return (static_cast<long long> (C - x % mod3 + mod3) % mod3 * inv_2 % mod3 * (mod_1_2 % mod) % mod + x) % mod;
    }
} ;

#define maxn 131072

namespace Poly {
#define N (maxn << 1)
    int lim, s, rev[N];
    Int Wn[N | 1];
    inline void init(int n) {
        s = -1, lim = 1; while (lim < n) lim <<= 1, ++s;
        for (register int i = 1; i < lim; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << s;
        const Int t(pw(G, (mod1 - 1) / lim, mod1), pw(G, (mod2 - 1) / lim, mod2), pw(G, (mod3 - 1) / lim, mod3));
        *Wn = Int(1); for (register Int *i = Wn; i != Wn + lim; ++i) *(i + 1) = *i * t;
    }
    inline void NTT(Int *A, const int op = 1) {
        for (register int i = 1; i < lim; ++i) if (i < rev[i]) std::swap(A[i], A[rev[i]]);
        for (register int mid = 1; mid < lim; mid <<= 1) {
            const int t = lim / mid >> 1;
            for (register int i = 0; i < lim; i += mid << 1) {
                for (register int j = 0; j < mid; ++j) {
                    const Int W = op ? Wn[t * j] : Wn[lim - t * j];
                    const Int X = A[i + j], Y = A[i + j + mid] * W;
                    A[i + j] = X + Y, A[i + j + mid] = X - Y;
                }
            }
        }
        if (!op) {
            const Int ilim(inv(lim, mod1), inv(lim, mod2), inv(lim, mod3));
            for (register Int *i = A; i != A + lim; ++i) *i = (*i) * ilim;
        }
    }
#undef N
}
using namespace Poly;

int n, m;
Int A[maxn << 1], B[maxn << 1];
int main() {
    scanf("%d%d%d", &n, &m, &mod); ++n, ++m;
    for (int i = 0, x; i < n; ++i) scanf("%d", &x), A[i] = Int(x % mod);
    for (int i = 0, x; i < m; ++i) scanf("%d", &x), B[i] = Int(x % mod);
    init(n + m);
    NTT(A), NTT(B);
    for (int i = 0; i < lim; ++i) A[i] = A[i] * B[i];
    NTT(A, 0);
    for (int i = 0; i < n + m - 1; ++i) {
        printf("%d%c", A[i].get(),i == n + m - 2 ? '\n' : ' ');
    }
    return 0;
}

/*

*/
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值