题目大意
给定一个长度为 n n n的序列 a a a,对于每一种将 a a a划分成若干个子串的方式,我们设当前划分为 k k k个子串,分别描述为 ( l 1 , r 1 ) , ( l 2 , r 2 ) , … , ( l k , r k ) (l_1,r_1),(l_2,r_2),\dots,(l_k,r_k) (l1,r1),(l2,r2),…,(lk,rk)定义长度为 k k k的数组 s s s满足 s i = ∑ j = l i r i a j s_i=\sum\limits_{j=l_i}^{r_i}a_j si=j=li∑riaj。如果 s s s是一个回文数组,则称这种划分方式是好的。
求 a a a有多少种好的划分。答案模 998244353 998244353 998244353。
有多组数据。
题解
我们可以用前缀和与后缀和来维护区间和,然后记录划分时的断点。
对于每一个断点,令其位置为 i i i,则位置 i i i是一个划分子串的右端点,位置 i + 1 i+1 i+1是一个划分子串的左端点。
对于每一个断点,必有另一个断点的后缀和与这个断点的前缀和相等。那么我们可以枚举 a a a序列中左右部分相等的段,然后求出左边前缀和相等的一段 [ l 1 , r 1 ] [l_1,r_1] [l1,r1]和右边后缀和相等的一段 [ l 2 , r 2 ] [l_2,r_2] [l2,r2],枚举其中能有几个断点,断点分配方案的数量为
∑ i = 1 min ( r 1 − l 1 + 1 , r 2 − l 2 + 1 ) C ( r 1 − l 1 + 1 , i ) × C ( r 2 − l 2 + 1 , i ) \sum\limits_{i=1}^{\min(r_1-l_1+1,r_2-l_2+1)}C(r_1-l_1+1,i)\times C(r_2-l_2+1,i) i=1∑min(r1−l1+1,r2−l2+1)C(r1−l1+1,i)×C(r2−l2+1,i)。
最后将各段的断点分配方案的数量相乘即为答案。
如果 [ l 1 , r 1 ] [l_1,r_1] [l1,r1]和 [ l 2 , r 2 ] [l_2,r_2] [l2,r2]有交集,那么这一段的点可以任意选,答案乘上 2 r 2 − l 1 + 1 2^{r_2-l_1+1} 2r2−l1+1。此时整个序列 a a a已经选完,直接退出即可。
根据断点的定义,后缀和的求法为 s 2 i = s 2 i + 1 + a i + 1 s2_i=s2_{i+1}+a_{i+1} s2i=s2i+1+ai+1。
前缀和与后缀和要清零,避免影响后面的计算。不能用memset,否则会TLE。
时间复杂度为 O ( ∑ n ) O(\sum n) O(∑n)。
code
#include<bits/stdc++.h>
using namespace std;
const int N=100000;
int T,n;
long long ans,a[100005],s1[100005],s2[100005],jc[100005],ny[100005],w[100005];
long long mod=998244353;
long long mi(long long t,long long v){
if(!v) return 1;
long long re=mi(t,v/2);
re=re*re%mod;
if(v&1) re=re*t%mod;
return re;
}
void init(){
w[0]=1;
for(int i=1;i<=N;i++) w[i]=w[i-1]*2%mod;
jc[0]=1;
for(int i=1;i<=N;i++) jc[i]=jc[i-1]*i%mod;
ny[N]=mi(jc[N],mod-2);
for(int i=N-1;i>=0;i--) ny[i]=ny[i+1]*(i+1)%mod;
}
long long C(int x,int y){
return jc[x]*ny[y]%mod*ny[x-y]%mod;
}
int main()
{
init();
scanf("%d",&T);
while(T--){
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%lld",&a[i]);
}
for(int i=1;i<=n;i++) s1[i]=s1[i-1]+a[i];
for(int i=n-1;i>=1;i--) s2[i]=s2[i+1]+a[i+1];
ans=1;
for(int l=1,r=n-1,x,y;l<=r;l=x+1,r=y-1){
while(l<=r&&s1[l]!=s2[r]){
if(s1[l]<s2[r]) ++l;
else --r;
}
if(l>r) break;
if(s1[l]==s1[r]){
ans=ans*w[r-l+1]%mod;
break;
}
x=l;y=r;
while(s1[x+1]==s1[l]) ++x;
while(s2[y-1]==s2[r]) --y;
int v=min(x-l+1,r-y+1);
long long re=0;
for(int i=0;i<=v;i++){
re=(re+C(x-l+1,i)*C(r-y+1,i)%mod)%mod;
}
ans=ans*re%mod;
}
printf("%lld\n",ans);
for(int i=1;i<=n;i++) s1[i]=s2[i]=0;
}
return 0;
}