思路:先考虑单个序列的最大值怎么求,然后考虑算贡献
可以参考一下:https://zhuanlan.zhihu.com/p/2487534128,知乎大佬的解释。
package CodeforcesRound979Div2;
import java.io.*;
import java.util.*;
public class e {
public static BufferedReader rd=new BufferedReader(new InputStreamReader(System.in));
public static BufferedWriter wd=new BufferedWriter(new OutputStreamWriter(System.out));
public static long inv[],fac[],infac[];
public static int mod=998244353;
public static void init(int x) {
inv=new long [x+10];fac=new long [x+10];infac=new long [x+10];
inv[1]=fac[0]=infac[0]=1;
for(int i=1;i<=x;i++) fac[i]=fac[i-1]*i%mod;
for(int i=2;i<=x;i++) inv[i]=(-mod/i*inv[mod%i]%mod+mod)%mod;
for(int i=1;i<=x;i++) infac[i]=infac[i-1]*inv[i]%mod;
}
public static long c(int a,int b) {
if(a<b) return 0;
return fac[a]*infac[b]%mod*infac[a-b]%mod;
}
public static void solve()throws Exception{
int n=Integer.parseInt(rd.readLine());
int a[]=new int [n+10],cnt []=new int [n+10];
String dr[]=rd.readLine().split(" ");
for(int i=1;i<=n;i++) {
a[i]=Integer.parseInt(dr[i-1]);
cnt[a[i]]++;
}
int suf[]=new int [n+10];
for(int i=n-1;i>=0;i--) suf[i]=suf[i+1]+cnt[i];
long zcmi[]=new long [n+10];
zcmi[0]=1;
for(int i=1;i<=n;i++) zcmi[i]=zcmi[i-1]*2%mod;
long dp[][]=new long [n+10][],sum[][]=new long [n+10][];
dp[0]=new long [cnt[0]+2];sum[0]=new long [cnt[0]+2];
long ans=0;
for(int j=0;j<=cnt[0];j++) {
dp[0][j]=c(cnt[0],j);
ans=(ans+dp[0][j]*j%mod*zcmi[suf[1]]%mod)%mod;
}
sum[0][cnt[0]]=dp[0][cnt[0]];
for(int i=cnt[0]-1;i>=0;i--) sum[0][i]=(dp[0][i]+sum[0][i+1])%mod;
for(int i=1;i<n;i++) {
long qz=0;
dp[i]=new long [cnt[i]+2];
sum[i]=new long [cnt[i]+2];
for(int j=cnt[i];j>=0;j--) {
qz=(qz+c(cnt[i],j+1))%mod;
if(j<=cnt[i-1]) {
//最小值来自我这个j;
dp[i][j]=sum[i-1][j]*c(cnt[i],j)%mod;
//最小值来自前面的j;
dp[i][j]=(dp[i][j]+dp[i-1][j]*qz%mod)%mod;
}
else dp[i][j]=0;
}
sum[i][cnt[i]]=dp[i][cnt[i]];
for(int j=cnt[i]-1;j>=0;j--) sum[i][j]=(dp[i][j]+sum[i][j+1])%mod;
for(int j=0;j<=cnt[i];j++) {
ans=(ans+dp[i][j]*j%mod*zcmi[suf[i+1]]%mod)%mod;
}
}
wd.write(ans+"\n");
}
public static void main(String[] args)throws Exception{
int t=Integer.parseInt(rd.readLine());
init((int)3e5+10);
//wd.write(c(6,2)+"\n");
while(t-->0) {
solve();
wd.flush();
}
}
}