似乎是codeforces 960G 原题?
dalao博客,nlogn
这里主要贴代码。。。
一、nlogn写法:
#include<cstdio>
#include<algorithm>
#define maxn 400005
using namespace std;
const int mod = 998244353, G = 3;
int n,A,B,w[maxn],wlen;
long long fac[maxn],inv[maxn];
void FAC_INV(int N)
{
fac[0]=fac[1]=inv[0]=inv[1]=1;
for(int i=2;i<=N;i++) fac[i]=fac[i-1]*i%mod,inv[i]=(mod-mod/i)*inv[mod%i]%mod;
for(int i=2;i<=N;i++) inv[i]=inv[i]*inv[i-1]%mod;
}
inline int C(int n,int m){return fac[n]*inv[m]%mod*inv[n-m]%mod;}
inline int ksm(int a,int b){
int s=1;
for(;b;b>>=1,a=1ll*a*a%mod) if(b&1) s=1ll*s*a%mod;
return s;
}
int r[maxn];
void ntt(int *a,int len,int flg)
{
for(int i=0;i<len;i++) if(i<r[i]) swap(a[i],a[r[i]]);
for(int i=2;i<=len;i<<=1)
for(int j=0,t=wlen/i;j<len;j+=i)
for(int k=j,o=0;k<j+i/2;k++,o+=t)
{
int u=a[k],v=1ll*(flg==1?w[o]:w[wlen-o])*a[k+i/2]%mod;
a[k]=(u+v)%mod,a[k+i/2]=(u-v+mod)%mod;
}
if(flg==-1){
int ni=ksm(len,mod-2);
for(int i=0;i<len;i++) a[i]=1ll*a[i]*ni%mod;
}
}
void multiply(int *a,int *b,int len)
{
ntt(a,len,1),ntt(b,len,1);
for(int i=0;i<len;i++) a[i]=1ll*a[i]*b[i]%mod;
ntt(a,len,-1);
}
int a[maxn],b[maxn],tmp[maxn];
void solve(int n)
{
if(n==1) {a[1]=1;return;}
if(n&1){
solve(n-1);
for(int i=n;i>=1;i--) a[i]=(a[i-1]+1ll*a[i]*(n-1))%mod;
}
else{
solve(n>>1);
int l=n>>1,len=1,pw=1;while(len<(l+1)<<1) len<<=1;
for(int i=0;i<len;i++) r[i]=(r[i>>1]>>1)|(i&1?len>>1:0);
for(int i=1;i<=l;i++) pw=1ll*pw*l%mod,tmp[i]=fac[i]*a[i]%mod,b[i]=pw*inv[i]%mod;
tmp[0]=fac[0]*a[0],b[0]=inv[0];
for(int i=l+1;i<len;i++) tmp[i]=b[i]=0;
reverse(b,b+l+1);
multiply(b,tmp,len);
for(int i=0;i<=l;i++) b[i]=b[i+l]*inv[i]%mod;
for(int i=l+1;i<len;i++) b[i]=0;
multiply(a,b,len);
}
}
int main()
{
freopen("permutation.in","r",stdin);
freopen("permutation.out","w",stdout);
scanf("%d%d%d",&n,&A,&B);
if(n==1) return printf("%d",A==1&&B==1?1:0),0;
if(!A||!B) return puts("0"),0;
FAC_INV(n);
wlen=w[0]=1;while(wlen<n) wlen<<=1;
for(int i=1,j=ksm(G,(mod-1)/wlen);i<=wlen;i++) w[i]=1ll*w[i-1]*j%mod;
solve(n-1);
printf("%d",1ll*a[A+B-2]*C(A+B-2,A-1)%mod);
}
二、递归分治FFT:
(膜 F6 dalao)
#include<algorithm>
#include<cstring>
#include<cctype>
#include<cstdio>
#include<vector>
#define rep(i,x,y) for(int i=x; i<=y; ++i)
#define repd(i,x,y) for(int i=x; i>=y; --i)
#define mid (l+r>>1)
#define pb push_back
using namespace std;
typedef long long LL;
const int N=200005,M=262144,mod=998244353;
int n,a,b,len,bin[M];
LL Wn[18][M],A[M],B[M],ans;
vector <LL> dat[N];
LL getmi(LL a,LL x)
{
LL rt=1;
while(x)
{
if(x&1) rt=rt*a%mod;
a=a*a%mod,x>>=1;
}
return rt;
}
void FFT(LL a[],int len,int tp)
{
rep(i,0,len-1) if(i<bin[i]) swap(a[i],a[bin[i]]);
for(int i=1,cnt=0; i<len; ++cnt,i<<=1)
{
for(int j=0; j<len; j+=i<<1)
{
LL w=0,x,y;
rep(k,0,i-1)
{
x=a[j+k],y=a[i+j+k]*Wn[cnt][w+i],w+=tp;
a[j+k]=(x+y)%mod,a[i+j+k]=(x-y)%mod;
}
}
}
if(tp==-1)
{
LL x=getmi(len,mod-2);
rep(i,0,len-1) a[i]=a[i]*x%mod;
}
}
void pre()
{
rep(i,0,17)
{
int x=1<<i;
Wn[i][x]=1;
LL wn=getmi(3,(mod-1)/(x<<1));
rep(j,1,x-1) Wn[i][j+x]=Wn[i][j+x-1]*wn%mod;
wn=getmi(wn,mod-2);
rep(j,1,x-1) Wn[i][-j+x]=Wn[i][-j+x+1]*wn%mod;
}
}
void solve(int l,int r)
{
if(l==r)
{
dat[l].pb(1);
dat[l].pb(l);
return;
}
solve(l,mid);
solve(mid+1,r);
int szl=dat[l].size()-1;
int szr=dat[mid+1].size()-1;
for(len=1; len<=szl+szr; len<<=1);
rep(i,0,len-1) bin[i]=bin[i>>1]>>1|((i&1)*(len>>1)),A[i]=B[i]=0;
rep(i,0,szl) A[i]=dat[l][i];
rep(i,0,szr) B[i]=dat[mid+1][i];
FFT(A,len,1),FFT(B,len,1);
rep(i,0,len-1) A[i]=A[i]*B[i]%mod;
FFT(A,len,-1),dat[l].clear();
rep(i,0,szl+szr) dat[l].pb(A[i]);
}
LL C(int n,int m)
{
LL inv=1,flv=1;
rep(i,1,m) inv=inv*i%mod;
rep(i,n-m+1,n) flv=flv*i%mod;
return flv*getmi(inv,mod-2)%mod;
}
int main()
{
freopen("permutation.in","r",stdin);
freopen("permutation.out","w",stdout);
scanf("%d%d%d",&n,&a,&b);
if(!a || !b) return puts("0"),0;
if(n==1)
{
if(a==1 && b==1) puts("1");
else puts("0");
return 0;
}
pre();
solve(0,n-2);
int sz=dat[0].size()-1,p=n-1-(a-1+b-1);
if(p>=0 && p<=sz)
ans=dat[0][p];
else return puts("0"),0;
ans=ans*C(a-1+b-1,a-1)%mod;
printf("%lld\n",(ans+mod)%mod);
return 0;
}
三、优先队列NTT(每次取次数最小的两个因式来乘):
(膜 B1 dalao)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1<<18;
const int maxn2 = 1<<17;
const int mod = 998244353;
int n,a,b;
ll fac[maxn],ifac[maxn],wnori[maxn],*wn=wnori+maxn2;
inline int MOD(int x){return x<mod?x:x-mod;}
inline int Mod(int x){return x<0?x+mod:x;}
inline ll qpow(ll a,ll x){ll ans=1;for(;x;x>>=1,a=a*a%mod) if(x&1) ans=ans*a%mod;return ans;}
inline ll C(int n,int r){if(r<0 || r>n) return 0;return fac[n]*ifac[r]%mod*ifac[n-r]%mod;}
struct Poly{
vector<int> a;
void fft(vector<int>&A,int N,int dft)
{
static int a[maxn],rev[maxn];
int step,i,j,cur,stc=maxn*dft;
for(i=0;i<N;i++) a[rev[i]=rev[i>>1]>>1|(i&1)*(N>>1)]=A[i];
for(step=1;step<N;step<<=1)
{
stc/=2;
for(i=0;i<N;i+=step<<1) for(j=i,cur=0;j<i+step;j++,cur+=stc)
{
int x=a[j],y=a[j+step]*wn[cur]%mod;
a[j]=MOD(x+y);a[j+step]=Mod(x-y);
}
}
if(!~dft){ll invn=qpow(N,mod-2);for(i=0;i<N;i++) A[i]=a[i]*invn%mod;}
else for(i=0;i<N;i++) A[i]=a[i];
}
inline int size()const{return a.size();}
int &operator [](int x){return a[x];}
void operator *=(Poly&b)
{
int n=a.size()+b.size()-1,N,i;
for(N=1;N<n;N<<=1);
a.resize(N);b.a.resize(N);
fft(a,N,1);fft(b.a,N,1);
for(i=0;i<N;i++) a[i]=(ll)a[i]*b[i]%mod;
fft(a,N,-1);
a.resize(n);
}
bool operator < (const Poly&b)const{return size()>b.size();}
}P[maxn];
void pre()
{
int i;
fac[0]=1;
for(i=1;i<maxn;i++) fac[i]=fac[i-1]*i%mod;
ifac[maxn-1]=qpow(fac[maxn-1],mod-2);
for(i=maxn-2;~i;i--) ifac[i]=ifac[i+1]*(i+1)%mod;
wn[1]=qpow(3,(mod-1)/maxn);wn[-1]=qpow(wn[1],mod-2);wn[0]=1;
for(i=2;i<maxn2;i++) wn[i]=wn[i-1]*wn[1]%mod,wn[-i]=wn[-i+1]*wn[-1]%mod;
}
struct Poly_Cmp{bool operator()(int x,int y){return P[x]<P[y];}};
priority_queue<int,vector<int>,Poly_Cmp> Q;
void solve()
{
int i;
//calc s(n-1,a+b-2)
if(a+b-2>n-1 || a+b-2<0) puts("0"),exit(0);
for(i=0;i<n-1;i++) P[i].a.resize(2),P[i][0]=i,P[i][1]=1,Q.push(i);
for(i=0;i<n-2;i++)
{
int x=Q.top();Q.pop();
int y=Q.top();Q.pop();
P[x]*=P[y];
Q.push(x);
}
int x=Q.top();
printf("%lld\n",P[x][a+b-2]*C(a+b-2,a-1)%mod);
}
int main()
{
freopen("permutation.in","r",stdin);
freopen("permutation.out","w",stdout);
pre();
scanf("%d%d%d",&n,&a,&b);
if(n==1)
{
if(a==1 && b==1) puts("1");else puts("0");
return 0;
}
solve();
return 0;
}