正难则反,考虑长度为 iii 的排列得到正确的结果的方案数。
设 dpidp_idpi 表示长度为 iii 的排列直到循环完也没有提前 return
的方案数。考虑 iii 所放置的位置,由于不会提前 return
,也就说明该数字所在的位置为 [i−k+1,i][i - k + 1,i][i−k+1,i] 的范围中。因此可以枚举 iii 的位置为 jjj,则相当于将 [1,i][1,i][1,i] 的区间分为 [1,j−1],[j],[j+1,i][1,j - 1],[j],[j + 1,i][1,j−1],[j],[j+1,i]。
第一段为 i−1i - 1i−1 个数字中选择 j−1j - 1j−1 个,也就是 (i−1j−1)\binom{i-1}{j-1}(j−1i−1),然后合法的方案数为 dpj−1dp_{j - 1}dpj−1;第二段放最大值 iii,第三段还剩下 i−ji - ji−j 个数字,随意放置,也就是 (i−j)!(i - j)!(i−j)!。虽然说 dpidp_idpi 的状态考虑的是排列,但是显然我们只需要考虑数字之间的相对大小,因此第一段的方案数是合理的。可以得到以下转移:
dpi=∑j=i−k+1i(i−1j−1)×dpj−1×(i−j)!dp_i=\sum_{j=i-k+1}^i \binom{i-1}{j-1}\times dp_{j-1}\times (i-j)!dpi=j=i−k+1∑i(j−1i−1)×dpj−1×(i−j)!
尝试进行化简,可以得到:
dpi=∑j=i−k+1i(i−1)!(j−1)!×((i−1)−(j−1))!×dpj−1×(i−j)!=∑j=i−k+1i(i−1)!(j−1)!×dpj−1=(i−1)!×∑j=i−ki−1dpjj! dp_i = \sum_{j=i-k+1}^i \frac{(i - 1)!}{(j - 1)! \times ((i - 1) - (j - 1))!}\times dp_{j-1}\times (i-j)! \\ = \sum_{j=i-k+1}^i\frac{(i - 1)!}{(j - 1)!} \times dp_{j - 1} \\ = (i - 1)! \times \sum_{j=i-k}^{i - 1} \frac{dp_j}{j!} dpi=j=i−k+1∑i(j−1)!×((i−1)−(j−1))!(i−1)!×dpj−1×(i−j)!=j=i−k+1∑i(j−1)!(i−1)!×dpj−1=(i−1)!×j=i−k∑i−1j!dpj
维护一段长度为 kkk 的 dpii!\frac{dp_i}{i!}i!dpi 的和即可 O(n)O(n)O(n) 求出 dpidp_idpi。
最后再考虑答案。若最后求得的答案是正确的,我们只需要枚举 nnn 所在的位置即可。因此总共合法的方案数为:
ans=∑i=1n(n−1i−1)×dpi−1×(n−i)! ans = \sum_{i = 1}^n \binom {n - 1}{i - 1} \times dp_{i - 1} \times (n - i)! ans=i=1∑n(i−1n−1)×dpi−1×(n−i)!
最后的答案就是 n!−ansn!-ansn!−ans。代码如下:
#include <bits/stdc++.h>
#define init(x) memset (x,0,sizeof (x))
#define ll long long
#define ull unsigned long long
#define INF 0x3f3f3f3f
#define pii pair <int,int>
using namespace std;
const int MAX = 1e6 + 5;
const int MOD = 1e9 + 7;
inline int read ();
int n,k;ll tot,sum,dp[MAX],f[MAX],inv[MAX];
ll qpow (ll x,ll y)
{
ll res = 1;
while (y)
{
if (y & 1) res = res * x % MOD;
x = x * x % MOD;
y >>= 1;
}
return res;
}
int main ()
{
n = read ();k = read ();
inv[0] = f[0] = 1;
for (int i = 1;i <= n;++i) f[i] = f[i - 1] * i % MOD;
inv[n] = qpow (f[n],MOD - 2);
for (int i = n - 1;i;--i) inv[i] = inv[i + 1] * (i + 1) % MOD;
dp[0] = sum = 1;
for (int i = 1;i <= n;++i)
{
dp[i] = f[i - 1] * sum % MOD;
sum = (sum + dp[i] * inv[i] % MOD) % MOD;
if (i >= k) sum = (sum - dp[i - k] * inv[i - k] % MOD + MOD) % MOD;
}
for (int i = 1;i <= n;++i) tot = (tot + dp[i - 1] * f[n - 1] % MOD * inv[i - 1] % MOD) % MOD;
printf ("%lld\n",(f[n] - tot + MOD) % MOD);
return 0;
}
inline int read ()
{
int s = 0;int f = 1;
char ch = getchar ();
while ((ch < '0' || ch > '9') && ch != EOF)
{
if (ch == '-') f = -1;
ch = getchar ();
}
while (ch >= '0' && ch <= '9')
{
s = s * 10 + ch - '0';
ch = getchar ();
}
return s * f;
}