MTT
MTT即为任意模数的NTT。
一般有两种情况需要用到任意模数的NTT:
一,模数是NTT模数,但是多项式长度超出了限制(比如模数是 998244353 998244353 998244353,而多项式长度和超过了 2 23 2^{23} 223);
二,模数不是上面提到的NTT模数,比如模数是 1000000007 1000000007 1000000007。
这里我们推荐换一个思路,考虑能不能把这些模数全都用上(模数乘积大于多项式相乘后的最大值即可),求出在这些模意义下的值分别是多少,最后通过中国剩余定理(CRT)来算出在给定模数的模意义下的值(选的质数为: 998244353 998244353 998244353 , 2281701377 2281701377 2281701377 , 1004535809 1004535809 1004535809)。
但是这些模数相乘后会爆long long(多项式相乘后的最大值一般不爆long long,侧面证明了这个方法普适性广,真爆long long了就换下一题吧 ),所以要想一些别的办法来CRT。
我们设最后的多项式某一个位置上实际的答案为 a n s ans ans,选取的三个质数分别为 p 1 p_1 p1, p 2 p_2 p2, p 3 p_3 p3。
我们先通过6次DFT,3次IDFT算出在模意义下的值:
a n s ≡ a 1 ( m o d p 1 ) ans \equiv a_1(\mod p_1) ans≡a1(modp1) a n s ≡ a 2 ( m o d p 2 ) ans \equiv a_2(\mod p_2) ans≡a2(modp2) a n s ≡ a 3 ( m o d p 3 ) ans \equiv a_3(\mod p_3) ans≡a3(modp3)。
根据CRT我们不难算出 a n s ≡ a 4 ( m o d p 1 p 2 ) ans \equiv a_4(\mod p_1p_2) ans≡a4(modp1p2)。
设 a n s = a 5 p 1 p 2 + a 4 ans = a_5p_1p_2 + a_4 ans=a5p1p2+a4,我们已知 a 4 a_4 a4,如果能求出 a 5 a_5 a5就能求出 a n s ans ans的值。
由于 a n s ≡ a 3 ( m o d p 3 ) ans \equiv a_3(\mod p_3) ans≡a3(modp3),
所以 a 5 p 1 p 2 ≡ a 3 − a 4 ( m o d p 3 ) a_5p_1p_2 \equiv a_3 - a_4(\mod p_3) a5p1p2≡a3−a4(modp3),
最后可得 a 5 ≡ ( a 3 − a 4 ) p 1 − 1 p 2 − 2 ( m o d p 3 ) a_5 \equiv (a_3 - a_4)p_1^{-1}p_2^{-2}(\mod p_3) a5≡(a3−a4)p1−1p2−2(modp3)。
现在 a 5 a_5 a5, p 1 p_1 p1, p 2 p_2 p2, a 4 a_4 a4都已经知道了,我们直接用 a n s = a 5 p 1 p 2 + a 4 ans = a_5p_1p_2 + a_4 ans=a5p1p2+a4算出答案即可(这里模的是题目输入的那个模数)。
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#define ll long long
using namespace std;
const int N = 262157;
int n,m;
struct ntt {
ll mod , a[N] , b[N];
int len , rev[N];
ll g , gi;
inline ll qw(ll x,ll y) {
ll res = 1;
while ( y ) {
if( y & 1 ) {
res = ( res * x ) % mod;
}
x = ( x * x ) % mod;
y >>= 1;
}
return res;
}
inline void init(const int x) {
int bit = 0;
len = 1;
gi = qw( g , mod - 2 );
while ( len <= x ) {
len <<= 1;
++bit;
}
for ( int i = 0 ; i < len ; ++i ) {
rev[i] = ( rev[i >> 1] >> 1 ) | ( ( i & 1 ) << ( bit - 1 ) );
}
}
inline void NTT( ll *F , int on ) {
for ( int i = 0 ; i < len ; ++i ) {
if( i < rev[i] ) {
swap( F[i] , F[rev[i]] );
}
}
for ( int i = 2 ; i <= len ; i <<= 1 ) {
//枚举步长,从递归的下面往上走
ll gn = qw( on ? g : gi , ( mod - 1 ) / ( i ) );
for ( int j = 0 ; j <= len - 1 ; j += i ) {
//走一遍步长
ll gg = 1;
for ( int k = j ; k < j + i / 2 ; ++k ) {
//枚举每块区间内的每一个元素
ll u = F[k];
ll v = ( ( gg * F[k + i / 2] ) % mod + mod ) % mod;
F[k] = ( (u + v) % mod + mod ) % mod;
F[k + i / 2] = ( ( u - v ) % mod + mod ) % mod;
gg = ( ( gg * gn ) % mod + mod ) % mod;
}
}
}
if( on == 0 ) {
const ll inv = qw( (ll)len , mod - 2 );
for ( int i = 0 ; i < len ; ++i ) {
F[i] = ( ( F[i] * inv ) % mod + mod ) % mod;
}
}
return;
}
}num[3];
ll exgcd( ll a, ll b, ll &x , ll &y ) {
if( !b ) {
x = 1;
y = 0;
return a;
}
const ll d = exgcd( b , a % b , x , y );
ll t = x;
x = y;
y = t - ( a / b ) * y;
return d;
}
inline ll mul( ll a , ll b , ll mod ) {
//玄学快乘
ll ans = ( a * b - (ll)( (long double)a / mod * b + 1e-8 ) * mod );
return ans < 0 ? ans + mod : ans;
}
ll a[3],p[3];
ll crt() {
ll pp = 1 , sum = 0;
for ( int i = 1 ; i <= 2 ; ++i ) {
pp *= p[i];
}
for ( int i = 1 ; i <= 2 ; ++i ) {
const ll mm = pp / p[i];
ll x , y;
exgcd( p[i] , mm , x , y );
sum = ( sum + mul( mul( y , mm , pp ) , a[i] , pp ) )