题目大意:
令fk(n)f_k(n)fk(n)表示长度为kkk的序列,每个元素在[1,n][1,n][1,n],并且gcd\gcdgcd为111的数量。求:∑i=1nfk(i), n≤109,k≤105\sum_{i=1}^n f_k(i),\ \ \ n\le10^9,k\le10^5∑i=1nfk(i), n≤109,k≤105
题解:
Ans=∑i=1nfk(i)=∑i=1n∑j=1i⌊ij⌋kμ(j)=∑j=1nμ(j)∑i=1n⌊ij⌋k=∑j=1n[(∑i=1⌊nj⌋−1ikj)+⌊nj⌋k(n−⌊nj⌋j+1)]=∑j=1nμ(j)(jSk(⌊nj⌋−1)+⌊nj⌋k(n+1)−⌊nj⌋k+1j)=∑j=1nμ(j)jSk(⌊nj⌋−1)+μ(j)⌊nj⌋k(n+1)−μ(j)j⌊nj⌋k+1Ans=\sum_{i=1}^nf_k(i)=\sum_{i=1}^n \sum_{j=1}^i\left\lfloor\frac ij\right\rfloor^k\mu(j)=\sum_{j=1}^n\mu(j)\sum_{i=1}^n\left\lfloor\frac ij\right\rfloor^k\\
=\sum_{j=1}^n\left[\left(\sum_{i=1}^{\left\lfloor\frac nj\right\rfloor-1}i^kj\right)+\left\lfloor\frac nj\right\rfloor^k\left(n-\left\lfloor\frac nj\right\rfloor j+1\right)\right]\\
=\sum_{j=1}^n\mu(j)\left(jS_k\left(\left\lfloor\frac nj\right\rfloor-1\right)+\left\lfloor\frac nj\right\rfloor^k(n+1)-\left\lfloor\frac nj\right\rfloor^{k+1}j\right)\\
=\sum_{j=1}^n\mu(j)jS_k\left(\left\lfloor\frac nj\right\rfloor-1\right)+\mu(j)\left\lfloor\frac nj\right\rfloor^k(n+1)-\mu(j)j\left\lfloor\frac nj\right\rfloor^{k+1}Ans=i=1∑nfk(i)=i=1∑nj=1∑i⌊ji⌋kμ(j)=j=1∑nμ(j)i=1∑n⌊ji⌋k=j=1∑n⎣⎢⎡⎝⎜⎛i=1∑⌊jn⌋−1ikj⎠⎟⎞+⌊jn⌋k(n−⌊jn⌋j+1)⎦⎥⎤=j=1∑nμ(j)(jSk(⌊jn⌋−1)+⌊jn⌋k(n+1)−⌊jn⌋k+1j)=j=1∑nμ(j)jSk(⌊jn⌋−1)+μ(j)⌊jn⌋k(n+1)−μ(j)j⌊jn⌋k+1
其中:
Sk(n)=∑i=1nikS_k(n)=\sum_{i=1}^n i^kSk(n)=i=1∑nik
然后对μ(j)\mu(j)μ(j)和μ(j)j\mu(j)jμ(j)j求杜教筛,对Sk(n)S_k(n)Sk(n)做插值即可。
取block_size=nk\mathrm{block\_size}=\sqrt{nk}block_size=nk,可以做到O(n23+nk)O\left(n^{\frac23}+\sqrt{nk}\right)O(n32+nk)
std\mathrm{std}std本意是要写一个多点插值把后半部分做到O(nlgk)O\left(\sqrt{n}lgk\right)O(nlgk)来着,但是数据出小了,就暴过去了……
#include<bits/stdc++.h>
#define gc getchar()
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define Rep(i,v) rep(i,0,(int)v.size()-1)
#define lint long long
#define mod 998244353
#define db double
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define N 10000010
#define K 100020
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
typedef pair<int,int> pii;
typedef set<int>::iterator sit;
typedef unordered_map<int,int> mpii;
typedef unordered_map<int,lint> mpil;
mpil savg;lint sg[N];int fac[K],facinv[K];
mpii savh;int sh[N],k,mu[N],ik[N],sk[N];
bool np[N];int p[N],pre[K],suf[K],y[K];
inline int inn()
{
int x,ch;while((ch=gc)<'0'||ch>'9');
x=ch^'0';while((ch=gc)>='0'&&ch<='9')
x=(x<<1)+(x<<3)+(ch^'0');return x;
}
inline int fast_pow(int x,int k,int ans=1) { for(;k;k>>=1,x=(lint)x*x%mod) (k&1)?ans=(lint)ans*x%mod:0;return ans; }
inline int prelude(int n)
{
mu[1]=1,sg[1]=mu[1]*1,sh[1]=mu[1],ik[1]=1,sk[1]=1;
for(int i=2,c=0;i<=n;i++)
{
if(!np[i]) p[++c]=i,mu[i]=-1,ik[i]=fast_pow(i,k);
sg[i]=sg[i-1]+mu[i]*i,sh[i]=sh[i-1]+mu[i],
sk[i]=sk[i-1]+ik[i],(sk[i]>=mod?sk[i]-=mod:0);
for(int j=1,u=n/i;j<=c&&p[j]<=u;j++)
{
int x=p[j]*i;np[x]=1,ik[x]=(lint)ik[i]*ik[p[j]]%mod;
if(i%p[j]==0) { mu[x]=0;break; } else mu[x]=-mu[i];
}
}
return 0;
}
inline int prelude2(int n)
{
rep(i,1,n) y[i]=y[i-1]+fast_pow(i,k),(y[i]>=mod?y[i]-=mod:0);
rep(i,fac[0]=1,n) fac[i]=(lint)fac[i-1]*i%mod;
facinv[n]=fast_pow(fac[n],mod-2);
for(int i=n-1;i>=0;i--) facinv[i]=(i+1ll)*facinv[i+1]%mod;
return 0;
}
inline lint g(int n)//mu(i)*i
{
if(n<N) return sg[n];lint ans=0;
if(savg.count(n)) return savg[n];
for(int s=2,t;s<=n;s=t+1) t=n/(n/s),ans+=(s+t)*(t-s+1ll)/2*g(n/s);
return savg[n]=1-ans;
}
inline int h(int n)//mu(i)
{
if(n<N) return sh[n];int ans=0;
if(savh.count(n)) return savh[n];
for(int s=2,t;s<=n;s=t+1) t=n/(n/s),ans+=h(n/s)*(t-s+1);
return savh[n]=1-ans;
}
inline int g(int l,int r) { return ((g(r)-g(l-1))%mod+mod)%mod; }
inline int h(int l,int r) { return (h(r)-h(l-1)+mod)%mod; }
inline int S(int x)
{
if(x<N) return sk[x];int n=k+3;static lint xs[3],ans;
pre[0]=suf[n+1]=1,xs[0]=1,xs[1]=-1,ans=0;
for(int i=1;i<=n;i++) pre[i]=(lint)pre[i-1]*(x-i)%mod;
for(int i=n;i>=1;i--) suf[i]=(lint)suf[i+1]*(x-i)%mod;
for(int i=1;i<=n;i++)
ans+=xs[(n-i)&1]*y[i]*pre[i-1]%mod*suf[i+1]%mod*facinv[i-1]%mod*facinv[n-i]%mod;
return int((ans%mod+mod)%mod);
}
inline int IK(int i) { if(i<N) return ik[i];return fast_pow(i,k); }
inline int qs(int n,int kk=k) { int ans=0;rep(i,1,n) (ans+=fast_pow(i,kk))%=mod;return ans; }
int main()
{
int n=inn();k=inn();lint ans=0;prelude(N-1),prelude2(K-1);
for(int l=1,r,t;l<=n;l=r+1) r=n/(n/l),t=n/l,
ans+=(g(l,r)*(S(t-1)-(lint)IK(t)*t%mod)+h(l,r)*(n+1ll)%mod*IK(t)%mod)%mod;
return !printf("%lld\n",(ans%mod+mod)%mod);
}