Description
Input
共一行包括两个正整数N和M。
Output
共一行为所求表达式的值对10^9+7取模的值。
Sample Input
5 3
Sample Output
36363
HINT
1<=N<=10^9,1<=M<=500000
解题思路:
说一个不仅能求原题,还能求其拓展形式∑i=0nf(i)qi (f(i)是一个k次多项式)∑i=0nf(i)qi (f(i)是一个k次多项式)的O(k)O(k)算法。
(最好f(i)f(i)可以快速算或以点值形式给出,当然dalao们可以直接上多项式多点求值)
首先q=1q=1时还是要特判。
当q≠1q≠1时,我们设S(n)=∑i=0n−1f(i)qiS(n)=∑i=0n−1f(i)qi,所以我们要求S(n+1)S(n+1)。
有一个结论:存在一个≤k≤k次多项式g(x)g(x)使得S(n)=qng(n)−g(0)S(n)=qng(n)−g(0)
考虑用归纳法证明:
当n=0n=0时,显然成立。
当n=1n=1时,S(n)=f(0)=qg(1)−g(0)S(n)=f(0)=qg(1)−g(0),g(1)=(f(0)+g(0))/qg(1)=(f(0)+g(0))/q,也成立。
假设n≤mn≤m时都成立,那么当n=m+1n=m+1时
qS(n)=∑i=0n−1f(i)qi+1=qn+1g(n)−qg(0)qS(n)=∑i=0n−1f(i)qi+1=qn+1g(n)−qg(0)
S(n+1)=∑i=0nf(i)qi=qn+1g(n+1)−g(0)S(n+1)=∑i=0nf(i)qi=qn+1g(n+1)−g(0)
两式相减可得∑i=1n(f(i)−f(i−1))qi+f(0)=qn+1(g(n+1)−g(n))+(q−1)g(0)∑i=1n(f(i)−f(i−1))qi+f(0)=qn+1(g(n+1)−g(n))+(q−1)g(0)
而左边∑i=1n−1(f(i)−f(i−1))qi+f(0)+(f(n)−f(n−1))qn=qn(g(n)−g(n−1))+(q−1)g(0)+(f(n)−f(n−1))qn∑i=1n−1(f(i)−f(i−1))qi+f(0)+(f(n)−f(n−1))qn=qn(g(n)−g(n−1))+(q−1)g(0)+(f(n)−f(n−1))qn
代入原式,两边同除以qnqn可得:g(n)−g(n−1)+f(n)−f(n−1)=q(g(n+1)−g(n))g(n)−g(n−1)+f(n)−f(n−1)=q(g(n+1)−g(n))
所以g(n)=((q+1)g(n)−g(n−1)+f(n)−f(n−1))/qg(n)=((q+1)g(n)−g(n−1)+f(n)−f(n−1))/q也是个≤k≤k次多项式。得证。
现在问题就是如何求g(x)g(x)。
S(n)−S(n−1)=qng(n)−qn−1g(n−1)=qn−1f(n−1)S(n)−S(n−1)=qng(n)−qn−1g(n−1)=qn−1f(n−1)
即是g(n)=(g(n−1)+f(n−1))/qg(n)=(g(n−1)+f(n−1))/q
所以如果我们知道了g(0)g(0)的值,就能推算出g(1),g(2),...g(k+1)g(1),g(2),...g(k+1)
不妨设g(0)=xg(0)=x,由递推式可知g(1),g(2),g(k+1)g(1),g(2),g(k+1)都可表示为一次函数的形式。
而kk次多项式的次差分为0,所以可以列出k+1k+1阶差分表达式:∑i=0k+1(−1)i(k+1i)g(k+1−i)=0∑i=0k+1(−1)i(k+1i)g(k+1−i)=0从而解出g(0)g(0)。
然后直接朗格朗日差值求出g(n+1)g(n+1)就可以得到S(n+1)S(n+1)了。
代码如下,其实很简洁:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
ll getll()
{
ll i=0,f=1;char c;
for(c=getchar();(c!='-')&&(c<'0'||c>'9');c=getchar());
if(c=='-')f=-1,c=getchar();
for(;c>='0'&&c<='9';c=getchar())i=(i<<3)+(i<<1)+c-'0';
return i*f;
}
const int N=500005,mod=1e9+7;
ll n,m;
ll f[N],a[N][2],g[N],fac[N],inv[N],pre[N],suf[N];
int Pow(ll x,ll y)
{
ll res=1;
for(;y;y>>=1,x=x*x%mod)
if(y&1)res=res*x%mod;
return res;
}
inline int C(int x,int y)
{
return fac[x]*inv[y]%mod*inv[x-y]%mod;
}
int G(ll x)
{
if(x<m)return g[x];
ll res=0,t=1;suf[m+1]=1;x%=mod;
for(int i=0;i<=m;i++)pre[i]=t*(x-i)%mod,t=pre[i];
for(int i=m;i>=0;i--)suf[i]=suf[i+1]*(x-i)%mod;
for(int i=0;i<=m;i++)
{
ll t=suf[i+1]*(i==0?1:pre[i-1])%mod;
t=t*g[i]%mod*inv[i]%mod*(((m-i)&1)?mod-inv[m-i]:inv[m-i])%mod;
res=(res+t)%mod;
}
return (res+mod)%mod;
}
int main()
{
//freopen("lx.in","r",stdin);
scanf("%lld%lld",&n,&m);
if(m==1)
{
printf("%lld\n",n*(n+1)%mod*Pow(2,mod-2)%mod);
return 0;
}
for(int i=0;i<=m;i++)f[i]=Pow(i,m);
fac[0]=1;
for(int i=1;i<=m+1;i++)fac[i]=fac[i-1]*i%mod;
inv[m+1]=Pow(fac[m+1],mod-2);
for(int i=m;i>=0;i--)inv[i]=inv[i+1]*(i+1)%mod;
ll t=Pow(m,mod-2);a[0][1]=1,a[0][0]=0;
for(int i=1;i<=m+1;i++)
a[i][1]=a[i-1][1]*t%mod,a[i][0]=(a[i-1][0]+f[i-1])*t%mod;
ll c=0,d=0;
for(int i=0;i<=m+1;i++)
{
c=(c+a[m+1-i][1]*C(m+1,i)*((i&1)?1:-1))%mod;
d=(d+a[m+1-i][0]*C(m+1,i)*((i&1)?1:-1))%mod;
}
g[0]=(-d)*Pow(c,mod-2)%mod;
for(int i=1;i<=m+1;i++)g[i]=(g[i-1]+f[i-1])*t%mod;
ll ans=G(n+1);
ans=(ans*Pow(m,n+1)%mod-g[0]+mod)%mod;
printf("%lld\n",ans);
return 0;
}