首先可以考虑求出 f i f_i fi 表示结尾在第 i i i 位的 A A AA AA 串的个数, g i g_i gi 表示开头在第 i i i 位的 A A AA AA 串的个数。
枚举 A A A 的长度 L L L,每 L L L 位放置一个关键点,那么 A A AA AA 必定经过恰好两个相邻的关键点。
枚举两个相邻关键点 i i i 和 i + L i+L i+L 并钦定 A A AA AA 经过这两个关键点,求出 p r e i , p r e i + L pre_{i},pre_{i+L} prei,prei+L 的最长公共后缀 l 1 l_1 l1 和 s u f i + 1 , s u f i + L + 1 suf_{i+1},suf_{i+L+1} sufi+1,sufi+L+1 的最长公共前缀 l 2 l_2 l2。
设 l = i − l 1 + 1 l=i-l_1+1 l=i−l1+1, r = i + L + l 2 r=i+L+l_2 r=i+L+l2,那么对于任意的 l ≤ j ≤ r − L l\leq j\leq r-L l≤j≤r−L,有 S j = S j + L S_j=S_{j+L} Sj=Sj+L。
那么这段区间里任意一个长 2 L 2L 2L 的串都是 A A AA AA 的形式,那么经过 i i i 和 i + L i+L i+L 的 A A AA AA 串就共有 ( min ( r , i + 2 L − 1 ) − max ( l , i − L + 1 ) + 1 ) − ( 2 L − 1 ) \big(\min(r,i+2L-1)-\max(l,i-L+1)+1\big)-(2L-1) (min(r,i+2L−1)−max(l,i−L+1)+1)−(2L−1) 个(其中需满足 l 1 + l 2 ≥ L l_1+l_2\geq L l1+l2≥L,否则没有这样的 A A AA AA 串),它们的左/右端点都是连续的,使用差分维护即可。
容易证明任意一种合法的 A A AA AA 串都会被统计到。
时间复杂度 O ( n log 2 n ) O(n\log^2n) O(nlog2n),一个 log \log log 是调和级数,一个 log \log log 是求 LCP(当然用 ST 表就可以 O ( n log n ) O(n\log n) O(nlogn) 预处理+ O ( 1 ) O(1) O(1) 查询)。
这种分段放关键点的方式可以加速枚举 S S S 的每一个子串的每一个循环节的过程,因为 [ l , r ] [l,r] [l,r] 内的任意一个长度大于等于 L L L 的串都有循环节 L L L,而且任意一个串的任意一个循环节也肯定会被枚举这个循环节长度时统计到。
既然都是双
log
\sout{\log}
log 不如直接 hash。
可恶,直接自然溢出竟然会被卡
#include<bits/stdc++.h>
#define N 30010
#define ll long long
using namespace std;
namespace modular
{
const int mod=998244353;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
}using namespace modular;
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^'0');
ch=getchar();
}
return x*f;
}
const int base=1145141;
int bit[N];
int T,n;
int f[N],g[N];
char s[N];
struct LCP
{
int sum[N];
void init()
{
sum[0]=0;
for(int i=1;i<=n;i++)
sum[i]=add(mul(sum[i-1],base),s[i]-'a');
}
int ask(int l,int r)
{
return dec(sum[r],mul(sum[l-1],bit[r-l+1]));
}
int query(int i,int j)
{
int l=0,r=i,ans=-1;
while(l<=r)
{
int mid=(l+r)>>1;
if(ask(i-mid+1,i)==ask(j-mid+1,j)) ans=mid,l=mid+1;
else r=mid-1;
}
assert(ans!=-1);
return ans;
}
}Q[2];
int main()
{
// freopen("P1117_13.in","r",stdin);
// freopen("P1117_13_my.out","w",stdout);
bit[0]=1;
for(int i=1;i<=30001;i++)
bit[i]=mul(bit[i-1],base);
T=read();
while(T--)
{
scanf("%s",s+1);
n=strlen(s+1);
Q[0].init();
reverse(s+1,s+n+1);
Q[1].init();
reverse(s+1,s+n+1);
for(int i=0;i<=n+1;i++) f[i]=g[i]=0;
for(int L=1;L<=n;L++)
{
for(int i=1;i+L<=n;i+=L)
{
int l1=Q[0].query(i,i+L);
int l2=Q[1].query(n-(i+L+1)+1,n-(i+1)+1);
int l=max(i-l1+1,max(i-L+1,1)),r=min(i+L+l2,min(i+2*L-1,n));
int tmp=(r-l+1)-(2*L-1);
if(tmp<=0) continue;
assert(l1+l2>=L);
f[r-tmp+1]++,f[r+1]--;
g[l+tmp-1]++,g[l-1]--;
}
}
for(int i=1;i<=n;i++) f[i]+=f[i-1];
for(int i=n;i>=1;i--) g[i]+=g[i+1];
ll ans=0;
for(int i=1;i<=n;i++)
ans+=1ll*f[i]*g[i+1];
printf("%lld\n",ans);
}
return 0;
}