题目描述
求长度为 $n$ 的序列,每个数都是 $|S|$ 中的某一个,所有数的乘积模 $m$ 等于 $x$ 的序列数目模1004535809的值。
输入
一行,四个整数,N、M、x、|S|,其中|S|为集合S中元素个数。
第二行,|S|个整数,表示集合S中的所有元素。
1<=N<=10^9,3<=M<=8000,M为质数
1<=x<=M-1,输入数据保证集合S中元素不重复
输出
一行,一个整数,表示你求出的种类数mod 1004535809的值。
样例输入
4 3 1 2
1 2
样例输出
8
题解
原根+NTT
如果条件是和模 $m$ 等于 $x$ ,那么很明显就是一道NTT裸题。维护S集合的生成函数在模 $x^m$ 意义下的 $n$ 次幂即可。
然而本题的条件是乘积。可以求出 $m$ 的原根,对每个数取指标,那么原数相乘就变为指标相加,使用NTT快速幂即可。
求原根的过程可以直接暴力。
注意 $|S|$ 集合中的数可能有0,0是没有指标的。由于 $x\neq 0$ ,因此出现0时无意义,直接忽略这个数即可。
时间复杂度 $O(m\log^2n)$
#include <cstdio>
#include <algorithm>
#define N 16410
#define mod 1004535809
using namespace std;
typedef long long ll;
int m , s[N >> 1] , v[15] , tot , ind[N >> 1];
ll a[N] , ans[N];
inline ll pow(ll x , int y , ll m)
{
ll ans = 1;
while(y)
{
if(y & 1) ans = ans * x % m;
x = x * x % m , y >>= 1;
}
return ans;
}
int getroot()
{
int i , j , t = m - 1;
for(i = 2 ; i * i <= t ; i ++ )
{
if(t % i == 0)
{
v[++tot] = i;
while(t % i == 0) t /= i;
}
}
if(t != 1) v[++tot] = t;
for(i = 2 ; i < m ; i ++ )
{
for(j = 1 ; j <= tot ; j ++ )
if(pow(i , (m - 1) / v[j] , m) == 1)
break;
if(j > tot) return i;
}
return 0;
}
void ntt(ll *a , int n , int flag)
{
int i , j , k;
for(k = i = 0 ; i < n ; i ++ )
{
if(i > k) swap(a[i] , a[k]);
for(j = (n >> 1) ; (k ^= j) < j ; j >>= 1);
}
for(k = 2 ; k <= n ; k <<= 1)
{
ll wn = pow(3 , (mod - 1) / k , mod);
if(flag == -1) wn = pow(wn , mod - 2 , mod);
for(i = 0 ; i < n ; i += k)
{
ll w = 1 , t;
for(j = i ; j < i + (k >> 1) ; j ++ , w = w * wn % mod)
t = w * a[j + (k >> 1)] % mod , a[j + (k >> 1)] = (a[j] - t + mod) % mod , a[j] = (a[j] + t) % mod;
}
}
if(flag == -1)
{
k = pow(n , mod - 2 , mod);
for(i = 0 ; i < n ; i ++ ) a[i] = a[i] * k % mod;
for(i = m - 1 ; i < n ; i ++ ) a[i % (m - 1)] = (a[i % (m - 1)] + a[i]) % mod , a[i] = 0;
}
}
void Pow(int y , int n)
{
int i;
ans[0] = 1;
while(y)
{
ntt(a , n , 1);
if(y & 1)
{
ntt(ans , n , 1);
for(i = 0 ; i < n ; i ++ ) ans[i] = ans[i] * a[i] % mod;
ntt(ans , n , -1);
}
for(i = 0 ; i < n ; i ++ ) a[i] = a[i] * a[i] % mod;
ntt(a , n , -1);
y >>= 1;
}
}
int main()
{
int n , x , k , i , r , t , len = 1;
scanf("%d%d%d%d" , &n , &m , &x , &k);
for(i = 1 ; i <= k ; i ++ ) scanf("%d" , &s[i]);
r = getroot();
for(t = 1 , i = 0 ; i < m - 1 ; i ++ , t = t * r % m) ind[t] = i;
for(i = 1 ; i <= k ; i ++ )
if(s[i])
a[ind[s[i]]] ++ ;
while(len <= 2 * (m - 2)) len <<= 1;
Pow(n , len);
printf("%lld\n" , ans[ind[x]]);
return 0;
}