G. Lucky Tickets
time limit per test
5 seconds
memory limit per test
256 megabytes
input
standard input
output
standard output
All bus tickets in Berland have their numbers. A number consists of nn digits (nn is even). Only kk decimal digits d1,d2,…,dkd1,d2,…,dk can be used to form ticket numbers. If 00 is among these digits, then numbers may have leading zeroes. For example, if n=4n=4 and only digits 00 and 44can be used, then 00000000, 40044004, 44404440 are valid ticket numbers, and 00020002, 0000, 4444344443 are not.
A ticket is lucky if the sum of first n/2n/2 digits is equal to the sum of remaining n/2n/2 digits.
Calculate the number of different lucky tickets in Berland. Since the answer may be big, print it modulo 998244353998244353.
Input
The first line contains two integers nn and kk (2≤n≤2⋅105,1≤k≤10)(2≤n≤2⋅105,1≤k≤10) — the number of digits in each ticket number, and the number of different decimal digits that may be used. nn is even.
The second line contains a sequence of pairwise distinct integers d1,d2,…,dkd1,d2,…,dk (0≤di≤9)(0≤di≤9) — the digits that may be used in ticket numbers. The digits are given in arbitrary order.
Output
Print the number of lucky ticket numbers, taken modulo 998244353998244353.
Examples
input
Copy
4 2 1 8
output
Copy
6
input
Copy
20 1 6
output
Copy
1
input
Copy
10 5 6 1 4 0 3
output
Copy
569725
input
Copy
1000 7 5 4 0 1 8 3 2
output
Copy
460571165
Note
In the first example there are 66 lucky ticket numbers: 11111111, 18181818, 18811881, 81188118, 81818181 and 88888888.
There is only one ticket number in the second example, it consists of 2020 digits 66. This ticket number is lucky, so the answer is 11.
分析:容易想到用表示前i个凑成j的方案数,最后答案就是
转移方程,d[k]表示k是否存在。
容易得到,那么结果就是d[k]卷n/2次,直接NTT+快速幂优化即可。
#include "bits/stdc++.h"
using namespace std;
const int mod = 998244353;
long long qk(long long a, long long n) {
long long ans = 1;
while (n) {
if (n & 1)ans = ans * a % mod;
n >>= 1;
a = a * a % mod;
}
return ans;
}
void ntt(long long A[], int n, int type) {
for (int i = 0, j = 0; i < n; ++i) {
if (i < j) swap(A[i], A[j]);
for (int l = (n >> 1); (j ^= l) < l; l >>= 1);
}
for (int s = 2; s <= n; s <<= 1) {
int t = (s >> 1);
int u = (type == 1) ? (qk(3, (mod - 1) / s)) : (qk(3, (mod - 1) - (mod - 1) / s));
for (int i = 0; i < n; i += s) {
for (int p = 1, j = 0; j < t; ++j, p = 1LL * p * u % mod) {
int x = A[i + j], y = 1LL * A[i + j + t] * p % mod;
A[i + j] = (x + y) % mod, A[i + j + t] = (x + mod - y) % mod;
}
}
}
if (type == -1) {
int invn = qk(n, mod - 2);
for (int i = 0; i < n; ++i) {
A[i] = 1LL * A[i] * invn % mod;
}
}
}
long long f[2000004];
int main() {
int n, k, d, maxi = 0;
memset(f, 0, sizeof(f));
cin >> n >> k;
for (int i = 0; i < k; ++i) {
cin >> d;
maxi = max(maxi, d);
f[d] = 1;
}
int len = 1;
while (len <= ((n / 2) * maxi))
{
len <<= 1;
}
ntt(f, len, 1);
for (int i = 0; i < len; ++i) {
f[i] = qk(f[i], n / 2);
}
ntt(f, len, -1);
long long ans = 0;
for (int i = 0; i < len; ++i) {
ans = (ans + 1LL * f[i] * f[i] % mod) % mod;
}
cout << ans;
}