把两个字符串拼起来,然后做后缀数组和LCP,然后开两个单调栈,进行计算即可。
#include<bits/stdc++.h>
using namespace std;
long long len1,len2,n,rank[500000],sa[500000],temp[500000],str[500000],cnt[500000],p[500000];
long long q1[500000],q2[500000],height[500005],sum=0,ans1=0,ans2=0,tail1=0,tail2=0;
long long w1[500000],w2[500000],ans=0;
char s1[200010],s2[200005],s[400050];
bool equ(long long x,long long y,long long l){return rank[x]==rank[y]&&rank[x+l]==rank[y+l];}
void doubling()
{
for(long long i=0;i<n;i++)str[i]=s[i]-'a';
str[len1]=26;
str[n]=rank[n]=-1;
for(long long i=0;i<n;i++)rank[i]=str[i],sa[i]=i;
for(long long i,l=0,pos=0,sig=26;pos<n-1;sig=pos)
{
for(i=n-l,pos=0;i<n;i++)p[pos++]=i;
for(i=0;i<n;i++)if(sa[i]>=l)p[pos++]=sa[i]-l;
memset(cnt,0,sizeof(cnt));
for(i=0;i<n;i++)cnt[rank[p[i]]]++;
for(i=1;i<=sig;i++)cnt[i]+=cnt[i-1];
for(i=n-1;i>=0;i--)sa[--cnt[rank[p[i]]]]=p[i];
for(temp[sa[0]]=pos=0,i=1;i<n;i++)
{
if(!equ(sa[i],sa[i-1],l))pos++;
temp[sa[i]]=pos;
}
for(i=0;i<n;i++)rank[i]=temp[i];
if(!l)l=1;else l<<=1;
}
long long i,k=0;
for(i=k=0;i<n;i++)
{
if(k)k--;
if(rank[i]==0)continue;
for(long long j=sa[rank[i]-1];str[i+k]==str[j+k];)
k++;
height[rank[i]]=k;
}
}
void init()
{
scanf("%s",s1);
scanf("%s",s2);
len1=strlen(s1);len2=strlen(s2);
for(long long i=0;i<len1;i++)s[i]=s1[i];
s[len1]='$';
for(long long i=len1+1;i<len1+len2+1;i++)s[i]=s2[i-len1-1];
n=len1+len2+1;
}
void work()
{
for(long long i=0;i<n;i++)
{
if(q1[tail1])
{
ans=0;
while(q1[tail1]>=height[i]&&tail1)
{
ans1-=w1[tail1]*q1[tail1];
ans+=w1[tail1];
tail1--;
}
q1[++tail1]=height[i];
w1[tail1]=ans;
ans1+=ans*height[i];
}
if(q2[tail2])
{
ans=0;
while(q2[tail2]>=height[i]&&tail2)
{
ans2-=w2[tail2]*q2[tail2];
ans+=w2[tail2];
tail2--;
}
q2[++tail2]=height[i];
w2[tail2]=ans;
ans2+=ans*height[i];
}
if(sa[i]<len1)
{
sum+=ans2;
if(q1[tail1]==n-sa[i])
{
w1[tail1]++;
ans1+=n-sa[i];
}
else
{
q1[++tail1]=n-sa[i];
w1[tail1]=1;
ans1+=n-sa[i];
}
}
else
{
sum+=ans1;
if(q2[tail2]==n-sa[i])
{
w2[tail2]++;
ans2+=n-sa[i];
}
else
{
q2[++tail2]=n-sa[i];
w2[tail2]=1;
ans2+=n-sa[i];
}
}
}
printf("%lld\n",sum);
}
int main()
{
//freopen("xf.in","r",stdin);
//freopen("xf.out","w",stdout);
init();
doubling();
work();
return 0;
}