题目描述
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。
输入输出格式
输入格式:
两行,两个字符串s1,s2s1,s2,长度分别为n1,n2n1,n2。
1<=n1,n2<=2000001<=n1,n2<=200000,字符串中只有小写字母。
输出格式:
输出一个整数表示答案
输入输出样例
输入样例#1:
aabb
bbaa
输出样例#1:
10
分析:
因为如果有一个nn长度的相同串,必有一个长度的相同串。考虑AA中的任意开头的后缀的贡献,显然等于这个串与任意后缀的lcplcp的和。
问题就变成了,求
∑i=1len1∑j=1len2lcp(Ai,Bj)∑i=1len1∑j=1len2lcp(Ai,Bj)
其中AiAi表示从ii开始的后缀,同理。
我们设S=A+ch+BS=A+ch+B,其中chch表示某个分隔符(比如说空格)。给这个串跑出heightheight数组,加入分隔符防止BB串前面的几个字母对的影响。
对于每一对Ai,BjAi,Bj,我们让下标大的统计下标小的。对于每一个位置ii,单调增。用一个单调栈维护一下就可以了。
注意,heightiheighti影响到的范围是[0,i−1][0,i−1],不能影响到ii,注意下标问题就没事了。
代码:
// luogu-judger-enable-o2
#include <iostream>
#include <cmath>
#include <cstdio>
#include <cstring>
#define LL long long
const int maxn=4e5+7;
using namespace std;
char s[maxn],s1[maxn];
int sa[maxn],x[maxn],y[maxn],c[maxn],height[maxn],sum[maxn];
int n,m,len,top;
LL ans;
struct node{
int x;
LL sum;
}sta[maxn];
void getsa()
{
int m=1000;
for (int i=0;i<=m;i++) c[i]=0;
for (int i=0;i<n;i++) x[i]=s[i];
for (int i=0;i<n;i++) c[x[i]]++;
for (int i=1;i<=m;i++) c[i]+=c[i-1];
for (int i=n-1;i>=0;i--) sa[--c[x[i]]]=i;
for (int k=1;k<=n;k<<=1)
{
int num=0;
for (int i=n-k;i<n;i++) y[num++]=i;
for (int i=0;i<n;i++) if (sa[i]>=k) y[num++]=sa[i]-k;
for (int i=0;i<=m;i++) c[i]=0;
for (int i=0;i<n;i++) c[x[i]]++;
for (int i=1;i<=m;i++) c[i]+=c[i-1];
for (int i=n-1;i>=0;i--) sa[--c[x[y[i]]]]=y[i],y[i]=0;
swap(x,y);
num=1;
x[sa[0]]=1;
for (int i=1;i<n;i++)
{
if ((y[sa[i]]!=y[sa[i-1]]) || (y[sa[i]+k]!=y[sa[i-1]+k]))
{
x[sa[i]]=++num;
}
else x[sa[i]]=num;
}
if (num>=n) break;
m=num;
}
for (int i=0;i<n;i++) x[sa[i]]=i;
}
void getheight()
{
int k=0;
for (int i=0;i<n;i++)
{
if (k) k--;
int j=sa[x[i]-1];
while ((i+k<n) && (j+k<n) && (s[i+k]==s[j+k])) k++;
height[x[i]]=k;
}
}
int main()
{
scanf("%s",s);
n=len=strlen(s);
scanf("%s",s1);
m=strlen(s1);
s[n]=' ';
for (int i=0;i<m;i++) s[n+i+1]=s1[i];
n=n+m+1;
getsa();
getheight();
for (int i=1;i<n;i++) sum[i]=sum[i-1]+(sa[i]<len);
sta[0]=(node){1,0},top=0;
for (int i=1;i<n;i++)
{
while (top&&(height[sta[top].x]>=height[i])) top--;
sta[++top]=(node){i,sta[top-1].sum+height[i]*((LL)sum[i-1]-(LL)sum[sta[top-1].x-1])};
if (sa[i]>=len) ans+=sta[top].sum;
}
for (int i=0;i<n;i++) sum[i]=0;
for (int i=1;i<n;i++) sum[i]=sum[i-1]+(sa[i]>=len);
sta[0]=(node){1,0},top=0;
for (int i=1;i<n;i++)
{
while (top&&(height[sta[top].x]>=height[i])) top--;
sta[++top]=(node){i,sta[top-1].sum+(LL)height[i]*((LL)sum[i-1]-(LL)sum[sta[top-1].x-1])};
if (sa[i]<len) ans+=sta[top].sum;
}
printf("%lld",ans);
}