先把整个字符串的字串个数求出(详情请看罗穗蹇论文),s[i]=s[i-1]+n-sa[i]-height[i].因为对每个以i开头的前缀个数为这么多(减去了与前面重复的子串)。
求出来之后二分,注意每个字串是字符串的前缀,况且sa数组是有序的,如果这些有序的后缀所组成的子串都不够k的话,就继续找,这是一个二分的过程,然后假设剩下的个数为p,则可以从这个串的头一直到p+height[i],然后就是找开头l,从这个点开始往后找,找height大于等于这个长度的串,取左端最小值即可。
下面是AC代码:
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
#define LL int
#define inf 31
#define mod 1000003
#define maxn 200100
using namespace std;
int r[maxn];
int Rank[maxn],sa[maxn],height[maxn];
int wa[maxn],wb[maxn],wv[maxn],ws[maxn];
char a[maxn],b[maxn];
int cmp(int *r,int a,int b,int le)
{
return r[a]==r[b]&&r[a+le]==r[b+le];
}
void da(int *r,int *sa,int n,int m)
{
int i,j,p,*x=wa,*y=wb,*t;
for ( i = 0; i < m; i++) ws[i]=0;
for ( i = 0; i < n; i++) ws[ x[i] = r[i] ]++;
for ( i = 1; i < m; i++) ws[i]+=ws[i-1];
for ( i = n-1; i >= 0; i--) sa[--ws[x[i]]]=i;
for ( j = 1,p=1; p <n ; j*=2,m=p)
{
for ( p = 0,i=n-j; i < n; i++) y[p++]=i;
for ( i = 0; i < n; i++) if(sa[i]>=j) y[p++]=sa[i]-j;
for ( i = 0; i < n; i++) wv[i]=x[y[i]];
for ( i = 0; i < m; i++) ws[i]=0;
for ( i = 0; i < n; i++) ws[wv[i]]++;
for ( i = 1; i < m; i++) ws[i]+=ws[i-1];
for ( i = n-1; i >= 0 ; i--) sa[--ws[wv[i]]]=y[i];
for (t=x,x=y,y=t,p=1,x[sa[0]]=0,i=1;i<n;i++)
x[sa[i]] = cmp(y,sa[i-1],sa[i],j)?p-1:p++;
}
return ;
}
void calheight( int *r,int *sa,int n)
{
int i,j,k=0;
for ( i = 0; i <=n ; i++) Rank[sa[i]]=i;
for(i=0;i<n;height[Rank[i++]]=k)
for(k?k--:0,j=sa[Rank[i]-1];r[i+k]==r[j+k];k++);
return ;
}
struct pi
{
int min;
int xm;
int ym;
}pp[4*maxn];
void build(int le,int ri,int tot)
{
if(le==ri)
{
pp[tot].min=height[le];
pp[tot].xm=min(sa[le],sa[le-1]);
pp[tot].ym=max(sa[le],sa[le-1]);
return ;
}
int mid;
mid=(le+ri)/2;
build(le,mid,2*tot);
build(mid+1,ri,2*tot+1);
pp[tot].min=min(pp[2*tot].min,pp[2*tot+1].min);
pp[tot].xm=min(pp[2*tot].xm,pp[2*tot+1].xm);
pp[tot].ym=max(pp[2*tot].ym,pp[2*tot+1].ym);
return ;
}
int query1(int le,int ri,int tot,int ll,int rr)
{
int p,q;
p=100000000;
q=100000000;
if(le>ri)
return 0;
if(le<=ll&&ri>=rr)
{
return pp[tot].min;
}
int mid;
mid=(ll+rr)/2;
if(le<=mid)
{
p=query1(le,ri,2*tot,ll,mid);
}
if(ri>mid)
{
q=query1(le,ri,2*tot+1,mid+1,rr);
}
if(p==100000000&&q==100000000)
return 0;
return min(p,q);
}
int query2(int le,int ri,int tot,int ll,int rr)
{
int p,q;
p=100000000;
q=100000000;
if(le>ri)
return 0;
if(le<=ll&&ri>=rr)
{
return pp[tot].xm;
}
int mid;
mid=(ll+rr)/2;
if(le<=mid)
{
p=query2(le,ri,2*tot,ll,mid);
}
if(ri>mid)
{
q=query2(le,ri,2*tot+1,mid+1,rr);
}
if(p==100000000&&q==100000000)
return 0;
return min(p,q);
}
int query3(int le,int ri,int tot,int ll,int rr)
{
int p,q;
p=100000000;
q=100000000;
if(le>ri)
return 0;
if(le<=ll&&ri>=rr)
{
return pp[tot].ym;
}
int mid;
mid=(ll+rr)/2;
if(le<=mid)
{
p=query3(le,ri,2*tot,ll,mid);
}
if(ri>mid)
{
q=query3(le,ri,2*tot+1,mid+1,rr);
}
if(p==100000000&&q==100000000)
return 0;
return max(p,q);
}
bool vis[1005][1005];
bool ha[mod];
LL s[10005];
LL jie[10005];
int get_pow(int n,int m){
LL p,k;
p=1;
k=n;
while(m>0){
if(m&1){
p=p*k;
p%=mod;
}
k=k*k;
k%=mod;
m>>=1;
}
return p;
}
void init(void){
int i;
jie[0]=1;
jie[1000]=get_pow(31,mod-1000-1);
for(i=999;i>=1;i--){
jie[i]=jie[i+1]*31;
jie[i]%=mod;
}
}
int main()
{
int i,n,j,p,f,u,x,y;
int ss;
LL q,v;
init();
while(1)
{
scanf("%s",a);
if(a[0]=='#') break;
n=(int)strlen(a);
v=1;
for(i=0;i<n;i++)
{
r[i]=a[i]-'a'+1;
if(i==0) s[i]=r[i];
else{
s[i]=s[i-1]+r[i]*v;
s[i]%=mod;
}
v=v*31;
v%=mod;
}
r[n]=0;//为了使rank从1开始,防止height[rank[i]-1]越界。
da(r,sa,n+1,20002);
calheight(r,sa,n);
ss=0;
for(i=1;i<n;i++){
j=2;
while(j<=n){
x=-1;
y=1000000000;
while(j<=n&&height[j]>=i){
x=max(x,sa[j]);
x=max(x,sa[j-1]);
y=min(y,sa[j]);
y=min(y,sa[j-1]);
j++;
}
if(x!=-1&&y<=1000000&&x-y>=i){
ss++;
}
j++;
}
}
printf("%d\n",ss);
}
return 0;
}