Description
给出一长度为nn的序列,如果ai=0ai=0则aiai等概率取[1,m][1,m]中任意一个整数,给出权值序列v1,...,vmv1,...,vm,求
的期望
Input
第一行一整数TT表示用例组数,每组用例首先输入两个整数,之后输入nn个整数,最后输入mm个整数
(1≤T≤10,4≤n≤100,1≤m≤100,0≤ai≤m,1≤vi≤109)(1≤T≤10,4≤n≤100,1≤m≤100,0≤ai≤m,1≤vi≤109)
Output
输出期望值,结果模109+7109+7
Sample Input
2
6 8
4 8 8 4 6 5
10 20 30 40 50 60 70 80
4 3
0 0 0 0
3 2 4
Sample Output
8000
3
Solution
考虑朴素dpdp,以dp[i][x][y][z]dp[i][x][y][z]表示ai−2=x,ai−1=y,ai=zai−2=x,ai−1=y,ai=z时的答案,每次枚举ai+1ai+1的值转移即可,这样的时间复杂度为O(nm4)O(nm4),注意到实际有用的不是ai−2,ai−1ai−2,ai−1的值,而是gcd(ai−2,ai−1,ai)gcd(ai−2,ai−1,ai)和gcd(ai−1,ai)gcd(ai−1,ai)的值,那么以dp[i][x][y][z]dp[i][x][y][z]表示gcd(ai−2,ai−1,ai)=x,gcd(ai−1,ai)=y,ai=zgcd(ai−2,ai−1,ai)=x,gcd(ai−1,ai)=y,ai=z时的答案,那么显然需要x|y|z≤mx|y|z≤m,这样的三元组在m=100m=100时只有res=1471res=1471个,进而可以直接转移
给所有三元组(x,y,z)(x,y,z)编号id(x,y,z)id(x,y,z),以dp[i][j]dp[i][j]表示gcd(ai−2,ai−1,ai),gcd(ai−1,ai),aigcd(ai−2,ai−1,ai),gcd(ai−1,ai),ai为第jj个三元组时的答案,并预处理表示第ii个三元组加上后转移到的三元组编号,val(i,j)val(i,j)表示第ii个三元组加上的gcdgcd值,进而有转移
dp[i][suf(j,k)]+=vval(j,k)⋅dp[i−1][j],k∈[li,ri]dp[i][suf(j,k)]+=vval(j,k)⋅dp[i−1][j],k∈[li,ri]
其中[li,ri][li,ri]表示aiai的取值区间,答案即为∑dp[n][i]mnum∑dp[n][i]mnum,其中numnum为ai=0ai=0的个数
Code
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
#define mod 1000000007
int mul(int x,int y)
{
ll z=1ll*x*y;
return z-z/mod*mod;
}
int add(int x,int y)
{
x+=y;
if(x>=mod)x-=mod;
return x;
}
int Pow(int x,int y)
{
int ans=1;
while(y)
{
if(y&1)ans=(ll)ans*x%mod;
x=(ll)x*x%mod;
y>>=1;
}
return ans;
}
int T,n,m,a[105],v[105];
int res,dp[105][1500],id[105][105][105],gcd[105][105],suf[1500][105],val[1500][105];
void init(int n=100)
{
for(int i=1;i<=n;i++)gcd[i][0]=gcd[0][i]=i;
for(int i=1;i<=n;i++)
{
gcd[i][i]=i;
for(int j=1;j<i;j++)
gcd[i][j]=gcd[j][i]=gcd[j][i%j];
}
res=0;
for(int x=1;x<=n;x++)
for(int y=x;y<=n;y+=x)
for(int z=y;z<=n;z+=y)
id[x][y][z]=res++;
for(int x=1;x<=n;x++)
for(int y=x;y<=n;y+=x)
for(int z=y;z<=n;z+=y)
for(int w=1;w<=n;w++)
{
suf[id[x][y][z]][w]=id[gcd[y][w]][gcd[z][w]][w];
val[id[x][y][z]][w]=gcd[x][w];
}
}
int main()
{
init();
scanf("%d",&T);
while(T--)
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)scanf("%d",&a[i]);
for(int i=1;i<=m;i++)scanf("%d",&v[i]);
memset(dp,0,sizeof(dp));
int cnt=id[m][m][m]+1;
for(int i=(a[1]?a[1]:1);i<=(a[1]?a[1]:m);i++)
for(int j=(a[2]?a[2]:1);j<=(a[2]?a[2]:m);j++)
for(int k=(a[3]?a[3]:1);k<=(a[3]?a[3]:m);k++)
dp[3][id[gcd[i][gcd[j][k]]][gcd[j][k]][k]]++;
for(int i=4;i<=n;i++)
for(int j=0;j<cnt;j++)
if(dp[i-1][j])
for(int k=(a[i]?a[i]:1);k<=(a[i]?a[i]:m);k++)
dp[i][suf[j][k]]=add(dp[i][suf[j][k]],mul(v[val[j][k]],dp[i-1][j]));
int ans=0;
for(int i=0;i<cnt;i++)ans=add(ans,dp[n][i]);
int num=0;
for(int i=1;i<=n;i++)
if(!a[i])num++;
ans=mul(ans,Pow(Pow(m,mod-2),num));
printf("%d\n",ans);
}
return 0;
}