题目大意:
我直接说题目抽象出来的模型吧:
给你一个序列CC的前n项,之后都是0;一开始你有一个多项式,然后m次操作,每次P=P×(aix+bi)P=P×(aix+bi),也就是乘以一个单项式,然后你要输出∑i≥0Ci[xi]P(x)∑i≥0Ci[xi]P(x),并且强制在线。
代码里面的Ci=n−i,ai=inputi,bi=1−aiCi=n−i,ai=inputi,bi=1−ai,并且强制在线的手段是每次inputiinputi要加上上一次的ansans,并且ans(init)=nans(init)=n。
题解:
直接做是平方的,gg。考虑这么一件事情:
假设你在处理第tt个询问,当前多项式是P,然后前若干(不超过)个单项式乘积是L,还要再乘以R才能得到想要的P,那么答案:
anst=∑i≥0Ci[xi](L(x)×R(x))=∑i≥0Ci∑j=0iLjRi−j=∑k=i−j≥0Rk∑i≥kCiLi−kanst=∑i≥0Ci[xi](L(x)×R(x))=∑i≥0Ci∑j=0iLjRi−j=∑k=i−j≥0Rk∑i≥kCiLi−k
其中那两个下标是多项式系数。
这个意味着,除了先算出L(x)R(x)L(x)R(x),然后在和CC做点积这种朴素做法以外,你还可以不依赖于R的先求出一个C’然后去和R做点积之类的。
考虑分治(先不要吐槽为啥强制在线了还分治)
取L是分治区间左半部分的乘积,C是的C′C′,然后这两个做减法卷积得到[1,mid][1,mid]的C′C′,然后递归右半部分。最终的R取为那个单项式即可。
然后注意到这时C超过区间长度的部分是没有意义的,L显然也是区间长度级别的,因此复杂度是对的。
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#define gc getchar()
#define lint long long
#define p 998244353
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define clr(a,n) memset(a,0,sizeof(int)*(n))
#define cpy(a,b,n) memcpy(a,b,sizeof(int)*(n))
#define N 800010
using namespace std;
inline int inn()
{
int x,ch;while((ch=gc)<'0'||ch>'9');
x=ch^'0';while((ch=gc)>='0'&&ch<='9')
x=(x<<1)+(x<<3)+(ch^'0');return x;
}
int ans=0,a[N],b[N],r[N],c[N],pl[N],pd[N];
inline int fast_pow(int x,int k,int ans=1) { for(;k;k>>=1,x=(lint)x*x%p) (k&1)?ans=(lint)ans*x%p:0;return ans; }
inline int NTT(int *a,int n,int s)
{
for(int i=1;i<=n;i++) if(i<r[i]) swap(a[i],a[r[i]]);
for(int i=2;i<=n;i<<=1)
{
int wn=fast_pow(3,s>0?(p-1)/i:p-1-(p-1)/i);
for(int j=0,t=i>>1,x,y;j<n;j+=i)
for(int k=0,w=1;k<t;k++,w=(lint)w*wn%p)
x=a[j+k],y=(lint)w*a[j+k+t]%p,
((a[j+k]=x+y)>=p?a[j+k]-=p:0),
((a[j+k+t]=x-y)<0?a[j+k+t]+=p:0);
}
int ninv=fast_pow(n,p-2);
if(s<0) for(int i=0;i<n;i++) a[i]=(lint)a[i]*ninv%p;
return 0;
}
inline int tms(int *A,int *B,int *C,int m1,int m2)
{
int n=1,L=0;for(;n<=m1+m2;n<<=1,L++);
for(int i=1;i<=n;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(L-1));
clr(a,n),cpy(a,A,m1+1),NTT(a,n,1),
clr(b,n),cpy(b,B,m2+1),NTT(b,n,1);
for(int i=1;i<=n;i++) a[i]=(lint)a[i]*b[i]%p;
return NTT(a,n,-1),cpy(C,a,m1+m2+1),0;
}
inline int mns(int *A,int *B,int *C,int m1,int m2)
{
reverse(B,B+m2+1),tms(A,B,a,m1,m2);
for(int i=0;i<=m1-m2;i++) C[i]=a[i+m2];
return reverse(B,B+m2+1),0;
}
inline int solve(int l,int r,int *c,int *pd,int *pl)
{
int mid=(l+r)>>1,L=mid-l+1,R=r-mid;
if(l==r) return pd[1]=(ans+inn())%p,pd[0]=(1-pd[1]+p)%p,
printf("%d\n",ans=((lint)c[0]*pd[0]+(lint)c[1]*pd[1])%p);
return solve(l,mid,c,pd,pl),mns(c,pd,pl,L+R,L),
solve(mid+1,r,pl,pd+L+1,pl+R+1),tms(pd,pd+L+1,pd,L,R);
}
int main() { ans=inn();rep(i,0,ans) c[i]=ans-i;return solve(1,inn(),c,pd,pl),0; }