题目大意
给定一个n×3的矩形,你要在一些格子上放东西,一个格子最多只能放一个。而且一个格子上放了东西会对四周有影响。
输入会给定一个3×3的01矩阵,表示当一个3×3的子矩阵中心放了东西时,哪些地方不能放东西。
譬如矩阵
表示一个东西上下左右都不能放东西。
请求出恰好放了m个东西的方案,答案对
1≤n≤2500
题目分析
首先我们可以写出一个很简单的状压dp:fi,s,j表示做到第i行,上一行的状态为
可以发现,如果我们把
考虑使用插值的方法求出这个多项式,这里采用傅里叶变换。我们做次数界次dp,每次把一个主次数界次单位根的次幂代入多项式然后计算。最后我们会得到次数界个值,其实就是答案多项式做了DFT的结果,最后做一次IDFT就好了。
可是中间的dp依然很慢,因为要做次数界(最多有3n)次。考虑将dp的转移用矩阵乘法快速幂来优化就好了。
时间复杂度O((23)33nlogn+3nlog3n)。
代码实现
#include <iostream>
#include <cstring>
#include <cstdio>
using namespace std;
const int P=998244353;
const int N=2505;
const int S=1<<3;
const int L=8192;
const int G=3;
int bitcnt[8]={0,1,1,2,1,2,2,3};
int omega[L+5],t[L+5],trs[L+5];
int mat[3][3],tmp[3][3];
bool legal[1<<6];
int f[N][S];
int a[L+5];
int n,m,len,wn;
struct matrix
{
int num[S][S];
int r,c;
matrix operator*(matrix const mat)const
{
matrix ret;ret.r=r,ret.c=mat.c,memset(ret.num,0,sizeof ret.num);
for (int i=0;i<ret.r;++i)
for (int j=0;j<ret.c;++j)
for (int k=0;k<c;++k)
(ret.num[i][j]+=1ll*num[i][k]*mat.num[k][j]%P)%=P;
return ret;
}
}zero,one,F;
matrix operator^(matrix x,int y)
{
matrix ret=zero;
for (;y;y>>=1,x=x*x) if (y&1) ret=ret*x;
return ret;
}
int quick_power(int x,int y)
{
int ret=1;
for (;y;y>>=1,x=1ll*x*x%P) if (y&1) ret=1ll*ret*x%P;
return ret;
}
void DFT(int *a,int sig)
{
for (int i=0;i<len;++i) t[trs[i]]=a[i];
for (int l=2;l<=len;l<<=1)
for (int h=l>>1,p=len/l,i=0;i<h;++i)
for (int w=omega[sig>0?i*p:len-i*p],j=i;j<len;j+=l)
{
int u=t[j],v=1ll*t[j+h]*w%P;
t[j]=(u+v)%P,t[j+h]=(u-v+P)%P;
}
for (int i=0;i<len;++i) a[i]=t[i];
}
void NTT_pre()
{
for (len=1;len<=n*3;len<<=1);
wn=quick_power(G,(P-1)/len),omega[0]=1;
for (int i=1;i<=len;++i) omega[i]=1ll*omega[i-1]*wn%P;
for (int i=0;i<len;++i)
{
int ret=0;
for (int x=i,j=1;j<len;x>>=1,j<<=1) ret=(ret<<1)|(x&1);
trs[i]=ret;
}
}
void pre()
{
for (int sta=0;sta<1<<6;++sta)
{
for (int s=sta,i=0;i<2;++i)
for (int j=0;j<3;++j,s>>=1)
tmp[i][j]=s&1;
bool flag=1;
for (int i=0;flag&&i<3;++i)
for (int j=0;flag&&j<3;++j)
if (tmp[i][j])
for (int x=0;flag&&x<3;++x)
for (int y=0;flag&&y<3;++y)
if (!(x==1&&y==1)&&mat[x][y])
{
int u=x-1+i,v=y-1+j;
if (u>=0&&u<3&&v>=0&&v<3&&tmp[u][v]) flag=0;
}
legal[sta]=flag;
}
zero.r=zero.c=S;
for (int i=0;i<S;++i) zero.num[i][i]=1;
}
int dp(int w)
{
memset(f,0,sizeof f);
int pw[4];pw[0]=1;
for (int i=1;i<=3;++i) pw[i]=1ll*pw[i-1]*w%P;
F.r=1,F.c=S,memset(F.num,0,sizeof F.num);
for (int s=0;s<S;++s) (F.num[0][s]+=quick_power(w,bitcnt[s]))%=P;
one.r=one.c=S,memset(one.num,0,sizeof one.num);
for (int s=0;s<S;++s)
for (int s_=0;s_<S;++s_)
if (legal[s|(s_<<3)])
(one.num[s][s_]+=pw[bitcnt[s_]]%P)%=P;
F=F*(one^(n-1));
int ret=0;
for (int s=0;s<S;++s) (ret+=F.num[0][s])%=P;
return ret;
}
int main()
{
freopen("battle.in","r",stdin),freopen("battle.out","w",stdout);
scanf("%d%d",&n,&m);
if (m>3*n) printf("0\n");
{
for (int i=0;i<3;++i)
for (int j=0;j<3;++j)
scanf("%d",&mat[i][j]);
pre(),NTT_pre();
for (int i=0;i<len;++i) a[i]=dp(omega[i]);
DFT(a,-1);
for (int inv=quick_power(len,P-2),i=0;i<len;++i) a[i]=1ll*a[i]*inv%P;
printf("%d\n",a[m]);
}
fclose(stdin),fclose(stdout);
return 0;
}