【题意】
统计两个串中长度>=k的公共子串的数量。
【题解】
很容易想到将两个串连接起来,求后缀数组,在根据k分组
但统计是件麻烦事。。。
以下是我自己搞出来的方法:
搞一个单调的栈维护lcp,然后维护单调性。
统计的时候要注意,只有不同的串的lcp才能统计。
这个方法严格的来说不能算是O(n),但是很接近。
结合代码分析:
long long work()
{
int i,top=0,tmp;
long long ans=0,ss[3]={0,0,0};
for (i=1;i<=n;i++)
if (h[i]<kk)
{
top=ss[1]=ss[2]=ss[0]=0;
}
else
{
for (tmp=top;tmp && v[tmp]>h[i]-kk+1;tmp--)//维护单调性
{
ss[w[tmp]]+=h[i]-kk+1-v[tmp];
v[tmp]=h[i]-kk+1;
}
v[++top]=h[i]-kk+1;
if (sa[i-1]<len) w[top]=1;//注意是sa[i-1],来判断前一个串是A串还是B串
if (sa[i-1]>len) w[top]=2;
ss[w[top]]+=h[i]-kk+1;
int t;
if (sa[i]<len) t=1;//自己是A串还是B串
if (sa[i]>len) t=2;
ans+=ss[3-t];//累加
}
return ans;
}
【代码】
#include <iostream>
#include <cstring>
using namespace std;
const int maxn=200010;
char s[maxn],s2[maxn];
int sa[maxn],rk[maxn],wa[maxn],wb[maxn],w[maxn],v[maxn],h[maxn];
int n,len,kk;
int cmp(int* r,int a,int b,int l)
{
return r[a]==r[b] && r[a+l]==r[b+l];
}
void da(char* s,int* sa,int n,int m)
{
int *x=wa,*y=wb,*t,i,j,p;
for (i=0;i<m;i++) w[i]=0;
for (i=0;i<n;i++) w[x[i]=s[i]]++;
for (i=1;i<m;i++) w[i]+=w[i-1];
for (i=n-1;i>=0;i--) sa[--w[x[i]]]=i;
for (j=1,p=1;p<n;j*=2,m=p)
{
for (p=0,i=n-j;i<n;i++) y[p++]=i;
for (i=0;i<n;i++) if (sa[i]>=j) y[p++]=sa[i]-j;
for (i=0;i<n;i++) v[i]=x[y[i]];
for (i=0;i<m;i++) w[i]=0;
for (i=0;i<n;i++) w[v[i]]++;
for (i=1;i<m;i++) w[i]+=w[i-1];
for (i=n-1;i>=0;i--) sa[--w[v[i]]]=y[i];
for (t=x,x=y,y=t,p=1,x[sa[0]]=0,i=1;i<n;i++)
x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++;
}
}
void calheight()
{
int i,j,k=0;
for (i=1;i<=n;i++) rk[sa[i]]=i;
for (i=0;i<n;h[rk[i++]]=k)
for (k?k--:k=0,j=sa[rk[i]-1];s[i+k]==s[j+k];k++);
}
long long work()
{
int i,top=0,tmp;
long long ans=0,ss[3]={0,0,0};
for (i=1;i<=n;i++)
if (h[i]<kk)
{
top=ss[1]=ss[2]=ss[0]=0;
}
else
{
for (tmp=top;tmp && v[tmp]>h[i]-kk+1;tmp--)//维护单调性
{
ss[w[tmp]]+=h[i]-kk+1-v[tmp];
v[tmp]=h[i]-kk+1;
}
v[++top]=h[i]-kk+1;
if (sa[i-1]<len) w[top]=1;//注意是sa[i-1],来判断前一个串是A串还是B串
if (sa[i-1]>len) w[top]=2;
ss[w[top]]+=h[i]-kk+1;
int t;
if (sa[i]<len) t=1;//自己是A串还是B串
if (sa[i]>len) t=2;
ans+=ss[3-t];//累加
}
return ans;
}
int main()
{
freopen("pin.txt","r",stdin);
freopen("pou.txt","w",stdout);
while (1)
{
cin >> kk;
if (kk==0) break;
cin >> s;
len=strlen(s);
cin >> s2;
strcat(s,"$");
strcat(s,s2);
n=strlen(s);
s[n]=0;
da(s,sa,n+1,128);
calheight();
cout << work() << endl;
}
return 0;
}