Description
给出一个序列a1,...,ana1,...,an,求∑x1=1a1∑x2=1a2..∑xn=1anmax{x1,...,xn}∑x1=1a1∑x2=1a2..∑xn=1anmax{x1,...,xn}
Input
多组用例,每组用例首先输入一整数nn,之后输入个整数a1,...,an(1≤n≤1000,1≤ai≤109)a1,...,an(1≤n≤1000,1≤ai≤109)
Output
输出结果,答案模109+7109+7
Sample Input
2
1 2
5
2 3 3 3 3
Sample Output
3
453
Solution
首先将aa序列升序,考虑最大值的若干取值区间:,假设最大值kk取在第个区间[ai−1+1,ai][ai−1+1,ai],那么前i−1i−1个数字无论取何值都不会超过最大值,故方案数为∏j=1i−1aj∏j=1i−1aj,而后n−i+1n−i+1个数字每一个都可以取到11~,故该部分对答案的贡献为
∑k=ai−1+1aik(kn−i+1−(k−1)n−i+1)∑k=ai−1+1aik(kn−i+1−(k−1)n−i+1)
令F(l,r,x)=∑k=lrk(kx−(k−1)x)F(l,r,x)=∑k=lrk(kx−(k−1)x),则裂项相消有F(l,r,x)=rx+1−l(l−1)x−∑k=lr−1kxF(l,r,x)=rx+1−l(l−1)x−∑k=lr−1kx
令S(n,k)=∑i=1nikS(n,k)=∑i=1nik,只要求出S(n,k)S(n,k)即可得到F(l,r,x)F(l,r,x),进而答案为
∑i=1nF(ai−1+1,ai,n−i+1)∏j=1i−1aj∑i=1nF(ai−1+1,ai,n−i+1)∏j=1i−1aj
O(n2)O(n2)预处理伯努利数列B0=1,Bn=−1n+1∑j=0n−1Cjn+1Bj,n≥1B0=1,Bn=−1n+1∑j=0n−1Cn+1jBj,n≥1,由S(n,k)=1k+1∑i=1k+1Cik+1Bk+1−i(n+1)iS(n,k)=1k+1∑i=1k+1Ck+1iBk+1−i(n+1)i即可O(k)O(k)求出S(n,k)S(n,k),总时间复杂度O(n2)O(n2)
Code
#include<cstdio>
#include<algorithm>
using namespace std;
typedef long long ll;
#define maxn 1111
#define mod 1000000007
int inv[maxn],B[maxn],C[maxn][maxn];
int add(int x,int y)
{
x+=y;
if(x>=mod)x-=mod;
return x;
}
int mul(int x,int y)
{
ll z=1ll*x*y;
return z-z/mod*mod;
}
int Pow(int a,int b)
{
int ans=1;
while(b)
{
if(b&1)ans=mul(ans,a);
a=mul(a,a);
b>>=1;
}
return ans;
}
void init(int n=maxn-1)
{
inv[1]=1;
for(int i=2;i<=n;i++)inv[i]=mul(mod-mod/i,inv[mod%i]);
C[0][0]=1;
for(int i=1;i<=n;i++)
{
C[i][0]=C[i][i]=1;
for(int j=1;j<i;j++)C[i][j]=add(C[i-1][j-1],C[i-1][j]);
}
B[0]=1;
for(int i=1;i<n;i++)
{
for(int j=0;j<i;j++)B[i]=add(B[i],mul(C[i+1][j],B[j]));
B[i]=mul(B[i],mod-inv[i+1]);
}
}
int Solve(int n,int k)
{
if(n==0)return 0;
int ans=0;
for(int i=1;i<=k+1;i++)
ans=add(ans,mul(mul(C[k+1][i],B[k+1-i]),Pow(n+1,i)));
ans=mul(ans,inv[k+1]);
return ans;
}
int Deal(int l,int r,int k)
{
int ans=add(Pow(r,k+1),mod-mul(l,Pow(l-1,k)));
ans=add(ans,add(mod-Solve(r-1,k),Solve(l-1,k)));
return ans;
}
int n,a[maxn];
int main()
{
init();
while(~scanf("%d",&n))
{
for(int i=1;i<=n;i++)scanf("%d",&a[i]);
sort(a+1,a+n+1);
int ans=0,res=1;
for(int i=1;i<=n;i++)
{
if(a[i]!=a[i-1])ans=add(ans,mul(res,Deal(a[i-1]+1,a[i],n-i+1)));
res=(ll)res*a[i]%mod;
}
printf("%d\n",ans);
}
return 0;
}