题目
设a,ba,ba,b分别为1∼n1\sim n1∼n的排列。
求有多少个排列对(a,b)(a,b)(a,b)满足∑i=1nmax{ai,bi}≥m\sum_{i=1}^n\max\{a_i,b_i\}\ge m∑i=1nmax{ai,bi}≥m。
两个排列对(a,b)(a,b)(a,b)和(c,d)(c,d)(c,d)不同当且仅当存在一个iii,使得ai≠cia_i\not=c_iai=ci或者bi≠dib_i\not=d_ibi=di。
数据范围为n≤50,0≤m≤109n\le 50, 0\le m\le 10^9n≤50,0≤m≤109。
分析
首先发现∑i=1nmax{ai,bi}≥m\sum_{i=1}^n \max\{a_i,b_i\}\ge m∑i=1nmax{ai,bi}≥m有一个比较宽松的上界n2n^2n2,因此我们只需要考虑m≤n2m\le n^2m≤n2的情况。
动用一个套路——我们假设ai=ia_i=iai=i,那么我们就可以单纯枚举bbb,找出这种情况下的方案数,再乘上n!n!n!即可。
事实上我们需要做的就是计算有多少个排列(或者可以称为 " 置换 " )ppp满足∑i=1nmax{pi,i}≥m\sum_{i=1}^n\max\{p_i,i\}\ge m∑i=1nmax{pi,i}≥m。而计算这样的排列可以理解为下标和值的对应。
因此可以考虑如下的 DP :
f(i,j,k)f(i,j,k)f(i,j,k):前iii个数中,分别有jjj个下标和值还没有对应上,已经对应的和为kkk的方案数。
考虑转移,分 3 种情况:
1.什么也不干,方案数为f(i−1,j−1,k)f(i-1,j-1,k)f(i−1,j−1,k)。
2.下标iii与一个值配对,或者值iii与一个下标配对。这样会有2j+12j+12j+1种情况(下标iii与值iii配对当然只算一次),因此方案数为(2j+1)f(i−1,j,k−i)(2j+1)f(i-1,j,k-i)(2j+1)f(i−1,j,k−i)。
3.下标iii与一个值配对,且值iii与一个下标配对。注意到这样的话会一次减少一个未配对的下标和未配对的值,所以在进行配对前分别有j+1j+1j+1个下标和值未配对,因此情况为(j+1)2(j+1)^2(j+1)2,方案数为(j+1)2f(i−1,j+1,k−2i)(j+1)^2f(i-1,j+1,k-2i)(j+1)2f(i−1,j+1,k−2i)。
最后统计kkk在[m,n2][m,n^2][m,n2]中的方案总数,并且不要忘了乘上n!n!n!。
代码
#include <cstdio>
const int mod = 998244353;
const int MAXN = 55, MAXS = MAXN * MAXN;
template<typename _T>
void read( _T &x )
{
x = 0;char s = getchar();int f = 1;
while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
x *= f;
}
template<typename _T>
void write( _T x )
{
if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
if( 9 < x ){ write( x / 10 ); }
putchar( x % 10 + '0' );
}
template<typename _T>
_T MAX( const _T a, const _T b )
{
return a > b ? a : b;
}
int f[MAXN][MAXN][MAXS];
int N, M;
int main()
{
read( N ), read( M );
f[0][0][0] = 1;
for( int i = 1 ; i <= N ; i ++ )
for( int j = 0 ; j <= i ; j ++ )
for( int k = 0 ; k <= N * N ; k ++ )
{
if( j ) ( f[i][j][k] += f[i - 1][j - 1][k] ) %= mod;
if( k >= i ) ( f[i][j][k] += 1ll * ( 2 * j + 1 ) % mod * f[i - 1][j][k - i] % mod ) %= mod;
if( k >= 2 * i ) ( f[i][j][k] += 1ll * ( j + 1 ) * ( j + 1 ) % mod *
f[i - 1][j + 1][k - 2 * i] % mod ) %= mod;
}
int ans = 0;
for( int i = M ; i <= N * N ; i ++ ) ( ans += f[N][0][i] ) %= mod;
for( int i = 2 ; i <= N ; i ++ ) ans = 1ll * ans * i % mod;
write( ans ), putchar( '\n' );
return 0;
}