D. New Year and the Permutation Concatenation
Let nnn be an integer. Consider all permutations on integers 1 to nnn in lexicographic order, and concatenate them into one big sequence ppp. For example, if n=3n=3n=3, then p=[1,2,3,1,3,2,2,1,3,2,3,1,3,1,2,3,2,1].p=[1,2,3,1,3,2,2,1,3,2,3,1,3,1,2,3,2,1].p=[1,2,3,1,3,2,2,1,3,2,3,1,3,1,2,3,2,1]. The length of this sequence will be n∗n!n*n!n∗n!.
Let 1≤i≤j≤n∗n!1≤i≤j≤n*n!1≤i≤j≤n∗n! be a pair of indices. We call the sequence (pi,pi+1,…,pj−1,pj)(p_i,p_{i+1},…,p_{j−1},p_j)(pi,pi+1,…,pj−1,pj) a subarray of ppp. Its length is defined as the number of its elements, i.e., j−i+1j-i+1j−i+1. Its sum is the sum of all its elements, i.e.,∑k=ijpk\sum_{k=i}^{j}{p_k}∑k=ijpk.
You are given nnn. Find the number of subarrays of ppp of length nnn having sum n∗(n+1)2\frac{n*(n+1)}{2}2n∗(n+1). Since this number may be large, output it modulo 998244353998244353998244353 (a prime number).
Input
The only line contains one integer n (1≤n≤106)n\ (1≤n≤10^6)n (1≤n≤106), as described in the problem statement.
Output
Output a single integer — the number of subarrays of length nnn having sum n∗(n+1)2\frac{n*(n+1)}{2}2n∗(n+1), modulo 998244353.
Examples
input
3
output
9
input
4
output
56
input
10
output
30052700
Note
In the first sample, there are 16 subarrays of length 3. In order of appearance, they are:
[1,2,3],[2,3,1],[3,1,3],[1,3,2],[3,2,2],[2,2,1],[2,1,3],[1,3,2],[3,2,3],[2,3,1],[3,1,3],[1,3,1],[3,1,2],[1,2,3],[2,3,2],[3,2,1].[1,2,3], [2,3,1], [3,1,3], [1,3,2], [3,2,2], [2,2,1], [2,1,3], [1,3,2], [3,2,3], [2,3,1], [3,1,3], [1,3,1], [3,1,2], [1,2,3], [2,3,2], [3,2,1].[1,2,3],[2,3,1],[3,1,3],[1,3,2],[3,2,2],[2,2,1],[2,1,3],[1,3,2],[3,2,3],[2,3,1],[3,1,3],[1,3,1],[3,1,2],[1,2,3],[2,3,2],[3,2,1].
Their sums are 6,6,7,6,7,5,6,6,8,6,7,5,6,6,7,6.6, 6, 7, 6, 7, 5, 6, 6, 8, 6, 7, 5, 6, 6, 7, 6.6,6,7,6,7,5,6,6,8,6,7,5,6,6,7,6. As n∗(n+1)2\frac{n*(n+1)}{2}2n∗(n+1)=6, the answer is 9.
-
题意:
将一个含有1~n的所有排列按字典序排序并依次首尾拼接,问子序列(连续)的个数使得其元素和为n∗(n+1)2\frac{n*(n+1)}{2}2n∗(n+1) -
解法:
- 至于为什么这些子序列的长度都是n还不会证明。。
- 首先你得知道c++c++c++库函数next_permutationnext\_permutationnext_permutation的实现原理:
-
/** Tips: next permuation based on the ascending order sort * sketch : * current: 3 7 6 2 5 4 3 1 . * | | | | * find i----+ j k +----end * swap i and k : * 3 7 6 3 5 4 2 1 . * | | | | * i----+ j k +----end * reverse j to end : * 3 7 6 3 1 2 4 5 . * | | | | * find i----+ j k +----end * */
-
具体方法为:
-
从后向前查找第一个相邻元素对(i,j)(i,j)(i,j),并且满足A[i]<A[j](i<j)A[i] < A[j](i<j)A[i]<A[j](i<j)
易知,此时从j到end必然是降序。可以用反证法证明,请自行证明
-
在[j,end)[j,end)[j,end)中寻找一个最小的kkk使其满足A[i]<A[k]A[i]<A[k]A[i]<A[k]
由于[j,end)[j,end)[j,end)是降序的,所以必然存在一个k满足上面条件;并且可以从后向前查找第一个满足A[i]<A[k]A[i]<A[k]A[i]<A[k]关系的kkk,此时的kkk必是待找的kkk
-
将A[i]A[i]A[i]与A[k]A[k]A[k]交换
此时,iii处变成比iii大的最小元素,因为下一个全排列必须是与当前排列按照升序排序相邻的排列,故选择最小的元素替代iii
易知,交换后的[j,end)[j,end)[j,end)仍然满足降序排序。因为在(k,end)(k,end)(k,end)中必然小于iii,在[j,k)[j,k)[j,k)中必然大于kkk,并且大于iii
-
逆置[j,end)[j,end)[j,end)
由于此时[j,end)[j,end)[j,end)是降序的,故将其逆置。最终获得下一全排序
-
注意:如果在步骤a)找不到符合的相邻元素对,即此时i=begini=begini=begin,则说明当前[begin,end)[begin,end)[begin,end)为一个降序顺序,即无下一个全排列,STLSTLSTL的方法是将其逆置成升序
-
-
- next_permutationnext\_permutationnext_permutation函数找到当前序列的一个比当前序列字典序大的最小字典序的排列,也就是对应了题目要求,那么本题我们可以这么想:对于可以构成 n∗(n+1)2\frac{n*(n+1)}{2}2n∗(n+1)的连续子序列,他的位置有如下两种情况:
- 恰好为第iii个排列,可以发现所有的n!n!n!个排列都满足要求
- 或者是第iii个排列的后kkk个与第i+1i+1i+1个排列的前n−kn-kn−k个排列拼接而成
- 现在考虑后者:显然要使这nnn个数恰好是一个排列的话,第i个排列的前n−kn-kn−k个一定与第i+1i+1i+1个排列的前n−kn-kn−k个相同,那么联系nextpermutationnext_permutationnextpermutation的过程,不难发现,最长递减后缀长度len<klen<klen<k才能满足,接下来统计会有多少区间
- 正难则反!直接统计不行的会方便很多:后kkk位从nnn个数中选出kkk个递减依次放置在最后棉,方案数为CnkC_{n}^{k}Cnk,前面n−kn-kn−k个直接把剩下的排列就行了,方案数为(n−k)!(n-k)!(n−k)!,所以答案就是n∗n!−∑k=1n−1Cnk∗(n−k)!=n∗n!−∑k=1n−1n!k!n*n!-\sum_{k=1}^{n-1}{C_{n}^{k}*(n-k)!} =n*n!-\sum_{k=1}^{n-1}{\frac{n!}{k!}}n∗n!−k=1∑n−1Cnk∗(n−k)!=n∗n!−k=1∑n−1k!n!
- 有大佬看出来这个题答案是下面这个式子:ans[n]=(ans[n−1]+(n−1)!−1)∗nans[n]=(ans[n-1]+(n-1)!-1)*nans[n]=(ans[n−1]+(n−1)!−1)∗n
-
附代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353;
const int maxn=1e6+10;
int n;
ll fac[maxn];
ll quick_pow(ll a,ll b)
{
ll res=1LL;
while(b){
if(b&1) res=res*a%mod;
a=a*a%mod;
b>>=1;
}
return res;
}
int main()
{
scanf("%d",&n);fac[0]=1LL;
for(int i=1;i<=n;i++) fac[i]=fac[i-1]*i%mod;
ll ans=1LL*n*fac[n]%mod;
for(int i=1;i<n;i++) ans=(ans-(fac[n]*quick_pow(fac[i],mod-2)%mod))%mod;
printf("%lld\n",(ans+mod)%mod);
}