G. Lucky Tickets
time limit per test5 seconds
memory limit per test256 megabytes
inputstandard input
outputstandard output
All bus tickets in Berland have their numbers. A number consists of n digits (n is even). Only k decimal digits d1,d2,…,dk can be used to form ticket numbers. If 0 is among these digits, then numbers may have leading zeroes. For example, if n=4 and only digits 0 and 4 can be used, then 0000, 4004, 4440 are valid ticket numbers, and 0002, 00, 44443 are not.
A ticket is lucky if the sum of first n/2 digits is equal to the sum of remaining n/2 digits.
Calculate the number of different lucky tickets in Berland. Since the answer may be big, print it modulo 998244353.
Input
The first line contains two integers n and k (2≤n≤2⋅105,1≤k≤10) — the number of digits in each ticket number, and the number of different decimal digits that may be used. n is even.
The second line contains a sequence of pairwise distinct integers d1,d2,…,dk (0≤di≤9) — the digits that may be used in ticket numbers. The digits are given in arbitrary order.
Output
Print the number of lucky ticket numbers, taken modulo 998244353.
Examples
inputCopy
4 2
1 8
outputCopy
6
inputCopy
20 1
6
outputCopy
1
inputCopy
10 5
6 1 4 0 3
outputCopy
569725
inputCopy
1000 7
5 4 0 1 8 3 2
outputCopy
460571165
Note
In the first example there are 6 lucky ticket numbers: 1111, 1818, 1881, 8118, 8181 and 8888.
There is only one ticket number in the second example, it consists of 20 digits 6. This ticket number is lucky, so the answer is 1.
最近因为在考期,所以写题有些懈怠,考后继续加油。另外,最近计划参加ACM的新人大佬非常多,不得不说有些担心明年的选拔了。
此题出自昨天CF的EDU Round,前面四题都1A,大概还有一个多小时,写完这题就rating就可以上黄色了,然而显然上天不愿意给我这个机会QwQ。
题目很容易理解,也很容易想到多项式乘法。因为问题核心是要求n/2位数来表示某个数字的方法数,dp的方式就是不断乘以一个不超过十位的,由0和1组成的数。那么用快速幂或者分治的形式分解问题。两个大数相乘时采用FFT,那么效率就是nlognlogn,似乎勉强能过。
很开心地写了一发。因为比赛时间问题,我写了更清晰但是低效的递归分治法,并且采用vector传参(反正CF不用担心爆栈,随便乱玩)。
#include<cstdio>
#include<vector>
#include<complex>
#define mo 998244353
using namespace std;
using db=long double;
using LL=long long;
using cp=complex<db>;
const int maxn=(1<<21)+5;
const db PI=acos(-1.0L);
int rev[maxn],s,sum;
int n,k;
vector<int> h(10);
void get_rev(int bit)
{
for(int i=0;i<(1<<bit);i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<bit-1);
}
void FFT(vector<cp> &a, int n, int dft)
{
cp x,y;
for(int i=0;i<n;i++)
if(i<rev[i])
swap(a[i],a[rev[i]]);
for(int stp=1;stp<n;stp<<=1)
{
cp wn=exp(cp(0,dft*PI/stp));
for(int j=0;j<n;j+=stp<<1)
{
cp wnk(1,0);
for(int k=j;k<j+stp;k++)
{
x=a[k];
y=wnk*a[k+stp];
a[k]=x+y;
a[k+stp]=x-y;
wnk*=wn;
}
}
}
if(dft==-1)
for(int i=0;i<n;i++)
a[i]/=n;
}
vector<int> mul(vector<int> x, vector<int> y)
{
int s=2,l1=x.size(),l2=y.size(),bit;
for(bit=1;(1<<bit)<l1+l2-1;bit++)
s<<=1;
get_rev(bit);
vector<int> op(s);
vector<cp> a(s),b(s);
for(int i=0;i<l1;i++)
a[i]=(db)x[i];
for(int i=0;i<l2;i++)
b[i]=(db)y[i];
FFT(a,s,1);
FFT(b,s,1);
for(int i=0;i<s;i++)
a[i]*=b[i];
FFT(a,s,-1);
for(int i=0;i<s;i++)
op[i]=(LL)(a[i].real()+0.5)%mo;
while(!op.empty()&&op.back()==0)
op.pop_back();
return op;
}
vector<int> solve(int x)
{
if(x==0)
return vector<int>(1,1);
if(x==1)
return h;
vector<int> a=solve(x/2);
if(x&1)
return mul(mul(a,a),h);
else
return mul(a,a);
}
int main()
{
scanf("%d%d",&n,&k);
for(int i=1,tmp;i<=k;i++)
scanf("%d",&tmp),h[tmp]=1;
vector<int> ans=solve(n/2);
for(int i:ans)
sum=(sum+(LL)i*i%mo)%mo;
printf("%d",sum);
return 0;
}
然而出现了两个问题,第一是最后一个样例过不了,这是因为浮点数精度的问题,因为FFT过程中,答案每一位最大值是(1E91E9位数)级别的,运算过程中a[i]最大值是1E91E9位数*位数级别的,位数最多1E6,那么总共可以达到1E30,即使用long double也力不从心。(妈耶,一开始还以为自己模板写错了,查了半天错)
通过printf("%.20Lf",3.14159265358979323846264338327950288L);输出可知long double大概18-19位精度。
第二个问题就是,FFT效率不高,极限数据会TLE。
第二天看了看大佬们的代码,查了查,才知道有一种NTT的东西(快速数论变换)通过整数代替复数,原根代替FFT中的单位根,逆元代替除法得到答案。这里的原根指的就是数论的原根,和复数单位根有相同的性质。
这样的话只要在变换时不断取模就不会出现精度问题,而且效率快得不是一星半点。
参考文章:
https://blog.youkuaiyun.com/enjoy_pascal/article/details/81771910
基础NTT(求FFT取模,避免了精度问题)
https://www.cnblogs.com/fenghaoran/p/7107608.html
https://blog.youkuaiyun.com/qq_35950004/article/details/79477797 任意模数NTT(利用中国剩余定理,可以计算出真值,效果完全等价于FFT,同时避免了精度问题)
任意模数的NTT还可以用另一种把多项式的每一项系数拆成AM+B的取巧方法,参见
https://www.cnblogs.com/xzyxzy/p/9263480.html
模数998244353是素数,而且是NTT素数(为啥必须是NTT素数,看过FFT的都知道了)即(P-1)有超过序列长度的2的正整数幂因子的质数,其中一个原根是3,如果是其他NTT素数的话,可以暴力法求原根。如果不是NTT素数,但是能分解出NTT素数因子,求每一个素因子的模数的答案,然后综合即可,除此之外只能用任意模数NTT的方法来做了。
#include<cstdio>
#include<vector>
#include<complex>
#define mo 998244353
#define root 3
using namespace std;
using LL=long long;
const int maxn=(1<<21)+5;
int rev[maxn],s,sum;
int n,k;
vector<int> h(10);
int quick_power(int a, int b)
{
int res=1,base=a;
while(b)
{
if(b&1)
res=(LL)res*base%mo;
base=(LL)base*base%mo;
b>>=1;
}
return res;
}
void get_rev(int bit)
{
for(int i=0;i<(1<<bit);i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<bit-1);
}
void FFT(vector<int> &a, int n, int dft)
{
int x,y;
for(int i=0;i<n;i++)
if(i<rev[i])
swap(a[i],a[rev[i]]);
for(int stp=1;stp<n;stp<<=1)
{
int wn=quick_power(root,(mo-1)/(stp*2));
if(dft==-1)
wn=quick_power(wn,mo-2);
for(int j=0;j<n;j+=stp<<1)
{
int wnk=1;
for(int k=j;k<j+stp;k++)
{
x=a[k];
y=(LL)wnk*a[k+stp]%mo;
a[k]=(x+y)%mo;
a[k+stp]=(x-y+mo)%mo;
wnk=(LL)wnk*wn%mo;
}
}
}
if(dft==-1)
{
int t=quick_power(n,mo-2);
for(int i=0;i<n;i++)
a[i]=(LL)a[i]*t%mo;
}
}
vector<int> mul(vector<int> x, vector<int> y)
{
int s=2,l1=x.size(),l2=y.size(),bit;
for(bit=1;(1<<bit)<l1+l2-1;bit++)
s<<=1;
get_rev(bit);
x.resize(s),y.resize(s);
FFT(x,s,1);
FFT(y,s,1);
for(int i=0;i<s;i++)
x[i]=(LL)x[i]*y[i]%mo;
FFT(x,s,-1);
while(!x.empty()&&x.back()==0)
x.pop_back();
return x;
}
vector<int> solve(int x)
{
if(x==0)
return vector<int>(1,1);
if(x==1)
return h;
vector<int> a=solve(x/2);
if(x&1)
return mul(mul(a,a),h);
else
return mul(a,a);
}
int main()
{
scanf("%d%d",&n,&k);
for(int i=1,tmp;i<=k;i++)
scanf("%d",&tmp),h[tmp]=1;
vector<int> ans=solve(n/2);
for(int i:ans)
sum=(sum+(LL)i*i%mo)%mo;
printf("%d",sum);
return 0;
}
最后才意识到一件很傻的事情,FFT不只是可以干两个多项式的乘法,多项式快速幂也是OK的,DFT之后每答案取p次方再IDFT就是答案,根本不用分治,自己思维之前有些僵化了。附上最终的代码吧,跑得还挺快的。
#include<cstdio>
#include<vector>
#include<complex>
#define mo 998244353
#define root 3
using namespace std;
using LL=long long;
const int maxn=(1<<21)+5;
int rev[maxn],s,sum;
int n,k;
vector<int> h(10);
int quick_power(int a, int b)
{
int res=1,base=a;
while(b)
{
if(b&1)
res=(LL)res*base%mo;
base=(LL)base*base%mo;
b>>=1;
}
return res;
}
void get_rev(int bit)
{
for(int i=0;i<(1<<bit);i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<bit-1);
}
void FFT(vector<int> &a, int n, int dft)
{
int x,y;
for(int i=0;i<n;i++)
if(i<rev[i])
swap(a[i],a[rev[i]]);
for(int stp=1;stp<n;stp<<=1)
{
int wn=quick_power(root,(mo-1)/(stp*2));
if(dft==-1)
wn=quick_power(wn,mo-2);
for(int j=0;j<n;j+=stp<<1)
{
int wnk=1;
for(int k=j;k<j+stp;k++)
{
x=a[k];
y=(LL)wnk*a[k+stp]%mo;
a[k]=(x+y)%mo;
a[k+stp]=(x-y+mo)%mo;
wnk=(LL)wnk*wn%mo;
}
}
}
if(dft==-1)
{
int t=quick_power(n,mo-2);
for(int i=0;i<n;i++)
a[i]=(LL)a[i]*t%mo;
}
}
void ntt_pow(vector<int> &x, int p)
{
int s=2,l1=x.size(),bit;
for(bit=1;(1<<bit)<p*l1-p+1;bit++)
s<<=1;
get_rev(bit);
x.resize(s);
FFT(x,s,1);
for(int i=0;i<s;i++)
x[i]=quick_power(x[i],p);
FFT(x,s,-1);
}
int main()
{
scanf("%d%d",&n,&k);
for(int i=1,tmp;i<=k;i++)
scanf("%d",&tmp),h[tmp]=1;
ntt_pow(h,n/2);
for(int i:h)
sum=(sum+(LL)i*i%mo)%mo;
printf("%d",sum);
return 0;
}