Description
有n个格子,现在用m种颜色按顺序染m次,每次可以染一段区间(如果区间内有别的颜色将会被这种颜色覆盖),问最终所有格子都有颜色的情况下,不同的颜色序列有多少种。
Solution
最终序列肯定是一段一段的颜色,其实每次染色相当于从原有的颜色段中插入一段颜色。
设
f
i
,
j
f_{i,j}
fi,j表示前
i
i
i次染色,颜色段长度为
j
j
j的方案数,容易得到转移就是:
f
i
,
j
=
f
i
−
1
,
j
+
∑
k
=
0
j
−
1
(
k
+
1
)
f
i
−
1
,
k
f_{i,j}=f_{i-1,j}+\sum_{k=0}^{j-1}(k+1)f_{i-1,k}
fi,j=fi−1,j+k=0∑j−1(k+1)fi−1,k
把
f
i
,
j
f_{i,j}
fi,j看成
f
i
(
x
)
[
x
j
]
f_i(x)[x^j]
fi(x)[xj],转移可以看成
f
i
(
x
)
=
f
i
−
1
(
x
)
(
i
x
+
1
)
f_i(x)=f_{i-1}(x)(ix+1)
fi(x)=fi−1(x)(ix+1),最后再乘上个组合数,答案可以写成(这里m颜色必须染):
∑
i
=
1
m
C
m
−
1
i
−
1
(
[
x
i
]
x
∏
j
=
2
n
(
j
x
+
1
)
)
\sum_{i=1}^mC_{m-1}^{i-1}([x^i]x\prod_{j=2}^n(jx+1))
i=1∑mCm−1i−1([xi]xj=2∏n(jx+1))
分治NTT会T,考虑倍增:
设
f
m
(
x
)
=
∏
i
=
1
m
(
(
i
+
1
)
x
+
1
)
f_m(x)=\prod\limits_{i=1}^m((i+1)x+1)
fm(x)=i=1∏m((i+1)x+1),
f
m
′
(
x
)
=
∏
i
=
m
+
1
2
m
(
(
i
+
1
)
x
+
1
)
f_m'(x)=\prod\limits_{i=m+1}^{2m}((i+1)x+1)
fm′(x)=i=m+1∏2m((i+1)x+1)
那么
f
2
m
(
x
)
=
f
m
(
x
)
f
m
′
(
x
)
f_{2m}(x)=f_m(x)f_m'(x)
f2m(x)=fm(x)fm′(x)
考虑求
f
m
′
(
x
)
f_m'(x)
fm′(x),改写一下变成
f
m
′
(
x
)
=
∏
i
=
1
m
(
(
i
+
1
)
x
+
m
x
+
1
)
f_m'(x)=\prod\limits_{i=1}^m((i+1)x+mx+1)
fm′(x)=i=1∏m((i+1)x+mx+1)
设
a
i
=
f
m
(
x
)
[
x
i
]
a_i=f_m(x)[x^i]
ai=fm(x)[xi],
b
i
=
f
m
′
(
x
)
[
x
i
]
b_i=f_m'(x)[x^i]
bi=fm′(x)[xi],枚举有
f
m
(
x
)
f_m(x)
fm(x)有多少个
1
1
1变成了
m
m
m,可以得到:
b
i
+
j
=
C
m
−
i
j
m
j
a
i
b_{i+j}=C_{m-i}^jm^ja_i
bi+j=Cm−ijmjai
把组合数拆开,多项式乘法即可。
Code
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define fo(i,j,k) for(int i=j;i<=k;++i)
#define fd(i,j,k) for(int i=j;i>=k;--i)
using namespace std;
typedef long long ll;
const int N=1e6+10,M=3e6+10,mo=998244353;
int qpow(int x,int y){
int s=1;
for(;y;y>>=1,x=(ll)x*x%mo) if(y&1) s=(ll)s*x%mo;
return s;
}
int fn;
int jc[N],ny[N];
int rev[M];
int pl(int x,int y){
return x+y>=mo?x+y-mo:x+y;
}
void NTT(int *a,int sig){
fo(i,1,fn-1) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int m=2;m<=fn;m<<=1){
int half=m>>1,w0=qpow(3,(mo-1)/m);
if(sig<0) w0=qpow(w0,mo-2);
for(int i=0;i<fn;i+=m)
for(int j=i,w=1;j<i+half;++j,w=(ll)w*w0%mo){
int u=a[j],v=(ll)a[j+half]*w%mo;
a[j]=pl(u,v),a[j+half]=pl(u,mo-v);
}
}
if(sig<0){
int nf=qpow(fn,mo-2);
fo(i,0,fn-1) a[i]=(ll)a[i]*nf%mo;
}
}
void mul(int *a,int *b,int ln,int nd){
int cnt=0;
for(fn=1;fn<=(ln<<1);fn<<=1) ++cnt;
fo(i,1,fn-1) rev[i]=rev[i>>1]>>1|(i&1)<<(cnt-1);
fo(i,ln+1,fn-1) a[i]=b[i]=0;
NTT(a,1),NTT(b,1);
fo(i,0,fn-1) a[i]=(ll)a[i]*b[i]%mo;
NTT(a,-1);
fo(i,nd+1,fn-1) a[i]=0;
}
int b[M],c[M],d[M];
void solve(int *a,int n){
if(n==1){
a[0]=1,a[1]=2;
return;
}
int m=n>>1;
solve(a,m);
fo(i,0,m) b[i]=(ll)a[i]*jc[m-i]%mo,c[i]=(ll)qpow(m,i)*ny[i]%mo;
mul(b,c,m,m);
fo(i,0,m) b[i]=(ll)b[i]*ny[m-i]%mo;
mul(a,b,m,m<<1);
if(n&1){
c[0]=0;
fo(i,0,m<<1) c[i+1]=(ll)a[i]*(n+1)%mo;
fo(i,0,n) a[i]=pl(a[i],c[i]);
}
}
int C(int m,int n){
return (ll)jc[m]*ny[n]%mo*ny[m-n]%mo;
}
int a[M];
int main()
{
freopen("color.in","r",stdin);
freopen("color.out","w",stdout);
int n,m,mx;
scanf("%d %d",&n,&m),mx=max(n,m);
jc[0]=1;
fo(i,1,mx) jc[i]=(ll)jc[i-1]*i%mo;
ny[mx]=qpow(jc[mx],mo-2);
fd(i,mx,1) ny[i-1]=(ll)ny[i]*i%mo;
solve(a,n-1);
int ans=0;
fo(i,0,m-1) ans=pl(ans,(ll)a[i]*C(m-1,i)%mo);
printf("%d",ans);
}